In [None]:
from argparse import Namespace
import nonlinear_benchmarks
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from model import (
    StateSpaceSimulator,
    NeuralStateUpdate,
    CascadedTanksOverflowNeuralStateSpaceModel,
)
from tqdm import tqdm

In [None]:
cfg = {
    "n_x": 2,
    "n_feat": 32,
    "lr": 1e-3,
}
cfg = Namespace(**cfg)

In [None]:
train_val, test = nonlinear_benchmarks.Cascaded_Tanks()
train_u, train_y = train_val
train_u = train_u.reshape(-1, 1)  # T, 1
train_y = train_y.reshape(-1, 1)  # T, 1

In [None]:
n_u = train_u.shape[-1]
n_y = train_y.shape[-1]

In [None]:
# Rescale data (Z-score)
scaler_u = StandardScaler()
u = scaler_u.fit_transform(train_u)

scaler_y = StandardScaler()
y = scaler_y.fit_transform(train_y)

In [None]:
f_xu = NeuralStateUpdate(cfg.n_x, n_u, n_feat=cfg.n_feat)
# f_xu = CascadedTanksOverflowNeuralStateSpaceModel()
model = StateSpaceSimulator(f_xu)
x0 = torch.zeros((1, cfg.n_x)).requires_grad_(True)

In [None]:
opt = torch.optim.AdamW(
    [
        {"params": model.parameters(), "lr": cfg.lr},
        {"params": x0, "lr": cfg.lr},
    ],
    lr=cfg.lr,
)

In [None]:
# Load data
u = torch.tensor(u).unsqueeze(0).float()  # B=1, T, 1
y = torch.tensor(y).unsqueeze(0).float()  # B=1, T, 1

In [None]:
LOSS = []
# Train loop
for itr in tqdm(range(5000)):

    x_sim = model(x0, u)
    y_pred = x_sim[:, :, [1]]  # output is the second state
    loss = torch.nn.functional.mse_loss(y, y_pred)

    loss.backward()
    opt.step()

    opt.zero_grad()
    if itr % 100 == 0:
        print(loss.item())
    LOSS.append(loss.item())

In [None]:
checkpoint = {
    "scaler_u": scaler_u,
    "scaler_y": scaler_y,
    "model": model.state_dict(),
    "x0": x0,
    "LOSS": np.array(LOSS),
    "cfg": cfg,
}

torch.save(checkpoint, "ckpt_model2.pt")

In [None]:
plt.plot(LOSS)
plt.ylim([0, 0.2])