Use this notebook as the central place to collect all data dumps and process them into figures for _Trace plots_.

In [None]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [None]:
## TODO: change this to the dump path.
data_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(data_dump_path)

data_dump = torch.load(data_dump_path)

ts = data_dump.get('ts')
z0_orig = data_dump.get('z0_orig')
true_zt = data_dump.get('true_zt')
true_zt_chaos = data_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
pred_zt = data_dump.get('pred_zt')

In [None]:
true_zt_chaos.shape

In [None]:
def plot_trace(ref, chaos, pred):
    '''
    M: number of samples
    T: timesteps
    N: number of bodies
    The last "2" corresponds to xy dimension.
    Arguments:
        ref: T x N x 2
        chaos: M x T x N x 2
        pred: M x T x N x 2
    '''
    N = ref.size(1)
    ref = ref.cpu().numpy()
    label = ['x', 'y']

    for n in range(N):
        body_chart = None
        for dof in range(2):
            chaos_mean = chaos[..., n, dof].mean(dim=0).cpu().numpy()
            chaos_std = chaos[..., n, dof].std(dim=0).cpu().numpy()

            pred_mean = pred[..., n, dof].mean(dim=0).cpu().numpy()
            pred_std = pred[..., n, dof].std(dim=0).cpu().numpy()

            true_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': ref[:, n, dof],
            })).mark_line(color='black',strokeDash=[5,5]).encode(x='t:Q', y=alt.Y('y:Q'))

            chaos_mean_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': chaos_mean,
                'y_lo': (chaos_mean - 2. * chaos_std),
                'y_hi': (chaos_mean + 2. * chaos_std),
            })).mark_line(color='blue', opacity=0.5).encode(x='t:Q', y='y:Q')
            chaos_err_chart = chaos_mean_chart.mark_area(opacity=0.2, color='blue').encode(y='y_lo', y2='y_hi')

            pred_mean_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': pred_mean,
                'y_lo': (pred_mean - 2. * pred_std),
                'y_hi': (pred_mean + 2. * pred_std),
            })).mark_line(color='red', opacity=0.5).encode(x=alt.X('t'), y='y:Q')
            pred_err_chart = pred_mean_chart.mark_area(opacity=0.2, color='red').encode(y=alt.Y('y_lo', title=label[dof]), y2='y_hi')

            if dof == 0:
                body_chart = true_chart + chaos_mean_chart + chaos_err_chart + pred_mean_chart + pred_err_chart
            else:
                body_chart |= true_chart + chaos_mean_chart + chaos_err_chart + pred_mean_chart + pred_err_chart

        yield body_chart.properties(title=f'Body Mass {n + 1}')


In [None]:
traj_idx = 0

all_plot = None
for b, pl in enumerate(plot_trace(true_zt[traj_idx, :, 0, ...], true_zt_chaos[:, traj_idx, :, 0, ...], pred_zt[:, traj_idx, :, 0, ...])):
    all_plot = pl if all_plot is None else (all_plot & pl)
    # pl.save(f'traj{traj_idx + 1}-b{b + 1}.json')

In [None]:
all_plot