Use this notebook as the central place to collect all data dumps and process them into figures for _Trace plots_.

In [None]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [None]:
## TODO: change this to the dump path.
swag_dump_path = '../.log/swag-3body/files/data.pt'
de_dump_path = '../.log/de-3body/files/data.pt'

assert os.path.isfile(swag_dump_path)
assert os.path.isfile(de_dump_path)

swag_dump = torch.load(swag_dump_path)
de_dump = torch.load(de_dump_path)

ts = swag_dump.get('ts')
z0_orig = swag_dump.get('z0_orig')
true_zt = swag_dump.get('true_zt')
true_zt_chaos = swag_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
swag_pred_zt = swag_dump.get('pred_zt')
de_pred_zt = de_dump.get('pred_zt')

In [None]:
true_zt_chaos.shape

In [None]:
def plot_trace(ref, chaos, swag_pred, de_pred):
    '''
    M: number of samples
    T: timesteps
    N: number of bodies
    The last "2" corresponds to xy dimension.
    Arguments:
        ref: T x N x 2
        chaos: M x T x N x 2
        pred: M x T x N x 2
    '''
    N = ref.size(1)
    ref = ref.cpu().numpy()
    label = ['x', 'y']

    for n in range(N):
        body_chart = None
        for dof in range(2):
            chaos_mean = chaos[..., n, dof].mean(dim=0).cpu().numpy()
            chaos_std = chaos[..., n, dof].std(dim=0).cpu().numpy()

            de_mean = de_pred[..., n, dof].mean(dim=0).cpu().numpy()
            de_std = de_pred[..., n, dof].std(dim=0).cpu().numpy()

            swag_mean = swag_pred[..., n, dof].mean(dim=0).cpu().numpy()
            swag_std = swag_pred[..., n, dof].std(dim=0).cpu().numpy()

            true_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': ref[:, n, dof],
            })).mark_line(color='black',strokeDash=[5,5]).encode(x='t:Q', y=alt.Y('y:Q'))

            chaos_mean_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': chaos_mean,
                'y_lo': (chaos_mean - 2. * chaos_std),
                'y_hi': (chaos_mean + 2. * chaos_std),
            })).mark_line(color='green', opacity=0.5).encode(x='t:Q', y='y:Q')
            chaos_err_chart = chaos_mean_chart.mark_area(opacity=0.2, color='green').encode(y='y_lo', y2='y_hi')
            chaos_chart = chaos_mean_chart + chaos_err_chart

            de_mean_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': de_mean,
                'y_lo': (de_mean - 2. * de_std),
                'y_hi': (de_mean + 2. * de_std),
            })).mark_line(color='blue', opacity=0.5).encode(x='t:Q', y='y:Q')
            de_err_chart = de_mean_chart.mark_area(opacity=0.2, color='blue').encode(y='y_lo', y2='y_hi')
            de_chart = de_mean_chart + de_err_chart

            swag_pred_mean_chart = alt.Chart(pd.DataFrame({
                't': ts.cpu().numpy(),
                'y': swag_mean,
                'y_lo': (swag_mean - 2. * swag_std),
                'y_hi': (swag_mean + 2. * swag_std),
            })).mark_line(color='red', opacity=0.5).encode(x=alt.X('t'), y='y:Q')
            swag_pred_err_chart = swag_pred_mean_chart.mark_area(opacity=0.2, color='red').encode(y=alt.Y('y_lo', title=label[dof]), y2='y_hi')
            swag_chart = swag_pred_mean_chart + swag_pred_err_chart

            if dof == 0:
                body_chart = true_chart + chaos_chart + de_chart + swag_chart
            else:
                body_chart |= true_chart + chaos_chart + de_chart + swag_chart

        yield body_chart.properties(title=f'Body Mass {n + 1}')


In [None]:
traj_idx = 0

all_plot = None
for b, pl in enumerate(plot_trace(true_zt[traj_idx, :, 0, ...], true_zt_chaos[:, traj_idx, :, 0, ...], swag_pred_zt[:, traj_idx, :, 0, ...], de_pred_zt[:, traj_idx, :, 0, ...])):
    all_plot = pl if all_plot is None else (all_plot & pl)
    # pl.save(f'traj{traj_idx + 1}-b{b + 1}.json')

In [None]:
all_plot