Use this notebook as the central place to collect all data dumps and process them into figures for _Relative error plots_.

In [None]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [None]:
## TODO: change this to the dump path.
data_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(data_dump_path)

data_dump = torch.load(data_dump_path)

ts = data_dump.get('ts')
z0_orig = data_dump.get('z0_orig')
true_zt = data_dump.get('true_zt')
true_zt_chaos = data_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
pred_zt = data_dump.get('pred_zt')

In [None]:
def compute_rel_error(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    delta_z = ref.unsqueeze(0) - pred  # M x N x T x 2 x B x 2
    all_err = delta_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    sum_z = ref.unsqueeze(0) + pred  # M x N x T x 2 x B x 2
    pred_rel_err = all_err / sum_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    return pred_rel_err

In [None]:
true_rel_err = compute_rel_error(true_zt, true_zt_chaos)
pred_rel_err = compute_rel_error(true_zt, pred_zt)

In [None]:
def plot_rel_err(rel_err, color):
    y_mean = rel_err.mean(0).cpu().numpy()
    y_std = rel_err.std(0).cpu().numpy()

    err_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': y_mean,
        'y_hi': y_mean + 2. * y_std,
        'y_lo': np.clip(y_mean - 2. * y_std, 0.0, np.inf),
    })).mark_line(color=color, opacity=0.7).encode(x='t', y='y')
    err_std_chart = err_mean_chart.mark_area(color=color, opacity=0.2).encode(
        x='t', y=alt.Y('y_lo', title='Relative Error'), y2='y_hi')

    return err_mean_chart + err_std_chart


In [None]:
idx = 0
rel_err_chart = plot_rel_err(true_rel_err[:, idx, ...], 'blue') + plot_rel_err(pred_rel_err[:, idx, ...], 'red')
# rel_err_chart.save('chart.json')