In [None]:
import os
import sys
import torch
import numpy as np
import altair as alt
import pandas as pd

if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

from src.train.det_trainer import make_trainer
from src.models import CHNN
from src.systems.chain_pendulum import ChainPendulum

alt.data_transformers.disable_max_rows()

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

device = "cuda:0" if torch.cuda.is_available() else None

In [None]:
cfg = dict()
cfg["network"] = CHNN
cfg["body"] = ChainPendulum(3)
cfg["device"] = None #'cuda' if torch.cuda.is_available() else None
cfg["C"] = 5
cfg["num_epochs"] = 10
 
cfg

In [None]:
from src.systems.rigid_body import project_onto_constraints

eps_scale = 1e-2

z0_orig = cfg["body"].sample_initial_conditions(1)

eps = 2. * torch.rand_like(z0_orig.expand(10, -1, -1, -1)) - 1.

z0 = project_onto_constraints(cfg["body"].body_graph,
                              z0_orig.expand(10, -1, -1, -1) + eps_scale * eps, tol=1e-5)
ts = torch.arange(0., 10.0, cfg["body"].dt, device=z0_orig.device, dtype=z0_orig.dtype)

true_zt = cfg["body"].integrate(z0_orig, ts, method='rk4')
true_zt_chaos = cfg["body"].integrate(z0, ts, method='rk4')

true_zt.shape, true_zt_chaos.shape

In [None]:
.shape

In [None]:
body_idx = 2
dof_idx = 1

chaos_chart = alt.Chart(pd.DataFrame({
    't': ts.unsqueeze(0).expand(10, -1).reshape(-1).cpu().numpy(),
    'y': true_zt_chaos[..., 0, body_idx, dof_idx].reshape(-1).cpu().numpy(),
    'idx': torch.arange(10).unsqueeze(-1).expand(-1, 334).reshape(-1).cpu().numpy()
})).mark_line(opacity=.3).encode(x='t',y=alt.Y('y', scale=alt.Scale(domain=(-3.1,-1.))),color=alt.Color('idx:N', legend=None, scale=alt.Scale(scheme='category10')))

true_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': true_zt[..., 0, body_idx, dof_idx].mean(dim=0).cpu().numpy(),
})).mark_line(color='black',strokeDash=[5,5]).encode(x='t', y=alt.Y('y', scale=alt.Scale(domain=(-3.1,-1.))))

# chaos_chart = alt.Chart(pd.DataFrame({
#     't': ts.cpu().numpy(),
#     'y': true_zt_chaos_mu.cpu().numpy(),
#     'y_lo': (true_zt_chaos_mu - 2. * true_zt_chaos_std).cpu().numpy(),
#     'y_hi': (true_zt_chaos_mu + 2. * true_zt_chaos_std).cpu().numpy(),
# })).mark_line(color='blue',opacity=0.5).encode(x='t:Q', y='y:Q')
# chaos_err_chart = chaos_chart.mark_area(opacity=0.1,color='blue').encode(y='y_lo', y2='y_hi')
# dof_chart = (chaos_err_chart + chaos_chart + true_chart).properties(title=f'Mass = {body_idx}, DoF = {dof_idx}')

(chaos_chart + true_chart).properties(width=600,height=200).save('chart.json')

In [None]:
trainer = make_trainer(**cfg)
trainer.train(cfg['num_epochs'])

In [None]:
num_samples = 10
pred_zt = []
for _ in range(num_samples):
    trainer.model.sample()
    model = trainer.model.to(device)
    z0_orig = z0_orig.to(device)
    ts = ts.to(device)
        
    with torch.no_grad():
    	zt_pred = trainer.model.integrate_swag(z0_orig, ts, method='rk4')
    pred_zt.append(zt_pred)
    
pred_zt = torch.cat(pred_zt, dim=0)
pred_zt.shape

In [None]:
pred_zt_mu = pred_zt[..., 0, body_idx, dof_idx].mean(dim=0)
pred_zt_std = pred_zt[..., 0, body_idx, dof_idx].std(dim=0)
pred_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': pred_zt_mu.cpu().numpy(),
    'y_lo': (pred_zt_mu - 2. * pred_zt_std).cpu().numpy(),
    'y_hi': (pred_zt_mu + 2. * pred_zt_std).cpu().numpy(),
})).mark_line(color='red',opacity=0.5).encode(x='t:Q', y='y:Q')
pred_err_chart = pred_chart.mark_area(opacity=0.1,color='red').encode(y='y_lo', y2='y_hi')

# (pred_err_chart + pred_chart + true_chart).properties(title=f'Mass = {body_idx}, DoF = {dof_idx}; Prediction')

In [None]:
(chaos_err_chart + chaos_chart + true_chart | pred_err_chart + pred_chart + true_chart).properties(title=f'Mass = {body_idx}, DoF = {dof_idx}; Chaos v/s Predictions')