In [None]:
import os
import sys
import copy
import torch
import numpy as np
from pprint import pprint

if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

from oil.tuning.args import argupdated_config

from src.train.det_trainer import make_trainer
from src.models import CHNN
from src.systems.chain_pendulum import ChainPendulum

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [None]:
defaults = copy.deepcopy(make_trainer.__kwdefaults__)
defaults["save"] = False

cfg = defaults #argupdated_config(defaults)
# cfg.pop('local_rank')
save = cfg.pop('save')

cfg["network"] = CHNN
cfg["device"] = None # 'cuda:0' if torch.cuda.is_available() else None
cfg["body"] = ChainPendulum(3)
cfg["C"] = 10
cfg["num_epochs"] = 50

pprint(cfg)

In [None]:
trainer = make_trainer(**cfg)
trainer.train(cfg['num_epochs'])

In [None]:
from src.systems.rigid_body import project_onto_constraints

z0_orig = cfg["body"].sample_initial_conditions(1)

eps = 2. * torch.rand_like(z0_orig.expand(10, -1, -1, -1)) - 1.
z0 = project_onto_constraints(cfg["body"].body_graph,
                              z0_orig.expand(10, -1, -1, -1) + 0.1 * eps, tol=1e-5)
ts = torch.arange(0., 10.0, cfg["body"].dt, device=z0_orig.device, dtype=z0_orig.dtype)

In [None]:
true_zt = cfg["body"].integrate(z0_orig, ts, method='rk4')
true_zt.shape

In [None]:
true_zt_chaos = cfg["body"].integrate(z0, ts, method='rk4')
true_zt_chaos.shape

In [None]:
num_samples = 10
pred_zt = []
for _ in range(num_samples):
    trainer.model.sample()
    with torch.no_grad():
    	zt_pred = trainer.model.integrate_swag(z0_orig, ts, method='rk4')
    pred_zt.append(zt_pred)
pred_zt = torch.cat(pred_zt, dim=0)
pred_zt.shape

In [None]:
import altair as alt
import pandas as pd

alt.data_transformers.disable_max_rows()

body_idx, dof_idx = 2, 1

true_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': true_zt[..., 0, body_idx, dof_idx].mean(dim=0).cpu().numpy(),
})).mark_line(color='black',strokeDash=[5,5]).encode(x='t:Q', y=alt.Y('y:Q'))

# pred_raw_chart = alt.Chart(pd.DataFrame({
#     't': ts.unsqueeze(0).expand(pred_zt.size(0), -1).cpu().numpy().flatten(),
#     'y': pred_zt[..., 0, body_idx, dof_idx].cpu().numpy().flatten(),
#     'init': (torch.arange(pred_zt.size(0)) + 1).unsqueeze(-1).expand(-1, pred_zt.size(1)).cpu().numpy().flatten()
# })).mark_line(opacity=0.5).encode(x='t:Q', y=alt.Y('y:Q'), color=alt.Color('init:N',legend=None))

pred_zt_mu = pred_zt[..., 0, body_idx, dof_idx].mean(dim=0)
pred_zt_std = pred_zt[..., 0, body_idx, dof_idx].std(dim=0)
pred_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': pred_zt_mu.cpu().numpy(),
    'y_lo': (pred_zt_mu - 2. * pred_zt_std).cpu().numpy(),
    'y_hi': (pred_zt_mu + 2. * pred_zt_std).cpu().numpy(),
})).mark_line(color='red',opacity=0.5).encode(x='t:Q', y='y:Q')
pred_err_chart = pred_chart.mark_area(opacity=0.1,color='red').encode(y='y_lo', y2='y_hi')

# chaos_raw_chart = alt.Chart(pd.DataFrame({
#     't': ts.unsqueeze(0).expand(true_zt_chaos.size(0), -1).cpu().numpy().flatten(),
#     'y': true_zt_chaos[..., 0, body_idx, dof_idx].cpu().numpy().flatten(),
#     'init': (torch.arange(true_zt_chaos.size(0)) + 1).unsqueeze(-1).expand(-1, true_zt_chaos.size(1)).cpu().numpy().flatten()
# })).mark_line(opacity=0.5).encode(x='t:Q', y=alt.Y('y:Q'), color=alt.Color('init:N',legend=None))

true_zt_chaos_mu = true_zt_chaos[..., 0, body_idx, dof_idx].mean(dim=0)
true_zt_chaos_std = true_zt_chaos[..., 0, body_idx, dof_idx].std(dim=0)
chaos_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': true_zt_chaos_mu.cpu().numpy(),
    'y_lo': (true_zt_chaos_mu - 2. * true_zt_chaos_std).cpu().numpy(),
    'y_hi': (true_zt_chaos_mu + 2. * true_zt_chaos_std).cpu().numpy(),
})).mark_line(color='blue',opacity=0.5).encode(x='t:Q', y='y:Q')
chaos_err_chart = chaos_chart.mark_area(opacity=0.1,color='blue').encode(y='y_lo', y2='y_hi')

(chaos_err_chart + chaos_chart + true_chart | pred_err_chart + pred_chart + true_chart).properties(title=f'Mass = {body_idx}, DoF = {dof_idx}; Chaos v/s Predictions')