In [None]:
import torch
import torch.nn as nn

# Utilities for neural state-space model learning

In [None]:
# Adapted from pytorch-ident, https://github.com/forgi86/pytorch-ident/blob/master/torchid/ss/dt/models.py
# Implements a state-update function f(x, u)
class NeuralStateUpdate(nn.Module):

    def __init__(self, n_x=2, n_u=1, n_feat=32):
        super(NeuralStateUpdate, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(n_x+n_u, n_feat),
            nn.Tanh(),
            nn.Linear(n_feat, n_x),
        )
        
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=1e-2)
                nn.init.constant_(m.bias, val=0)


    def forward(self, x, u):
        z = torch.cat((x, u), dim=-1)
        dx = self.net(z)
        return dx

In [None]:
# Adapted from pytorch-ident, https://github.com/forgi86/pytorch-ident/blob/master/torchid/ss/dt/models.py
# Implements the output function g(x). In fact, it is just a standard feedforward net
class NeuralOutput(nn.Module):

    def __init__(self, n_x, n_y, n_feat=32):
        super(NeuralOutput, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(n_x, n_feat),
            nn.Tanh(),
            nn.Linear(n_feat, n_y),
        )
                    
    def forward(self, x):
        y = self.net(x)
        return y

In [None]:
# Adapted from pytorch-ident, https://github.com/forgi86/pytorch-ident/blob/master/torchid/ss/dt/models.py
# Unrolls the state-update function f(x,u) over time, starting from an initial state x_0

class StateSpaceSimulator(nn.Module):
    def __init__(self, f_xu):
        super().__init__()
        self.f_xu = f_xu

    def forward(self, x_0, u):
        B, n_x = x_0.shape
        _, T, _ = u.shape # B, T, n_u
        x = torch.empty((B, T, n_x))
        x_step = x_0

        # manually unroll f_xu over time
        for t in range(T): 
            x[:, t, :] = x_step
            dx = self.f_xu(x_step, u[:, t, :])
            x_step = x_step + dx

        return x

## Basic usage

In [None]:
n_x = 2; n_u = 1;
f_xu = NeuralStateUpdate(n_x, n_u, n_feat=32)
simulator = StateSpaceSimulator(f_xu)

In [None]:
# Simulate just one state update manually
B = 32
batch_x = torch.randn((B, n_x)) # B, n_x
batch_u = torch.randn((B, n_u)) # B, n_u

batch_dx = f_xu(batch_x, batch_u) # B, n_x
batch_x_new = batch_x + batch_dx # B, n_x

In [None]:
# Unroll f_xu over time more conveniently with the simulator
B, T = 32, 1024; 
batch_x0 = torch.zeros((B, n_x))
batch_u = torch.randn((B, T, n_u)) # replace with actual training input
batch_x_sim = simulator(batch_x0, batch_u) # B, T, n_x 
batch_x_sim.shape

In [None]:
# Overall, we have defined a sort of custom RNN. In fact, it behaves pretty much like a standard LSTM (with batch_first=True)
rnn = torch.nn.RNN(input_size=n_u, hidden_size=n_x, batch_first=True, num_layers=1)
batch_h, _ = rnn(batch_u) # returns both hidden sequence and last hidden layes
batch_h.shape

In [None]:
batch_h, _ = rnn(batch_u, batch_x0[None, :]) # initial state has dimensions (num_layers, batch_size, hidden_size)
batch_h.shape

## For Cascaded Two-tanks

In [None]:
n_y = 1; n_x = 2; n_u = 1
B = 1 # just one sequence
T = 1024
u = torch.randn((B, T, n_u)) # replace with actual training input
y = torch.randn((B, T, n_y)) # replace with actual training output

In [None]:
x0 = torch.zeros((B, n_x), requires_grad=True) # this is also a training variable
f_xu = NeuralStateUpdate(n_x, n_u, n_feat=32)
g_x = NeuralOutput(n_x, n_y)
simulator = StateSpaceSimulator(f_xu)

In [None]:
opt = torch.optim.AdamW(
    [
        {"params": f_xu.parameters(), "lr": 1e-3},
        {"params": g_x.parameters(), "lr": 1e-3},
        {"params": x0, "lr": 1e-3},
    ],
    1e-3, # default
)
# opt = torch.optim.AdamW(list(f_xu.parameters()) + list(g_x.parameters()) + [x0], 1e-3)

In [None]:
x_sim = simulator(x0, u) # B, T, n_x
y_sim = g_x(x_sim) # # B, T, n_y
loss = torch.nn.functional.mse_loss(y, y_sim)
loss