Use this notebook as the central place to collect all data dumps and process them into figures for _Trace plots_.

In [1]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [6]:
## TODO: change this to the dump path.
data_dump_path = '../../de-3body-data.pt'

assert os.path.isfile(data_dump_path)

data_dump = torch.load(data_dump_path)

ts = data_dump.get('ts')
z0_orig = data_dump.get('z0_orig')
true_zt = data_dump.get('true_zt')
true_zt_chaos = data_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
de_pred_zt = data_dump.get('pred_zt')

data_dump_path = '../../swag-3body-data.pt'
data_dump = torch.load(data_dump_path)
swag_pred_zt = data_dump.get('pred_zt')

In [30]:
def plot_trace(ref, chaos, swag_pred, de_pred):
    '''
    M: number of samples
    T: timesteps
    N: number of bodies
    The last "2" corresponds to xy dimension.
    Arguments:
        ref: T x N x 2
        chaos: M x T x N x 2
        pred: M x T x N x 2
    '''
    N = ref.size(1)
    ref = ref.cpu().numpy()
    label = ['x', 'y']

    n = 0
    dof = 0
    
    body_chart = None
    
    chaos_mean = chaos[..., n, dof].mean(dim=0).cpu().numpy()
    chaos_std = chaos[..., n, dof].std(dim=0).cpu().numpy()

    swag_pred_mean = swag_pred[..., n, dof].mean(dim=0).cpu().numpy()
    swag_pred_std = swag_pred[..., n, dof].std(dim=0).cpu().numpy()

    de_pred_mean = de_pred[..., n, dof].mean(dim=0).cpu().numpy()
    de_pred_std = de_pred[..., n, dof].std(dim=0).cpu().numpy()
    
    true_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': ref[:, n, dof],
    })).mark_line(color='black',strokeDash=[5,5]).encode(x='t:Q', y=alt.Y('y:Q'))

    chaos_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': chaos_mean,
        'y_lo': (chaos_mean - 2. * chaos_std),
        'y_hi': (chaos_mean + 2. * chaos_std),
    })).mark_line(color='blue', opacity=0.5).encode(x='t:Q', y='y:Q')
    chaos_err_chart = chaos_mean_chart.mark_area(opacity=0.2, color='blue').encode(y='y_lo', y2='y_hi')

    swag_pred_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': swag_pred_mean,
        'y_lo': (swag_pred_mean - 2. * swag_pred_std),
        'y_hi': (swag_pred_mean + 2. * swag_pred_std),
    })).mark_line(color='red', opacity=0.5).encode(x='t:Q', y='y:Q')
    swag_pred_err_chart = swag_pred_mean_chart.mark_area(opacity=0.2, color='red').encode(y=alt.Y('y_lo', title=label[dof]), y2='y_hi')

    de_pred_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': de_pred_mean,
        'y_lo': (de_pred_mean - 2. * de_pred_std),
        'y_hi': (de_pred_mean + 2. * de_pred_std),
    })).mark_line(color='green', opacity=0.5).encode(x='t:Q', y='y:Q')
    de_pred_err_chart = de_pred_mean_chart.mark_area(opacity=0.2, color='green').encode(y=alt.Y('y_lo', title=label[dof]), y2='y_hi')

    
    body_chart = (true_chart + chaos_mean_chart + chaos_err_chart) | \
                 (true_chart + swag_pred_mean_chart + swag_pred_err_chart) | \
                 (true_chart + de_pred_mean_chart + de_pred_err_chart)

    return body_chart.properties(title=f'Body Mass {n + 1}')
    
#     for n in range(N):
#         body_chart = None
#         for dof in range(2):
#             chaos_mean = chaos[..., n, dof].mean(dim=0).cpu().numpy()
#             chaos_std = chaos[..., n, dof].std(dim=0).cpu().numpy()

#             pred_mean = pred[..., n, dof].mean(dim=0).cpu().numpy()
#             pred_std = pred[..., n, dof].std(dim=0).cpu().numpy()
            
#             true_chart = alt.Chart(pd.DataFrame({
#                 't': ts.cpu().numpy(),
#                 'y': ref[:, n, dof],
#             })).mark_line(color='black',strokeDash=[5,5]).encode(x='t:Q', y=alt.Y('y:Q'))

#             chaos_mean_chart = alt.Chart(pd.DataFrame({
#                 't': ts.cpu().numpy(),
#                 'y': chaos_mean,
#                 'y_lo': (chaos_mean - 2. * chaos_std),
#                 'y_hi': (chaos_mean + 2. * chaos_std),
#             })).mark_line(color='blue', opacity=0.5).encode(x='t:Q', y='y:Q')
#             chaos_err_chart = chaos_mean_chart.mark_area(opacity=0.2, color='blue').encode(y='y_lo', y2='y_hi')

#             pred_mean_chart = alt.Chart(pd.DataFrame({
#                 't': ts.cpu().numpy(),
#                 'y': pred_mean,
#                 'y_lo': (pred_mean - 2. * pred_std),
#                 'y_hi': (pred_mean + 2. * pred_std),
#             })).mark_line(color='red', opacity=0.5).encode(x='t:Q', y='y:Q')
#             pred_err_chart = pred_mean_chart.mark_area(opacity=0.2, color='red').encode(y=alt.Y('y_lo', title=label[dof]), y2='y_hi')

#             if dof == 0:
#                 body_chart = true_chart + chaos_mean_chart + chaos_err_chart + pred_mean_chart + pred_err_chart
#             else:
#                 body_chart |= true_chart + chaos_mean_chart + chaos_err_chart + pred_mean_chart + pred_err_chart

#         yield body_chart.properties(title=f'Body Mass {n + 1}')


In [31]:
left_idx = 2
right_idx = 0
left = plot_trace(true_zt[left_idx, :, 0, ...], true_zt_chaos[:, left_idx, :, 0, ...], swag_pred_zt[:, left_idx, :, 0, ...], de_pred_zt[:, left_idx, :, 0, ...])
right = plot_trace(true_zt[right_idx, :, 0, ...], true_zt_chaos[:, right_idx, :, 0, ...], swag_pred_zt[:, right_idx, :, 0, ...], de_pred_zt[:, right_idx, :, 0, ...])
left & right

In [19]:
traj_idx = 0

all_plot = None
for pl in plot_trace(true_zt[traj_idx, :, 0, ...], true_zt_chaos[:, traj_idx, :, 0, ...], swag_pred_zt[:, traj_idx, :, 0, ...], de_pred_zt[:, traj_idx, :, 0, ...]):
    all_plot = pl if all_plot is None else (all_plot & pl)

## TODO: need to decide the exact layout here.
all_plot