In [None]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [None]:
swag_dump_path = '../.log/de-3body/files/data.pt'
de_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(swag_dump_path)
assert os.path.isfile(de_dump_path)

swag_dump = torch.load(swag_dump_path)
de_dump = torch.load(de_dump_path)

ts = swag_dump.get('ts')
z0_orig = swag_dump.get('z0_orig')
true_zt = swag_dump.get('true_zt')
true_zt_chaos = swag_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
swag_pred_zt = swag_dump.get('pred_zt')
de_pred_zt = de_dump.get('pred_zt')

assert true_zt_chaos.shape[1:] == true_zt.shape
assert true_zt_chaos.shape == swag_pred_zt.shape
assert de_pred_zt.shape == swag_pred_zt.shape

true_zt_chaos.shape

In [None]:
def compute_mrse(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    delta_z = ref - pred  # M x N x T x 2 x B x 2
    all_err = delta_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    return all_err

In [None]:
det_mrse = compute_mrse(true_zt, de_pred_zt[0])
de_mrse = compute_mrse(true_zt, de_pred_zt.mean(dim=0))
swag_mrse = compute_mrse(true_zt, swag_pred_zt.mean(dim=0))

assert det_mrse.shape == de_mrse.shape
assert det_mrse.shape == swag_mrse.shape

det_mrse.shape

In [None]:
det_mean = det_mrse.mean(0).cpu().numpy()
det_std = det_mrse.std(0).cpu().numpy()

det_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': det_mean,
    'y_hi': det_mean + 1. * det_std,
    'y_lo': np.clip(det_mean - 1. * det_std, 0.0, np.inf)
})).mark_line(color='gray').encode(x='t', y=alt.Y('y', title='Mean Root Squared Error'))
det_err_chart = det_mean_chart.mark_area(color='gray', opacity=.25).encode(y='y_hi', y2='y_lo')
det_chart = det_err_chart + det_mean_chart

de_mean = de_mrse.mean(0).cpu().numpy()
de_std = de_mrse.std(0).cpu().numpy()

de_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': de_mean,
    'y_hi': de_mean + 2. * de_std,
    'y_lo': np.clip(de_mean - 2. * de_std, 0.0, np.inf),
})).mark_line(color='blue',opacity=0.8).encode(x='t', y='y')
de_hi_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
de_lo_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
de_chart = de_mean_chart + de_hi_chart + de_lo_chart

swag_mean = swag_mrse.mean(0).cpu().numpy()
swag_std = swag_mrse.std(0).cpu().numpy()

swag_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': swag_mean,
    'y_hi': swag_mean + 2. * swag_std,
    'y_lo': np.clip(swag_mean - 2. * swag_std, 0.0, np.inf),
})).mark_line(color='red',opacity=.8).encode(x='t', y='y')
swag_hi_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
swag_lo_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
swag_chart = swag_mean_chart + swag_hi_chart + swag_lo_chart

de_chart + swag_chart + det_chart