In [2]:
import torch
import torch.nn as nn

In [3]:
hidden_size = 10
net = nn.Sequential(
            nn.Linear(1+hidden_size, 50),
            nn.Tanh(),
            nn.Linear(50, hidden_size),
        )

In [57]:
ms = []
for m in net.modules():
    ms.append(m)

In [61]:
ms[1].in_features

4

In [62]:
ms[3].out_features

10

In [65]:
ms[2]

Tanh()

In [66]:
shapes = []
for i,p in enumerate(net.parameters()):
    print(p)

Parameter containing:
tensor([[ 0.2431,  0.1728, -0.1301,  0.1556, -0.2487,  0.1381,  0.2616,  0.2980,
         -0.1954, -0.1844,  0.0204],
        [ 0.1235, -0.1747,  0.1130,  0.0724,  0.2134, -0.2320, -0.2578,  0.2722,
          0.2914,  0.0156,  0.2283],
        [-0.1328,  0.1423, -0.0695, -0.2352,  0.1493,  0.2161,  0.1734, -0.2429,
         -0.0385,  0.1320, -0.2446],
        [ 0.2692,  0.1658, -0.0068,  0.2027, -0.0337, -0.1300,  0.1165, -0.1442,
         -0.0568, -0.0525, -0.0513],
        [ 0.0882,  0.2568,  0.2755,  0.2468, -0.2314,  0.1749,  0.2672, -0.2952,
          0.1497, -0.1292, -0.0819],
        [ 0.2630, -0.1544, -0.2802, -0.0734,  0.0845,  0.0382,  0.0081, -0.0770,
         -0.1221,  0.0240,  0.2368],
        [-0.0731,  0.0697, -0.1398, -0.3009, -0.2604,  0.0865, -0.0176,  0.2259,
         -0.1398,  0.0226,  0.0956],
        [-0.2079,  0.0529, -0.0756,  0.1801,  0.1422,  0.1234,  0.2019,  0.0340,
         -0.1446, -0.1416,  0.1307],
        [ 0.2119, -0.1792,  0.2025

In [50]:
for i,m in enumerate(net.named_modules()):
    print(i,m)

0 ('', Sequential(
  (0): Linear(in_features=11, out_features=50, bias=True)
  (1): Tanh()
  (2): Linear(in_features=50, out_features=10, bias=True)
))
1 ('0', Linear(in_features=11, out_features=50, bias=True))
2 ('1', Tanh())
3 ('2', Linear(in_features=50, out_features=10, bias=True))


In [48]:
net.named_modules()

<generator object Module.named_modules at 0x7f9aa433e510>

In [35]:
net.parameters

<bound method Module.parameters of Sequential(
  (0): Linear(in_features=11, out_features=50, bias=True)
  (1): Tanh()
  (2): Linear(in_features=50, out_features=10, bias=True)
)>

In [30]:
net.parameters

48

In [6]:
import inspect

In [10]:
type(net) == nn.Sequential

True

In [12]:
inspect.getfullargspec(net.forward).args[1:]

['input']

In [None]:
class NeuralODEfromSequential(nn.Module):
    def __init__(self,net:nn.Sequential,time_dependent=False,data_dependent=False):

        self.net = net
        self.has_t_arg = time_dependent
        self.has_input_arg = data_dependent

    def forward(self,hidden,t):
        if self.has_input_arg and self.has_t_arg:
            # dh/ds = f(h,t,x)
            z = torch.cat((hidden,t,input),1)
            output = net.forward(z)
        elif self.has_input_arg and not self.has_t_arg:
            # dh/ds=f(h,x)
            z = torch.cat((hidden,input),1)
            output = net.forward(z)
        elif not self.has_input_arg and self.has_t_arg:
            # dh/ds=f(h,t)
            z = torch.cat((hidden,t),1)
            output = net.forward(z)
        else:
            # dh/ds=f(h)
            output = net.forward(hidden)
        return output