In [None]:
from pathlib import Path
import time
import torch
import numpy as np
from dataset import LinearDynamicalDataset
from torch.utils.data import DataLoader
from model import GPTConfig, GPT
import tqdm
import argparse
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

In [None]:
# Overall settings
out_dir = "out"

# System settings
nx = 10
nu = 1
ny = 1
seq_len = 400


# Compute settings
cuda_device = "cuda:0"
no_cuda = False
threads = 20
compile = True
batch_size = 256

# Create out dir
out_dir = Path(out_dir)
exp_data = torch.load(out_dir/"ckpt_lin.pt")

# Configure compute
torch.set_num_threads(threads)
use_cuda = not no_cuda and torch.cuda.is_available()
device_name  = cuda_device if use_cuda else "cpu"
device = torch.device(device_name)
device_type = 'cuda' if 'cuda' in device_name else 'cpu' # for later use in torch.autocast
torch.set_float32_matmul_precision("high")
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

# Create data loader
test_ds = LinearDynamicalDataset(nx=nx, nu=nu, ny=ny, seq_len=seq_len)
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=threads)

model_args = exp_data["model_args"]
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = exp_data["model"]
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict);
model = model.to(device)
#if compile:
#    model = torch.compile(model)
model.eval();

In [None]:
batch_y, batch_u = next(iter(test_dl))
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)

In [None]:
# Call model like in training (future inputs/outputs known)
with torch.no_grad():
    batch_y_pred, _ = model(batch_u, batch_y, compute_loss=True)

In [None]:
# Call model causally to see if it really behaves correctly
batch_y_pred_rt = []
seq_len = batch_y.shape[1]
with torch.no_grad():
    for idx in range(seq_len):
        batch_y_t, _ = model(batch_u[:, :idx+1, :], batch_y[:, :idx+1, :], compute_loss=False)
        batch_y_pred_rt.append(batch_y_t)
batch_y_pred_rt = torch.cat(batch_y_pred_rt, dim=1)
#batch_y_pred_rt.shape

In [None]:
#torch.testing.assert_close(batch_y_pred, batch_y_pred_rt)
torch.max(torch.abs(batch_y_pred - batch_y_pred_rt)).item()

In [None]:
# Call model in simulation from a certain time step!
sim_start = 100
batch_y_sim = torch.zeros_like(batch_y)
batch_y_sim[:, :sim_start, :] = batch_y[:, :sim_start, :]
with torch.no_grad():
    for idx in range(sim_start, seq_len):
        batch_y_t, _ = model(batch_u[:, :idx, :], batch_y_sim[:, :idx, :], compute_loss=False)
        batch_y_sim[:, [idx], :] = batch_y_t

In [None]:
batch_y_pred = batch_y_pred.to("cpu").detach().numpy()
batch_y_pred_rt = batch_y_pred_rt.to("cpu").detach().numpy()
batch_y_sim = batch_y_sim.detach().to("cpu").numpy()
batch_y = batch_y.detach().to("cpu").numpy()

In [None]:
batch_y_target = batch_y[:, 1:, :] # target @ time k: y_{k+1}
batch_y_pred = batch_y_pred[:, :-1, :] # prediction @ time k: y_{k+1|k}
batch_y_sim = batch_y_sim[:, 1:, :] # simulation @ time k: y_{k+1|k}
batch_pred_err = batch_y_target - batch_y_pred
batch_sim_err = batch_y_target - batch_y_pred

In [None]:
instance = 1
fig = go.Figure()
fig.add_trace(go.Scatter(y=batch_y_target[instance].squeeze(), name="y", line_color="black"))
fig.add_trace(go.Scatter(y=batch_y_sim[instance].squeeze(), name="y_sim", line_color="blue"))
#fig.add_trace(go.Scatter(y=batch_y_pred[instance].squeeze(), name="y_pred", line_color="magenta"))
fig.add_vline(x=sim_start, line_color="red", name="sim_start")
#import matplotlib.pyplot as plt
#plt.plot(batch_y[1], 'k', label="True")
#plt.plot(batch_y_pred[0], 'b', label="Pred")
#plt.plot(batch_y_pred_rt[0], 'm', label="Pred")
#plt.plot(batch_y_sim[1], 'b', label="Sim")

In [None]:
from torchid import metrics
skip = sim_start
rmse_pred = metrics.rmse(batch_y_target[:, skip:, :], batch_y_pred[:, skip:, :], time_axis=1)
rmse_sim = metrics.rmse(batch_y_target[:, skip:, :], batch_y_sim[:, skip:, :], time_axis=1)
#rmse_z = metrics.rmse(batch_y_target[:, skip:, :], 0*batch_y_sim[:, skip:, :], time_axis=1)

In [None]:
rmse_pred.mean(), rmse_sim.mean()

In [None]:
plt.figure()
plt.title("RMSE")
plt.hist(rmse_sim, color="black", label="sim");
plt.hist(rmse_pred, color="red", label="pred");
plt.legend();

In [None]:
fig, ax = plt.subplots()
plt.boxplot([rmse_pred.ravel(), rmse_sim.ravel()], labels=["pred", "sim"]);

#plt.boxplot(rmse_pred);
#ax.set_xticklabels("pred")
#plt.boxplot(rmse_sim)