In [None]:
import torch
import torch.nn as nn

In [None]:
# Taken from pytorch-ident, https://github.com/forgi86/pytorch-ident/blob/master/torchid/ss/dt/simulator.py

class StateSpaceSimulator(nn.Module):
    r""" Discrete-time state-space simulator.

    Args:
        f_xu (nn.Module): The neural state-space model.
        batch_first (bool): If True, first dimension is batch.

    Inputs: x_0, u
        * **x_0**: tensor of shape :math:`(N, n_{x})` containing the
          initial hidden state for each element in the batch.
          Defaults to zeros if (h_0, c_0) is not provided.
        * **input**: tensor of shape :math:`(L, N, n_{u})` when ``batch_first=False`` or
          :math:`(N, L, n_{x})` when ``batch_first=True`` containing the input sequence

    Outputs: x
        * **x**: tensor of shape :math:`(L, N, n_{x})` corresponding to
          the simulated state sequence.

    Examples::

        >>> ss_model = NeuralStateSpaceModel(n_x=3, n_u=2)
        >>> nn_solution = StateSpaceSimulator(ss_model)
        >>> x0 = torch.randn(64, 3)
        >>> u = torch.randn(100, 64, 2)
        >>> x = nn_solution(x0, u)
        >>> print(x.size())
        torch.Size([100, 64, 3])
     """

    def __init__(self, f_xu, batch_first=True):
        super().__init__()
        self.f_xu = f_xu
        self.batch_first = batch_first

    def forward(self, x_0, u):
        x_step = x_0
        dim_time = 1 if self.batch_first else 0
        x = []
        for u_step in u.split(1, dim=dim_time):  # split along the time axis
            u_step = u_step.squeeze(dim_time)
            x += [x_step]
            dx = self.f_xu(x_step, u_step)
            x_step = x_step + dx

        x = torch.stack(x, dim_time)
        return x

In [None]:

class StateSpaceSimulatorBasic(nn.Module):
    def __init__(self, f_xu, batch_first=True):
        super().__init__()
        self.f_xu = f_xu
        self.batch_first = batch_first

    def forward(self, x_0, u):
        B, n_x = x_0.shape
        _, T, _ = u.shape # B, T, n_u
        x = torch.empty((B, T, n_x))
        x_step = x_0
        for t in range(T):  # split along the time axis
            x[:, [t], :] = x_step
            dx = self.f_xu(x_step, u[:, t, :])
            x_step = x_step + dx

        return x



In [None]:
# Taken from pytorch-ident, https://github.com/forgi86/pytorch-ident/blob/master/torchid/ss/dt/models.py
class NeuralStateUpdate(nn.Module):

    def __init__(self, n_x=2, n_u=1, n_feat=32):
        super(NeuralStateUpdate, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(n_x+n_u, n_feat),
            nn.Tanh(),
            nn.Linear(n_feat, n_x),
        )
        
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=1e-2)
                nn.init.constant_(m.bias, val=0)


    def forward(self, x, u):
        z = torch.cat((x, u), dim=-1)
        dx = self.net(z)
        return dx

In [None]:
n_y = 1; n_x = 2; n_u = 1;
B = 1; # just one sequence
u = torch.randn((B, 1024, n_u)) # replace with actual training input
y = torch.randn((B, 1024, n_y)) # replace with actual training output

In [None]:
x0 = torch.zeros((B, n_x), requires_grad=True) # this is also a training variable
f_xu = NeuralStateUpdate(n_x, n_u, n_feat=32)
simulator = StateSpaceSimulator(f_xu) # 
#g_x = NeuralOutput(n_x, n_y, n_feat=32) # an MLP with n_x input and n_y outputs, to be defined

In [None]:
opt = torch.optim.AdamW([
        {'params': f_xu.parameters(),    'lr': 1e-3},
        #{'params': g_x.parameters(),    'lr': 1e-3},
        {'params':x0 , 'lr': 1e-3},
    ], 1e-3
)
#opt = torch.optim.AdamW(list(f_xu.parameters()) + list(g_x.parameters()) + [x0], 1e-3)

In [None]:
x_sim = simulator(x0, u) # B, T, n_x
#y_sim = g_x(x_sim) # # B, T, n_y
#loss = torch.nn.functional.mse_loss(y, y_sim)