In [31]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from src.train.ensemble_trainer import make_trainer
from src.models import HNN, CHNN
from src.systems.chain_pendulum import ChainPendulum
from src.datasets import get_chaotic_eval_dataset

import warnings
warnings.filterwarnings('ignore')

def trace_plot(t, y, ax, y_std=None, min_y=-np.inf, max_y=np.inf,
               kind='region', color='black',
               plt_args=None):
    '''
    Arguments:
        zt: Assumes zeroth dimension is number of samples.
        ax: Matplotlib Axis
        pos: Positive quantity flag (for range bounds)
    '''
    assert kind in ['region', 'bound']
    plt_args = plt_args or dict()

    mu = np.mean(y, axis=0)
    std = np.std(y, axis=0) if y_std is None else y_std

    ax.plot(t, mu, c=color, **plt_args)

    if y.shape[0] == 1:
        return

    lower = np.clip(mu - 2. * std, min_y, max_y)
    upper = np.clip(mu + 2. * std, min_y, max_y)

    if kind == 'region':
        ax.fill_between(ts, lower, upper, color=color, alpha=0.2)
    elif kind == 'bound':
        ax.plot(np.array([t, t]).T, np.array([lower, upper]).T, c=color, dashes=[8,4])
    else:
        raise NotImplementedError

In [None]:
os.environ['DATADIR'] = '.'

cfg = {
    "num_bodies": 2,
    "lr": 3e-3,
    "tau": 10.0,
    "C": 5,
    "num_epochs": 40,
    "uq_type": "deep-ensemble",
    "device": 'cuda:0' if torch.cuda.is_available() else None
}

body = ChainPendulum(cfg['num_bodies'])
network = CHNN
trainer = make_trainer(**cfg, network=network, body=body)
trainer.train(cfg['num_epochs'])

n_init = 10
n_samples = 5
evald = get_chaotic_eval_dataset(body, n_init=n_init, n_samples=n_samples)

model = trainer.model.to(cfg['device'])
ts = evald['ts'].to(cfg['device'])
z0_orig = evald['z0_orig'].to(cfg['device'])

with torch.no_grad():
    pred_zt = trainer.model(z0_orig, ts).cpu()

ts = ts.cpu().numpy()
pred_zt = pred_zt.numpy()
true_zt = evald['true_zt'].numpy()
true_zt_chaos = evald['true_zt_chaos'].numpy()

Training ensemble member 0...


HBox(children=(FloatProgress(value=0.0, description='train', max=40.0, style=ProgressStyle(description_width='…

   Minibatch_Loss  Train_MAE    lr0  nfe  test_MAE
0        0.128798   0.122779  0.003  0.0  0.127001
    Minibatch_Loss  Train_MAE       lr0  nfe  test_MAE
14        0.053317   0.054422  0.002951  0.8  0.055059
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
26        0.028735   0.027445  0.002823  0.421053  0.025779
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
38        0.022321   0.020103  0.002621  0.285714  0.018597
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
50        0.008935   0.010073  0.002358  0.216216  0.009306
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
62        0.004661   0.005948  0.002047  0.173913  0.006027
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
73        0.006398   0.006731  0.001735  0.146789  0.006517
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
84        0.003144   0.004247  0.001412  0.126984  0.004225
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
95      

HBox(children=(FloatProgress(value=0.0, description='train', max=40.0, style=ProgressStyle(description_width='…

   Minibatch_Loss  Train_MAE    lr0  nfe  test_MAE
0        0.120503   0.120837  0.003  0.0   0.12431
    Minibatch_Loss  Train_MAE       lr0  nfe  test_MAE
10        0.045549   0.052132  0.002977  1.0  0.049603
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
21        0.038472   0.038553  0.002886  0.484848  0.036564
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
31        0.024269   0.026172  0.002747  0.326531  0.024004
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
39        0.020267   0.020428  0.002601  0.253968  0.018522
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
48        0.010117   0.013487  0.002405  0.205128  0.012842
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
58        0.007359   0.006967  0.002155  0.170213  0.006287
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
68        0.007406   0.006881  0.001879  0.145455  0.006997
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
78      

HBox(children=(FloatProgress(value=0.0, description='train', max=40.0, style=ProgressStyle(description_width='…

   Minibatch_Loss  Train_MAE    lr0  nfe  test_MAE
0        0.117283   0.119897  0.003  0.0  0.122648
   Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
9        0.050904   0.050611  0.002982  1.066667  0.048483
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
19        0.046912   0.049103  0.002907  0.516129  0.047653
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
29        0.033687   0.037001  0.002779  0.340426  0.037483
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
38        0.027635   0.027543  0.002621  0.258065  0.027289
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
48        0.020556   0.021232  0.002405  0.205128  0.020333
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
59        0.016804   0.016388  0.002128  0.168421  0.015951
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
68        0.012911   0.013568  0.001879  0.145455  0.012976
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE


HBox(children=(FloatProgress(value=0.0, description='train', max=40.0, style=ProgressStyle(description_width='…

   Minibatch_Loss  Train_MAE    lr0  nfe  test_MAE
0        0.125439   0.127919  0.003  0.0  0.132452
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
11        0.068868   0.066223  0.002971  0.941176  0.065647
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
22        0.033319   0.036758  0.002874  0.470588  0.034467
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
33        0.022233   0.025205  0.002714  0.313725  0.023914
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
44         0.01343   0.016361  0.002496  0.235294  0.015352
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
55        0.012123   0.012364  0.002233  0.188235  0.012083
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
66        0.008661   0.009432  0.001935  0.156863  0.009328
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
77        0.005583   0.005234  0.001618  0.134454  0.005025
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MA

HBox(children=(FloatProgress(value=0.0, description='train', max=40.0, style=ProgressStyle(description_width='…

   Minibatch_Loss  Train_MAE    lr0  nfe  test_MAE
0        0.120367   0.121233  0.003  0.0  0.124926
    Minibatch_Loss  Train_MAE       lr0       nfe  test_MAE
11        0.057447   0.056907  0.002971  0.941176  0.056552


In [None]:
T, n_dof = ts.shape[-1], true_zt_chaos.shape[-1]
for init_id in range(n_init):
    
    fig, ax = plt.subplots(n_dof, cfg['num_bodies'], figsize=(10, 10))
    for b_id in range(cfg['num_bodies']):
        
        dof_label = ['x', 'y']
        for dof_id in range(n_dof):
            ax[dof_id, b_id].set_title(f'Body {b_id + 1} $\mid$ Dimension {dof_label[dof_id]}')
            ax[dof_id, b_id].set_xlabel('t')

            trace_plot(ts, true_zt[np.newaxis, init_id, :, 0, b_id, dof_id], ax[dof_id, b_id],
                       color=(.2,.2,.2), plt_args=dict(dashes=[4,2]))

            trace_plot(ts, true_zt_chaos[:, init_id, :, 0, b_id, dof_id], ax[dof_id, b_id],
                       color=(.6,.6,.6))

            trace_plot(ts, pred_zt[:, init_id, :, 0, b_id, dof_id], ax[dof_id, b_id], kind='bound', 
                       color=(.2,.2,1.,.75))

    fig.tight_layout()
    fig.suptitle(f'Init {init_id + 1}', fontsize=16, y=1.01)
    plt.show(fig)
    print("\n")