In [None]:
from pathlib import Path
import time
import torch
import numpy as np
import pandas as pd
from dataset import LinearDynamicalDataset, WHDataset, PWHDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from model import GPTConfig, GPT
import tqdm
import argparse

In [None]:
# Overall settings
out_dir = "out"

# System settings
nu = 1
ny = 1
#seq_len = 600
batch_size = 32 # 256


# Compute settings
cuda_device = "cuda:2"
no_cuda = False
threads = 5
compile = True

# Configure compute
torch.set_num_threads(threads) 
use_cuda = not no_cuda and torch.cuda.is_available()
device_name  = cuda_device if use_cuda else "cpu"
device = torch.device(device_name)
device_type = 'cuda' if 'cuda' in device_name else 'cpu' # for later use in torch.autocast
torch.set_float32_matmul_precision("high")
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

In [None]:
# Create out dir
out_dir = Path(out_dir)
#exp_data = torch.load(out_dir/"ckpt_lin.pt") # trained on linear models!
#exp_data = torch.load(out_dir/"ckpt_small_wh_last.pt")
#exp_data = torch.load(out_dir/"ckpt_small_wh.pt")
#exp_data = torch.load(out_dir/"ckpt_big.pt", map_location=device)
#exp_data = torch.load(out_dir/"ckpt_small_wh_adapt_last.pt")
exp_data = torch.load(out_dir/"ckpt_big_pwh.pt", map_location=device)

In [None]:
seq_len = exp_data["cfg"].seq_len
nx = exp_data["cfg"].nx

In [None]:
loss_smooth = pd.Series(exp_data["LOSS"]).rolling(100).mean()
fig = go.Figure()
fig.add_trace(go.Scatter(y=exp_data["LOSS"], name="TRAINING LOSS", line_color="black"))
fig.add_trace(go.Scatter(y=loss_smooth, name="TRAINING LOSS SMOOTH", line_color="blue"))
fig.add_trace(go.Scatter(x=np.arange(1, len(exp_data["LOSS_VAL"])+1)*2000,
                         y=exp_data["LOSS_VAL"], name="VAL LOSS", line_color="red")
             )
fig.show()

In [None]:
model_args = exp_data["model_args"]
gptconf = GPTConfig(**model_args)
model = GPT(gptconf).to(device)


state_dict = exp_data["model"]
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict);

In [None]:
# Create data loader
#test_ds = LinearDynamicalDataset(nx=nx, nu=nu, ny=ny, seq_len=seq_len)
#test_ds = WHDataset(nx=nx, nu=nu, ny=ny, seq_len=seq_len, system_seed=42, data_seed=None, fixed_system=True)
test_ds = PWHDataset(nx=nx, nu=nu, ny=ny, seq_len=seq_len)#, system_seed=42, data_seed=None, fixed_system=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=threads)

In [None]:
batch_y, batch_u = next(iter(test_dl))
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)
with torch.no_grad():
    batch_y_pred, loss = model(batch_u, batch_y)
    batch_y_pred = batch_y_pred.to("cpu").detach().numpy()
    batch_y = batch_y.to("cpu").detach().numpy()
    batch_u = batch_u.to("cpu").detach().numpy()

In [None]:
batch_y_target = batch_y[:, 1:, :] # target @ time k: y_{k+1}
batch_y_pred = batch_y_pred[:, :-1, :] # prediction @ time k: y_{k+1|k}
batch_y_pred_dummy = batch_y[:, :-1, :] # dummy estimator: y_{k+1} \approx y_{k}
batch_pred_err = batch_y_target - batch_y_pred
batch_pred_err_dummy = batch_y_target - batch_y_pred_dummy

In [None]:
#exp_data["LOSS_VAL"]

In [None]:
plt.figure()
plt.title("RMSE")
idx = 5
plt.plot(batch_y_target[idx], 'k', label="True")
plt.plot(batch_y_pred[idx], 'b', label="Pred")
#plt.plot(batch_y_pred_dummy[idx], 'm', label="Pred dummy")
plt.plot(batch_y_target[idx] - batch_y_pred[idx], 'r', label="Err")
#plt.plot(batch_y_target[idx] - batch_y_pred_dummy[idx], 'm', label="Err dummy")
plt.legend()
#plt.xlim([0, 600]);

In [None]:
plt.plot(batch_pred_err_dummy.squeeze(-1).T, "r", alpha=0.4);
plt.plot(batch_pred_err.squeeze(-1).T, "k", alpha=0.4);

In [None]:
from torchid import metrics
skip = 400
rmse_transformer = metrics.rmse(batch_y_target[:, skip:, :], batch_y_pred[:, skip:, :], time_axis=1)
rmse_dummy = metrics.rmse(batch_y_target[:, skip:, :], batch_y_pred_dummy[:, skip:, :], time_axis=1)

In [None]:
print(f"{rmse_transformer.mean()=:.2f}, {rmse_dummy.mean()=:.2f}")

In [None]:
plt.figure()
plt.title("RMSE")
plt.hist(rmse_dummy, color="red", label="dummy");
plt.hist(rmse_transformer, color="black", label="transformer");
plt.legend();

In [None]:
(rmse_transformer**2).mean()