## Testing interplay of the `NeuralDiffEq` object with `torchsde` and `torchode`

In general, we are focusing on two specific types of neural networks, both implemented in pytorch using `torch.nn.Sequential`: 

1. **Potential nets** where the input to the network is the dimensionality of the data and the output is a single parameter. Under the hood, in this implementetion of neural differential equations, this output is calculated as the gradient of the potential with respect to the input position and the output is mapped back to the input dimension of the data.
2. **Forward nets** where the input and output dim are both of the same dimension; this is the canonical input to a neural DE.

### Import packages and generate some dummy data

In [1]:
import neural_diffeqs
import torch
from torchdiffeq import odeint
from torchsde import sdeint

X0 = torch.randn([200, 50])
t = torch.Tensor([0, 1, 2])
X0.shape, t.shape

(torch.Size([200, 50]), torch.Size([3]))

####  SDE with `mu` as a potential net

This is the most likely use-case.

In [2]:
SDE = neural_diffeqs.neural_diffeq(mu_dropout=False, sigma_dropout=False)
SDE

NeuralDiffEq(
  (mu): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=1, bias=True)
  )
  (sigma): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=50, bias=True)
  )
)

In [3]:
X_pred = sdeint(SDE, X0,  t)
X_pred.shape

torch.Size([3, 200, 50])

#### SDE with `mu` and `sigma`  as potential nets

I'm still not sure if this scenario is necessary or even relevant to a realistic drift-diffusion equation, but testing it nonetheless. The reason I am unsure is that it's unclear to me whether the gradient of the potential is encoded and discovered through the drift (`mu`) alone or if there are components of gradient functions in diffusion. This may be the case as one can imagine the magnitude of diffusion dropping off at a gradient.

In [4]:
SDE = neural_diffeqs.neural_diffeq(mu_potential=True, mu_dropout=False, sigma_potential=True, sigma_dropout=False)
SDE

NeuralDiffEq(
  (mu): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=1, bias=True)
  )
  (sigma): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=1, bias=True)
  )
)

In [5]:
X_pred = sdeint(SDE, X0,  t)
X_pred.shape

torch.Size([3, 200, 50])

####  Neural ODE with `mu` as a potential net

In [6]:
ODE = neural_diffeqs.neural_diffeq(mu_dropout=False, mu_potential=True, sigma_hidden=False)
ODE

NeuralDiffEq(
  (mu): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=1, bias=True)
  )
)

In [7]:
X_pred = odeint(SDE, X0,  t)
X_pred.shape

torch.Size([3, 200, 50])

####  Neural ODE with `mu` is a normal neural net

In [8]:
ODE = neural_diffeqs.neural_diffeq(mu_dropout=False, mu_potential=False, sigma_hidden=False)
ODE

NeuralDiffEq(
  (mu): Sequential(
    (input_layer): Linear(in_features=50, out_features=400, bias=True)
    (activation_1): Tanh()
    (hidden_1): Linear(in_features=400, out_features=400, bias=True)
    (activation_2): Tanh()
    (output_layer): Linear(in_features=400, out_features=50, bias=True)
  )
)

In [9]:
X_pred = odeint(ODE, X0, t)
X_pred.shape

torch.Size([3, 200, 50])