Use this notebook as the central place to collect all data dumps and process them into figures for _Relative error plots_.

In [1]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [2]:
swag_dump_path = '../.log/de-3body/files/data.pt'
de_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(swag_dump_path)
assert os.path.isfile(de_dump_path)

swag_dump = torch.load(swag_dump_path)
de_dump = torch.load(de_dump_path)

ts = swag_dump.get('ts')
z0_orig = swag_dump.get('z0_orig')
true_zt = swag_dump.get('true_zt')
true_zt_chaos = swag_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
swag_pred_zt = swag_dump.get('pred_zt')
de_pred_zt = de_dump.get('pred_zt')

assert true_zt_chaos.shape[1:] == true_zt.shape
assert true_zt_chaos.shape == swag_pred_zt.shape
assert de_pred_zt.shape == swag_pred_zt.shape

true_zt_chaos.shape

torch.Size([10, 5, 334, 2, 3, 2])

In [3]:
def compute_rel_error(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    delta_z = ref.unsqueeze(0) - pred  # M x N x T x 2 x B x 2
    all_err = delta_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    sum_z = ref.unsqueeze(0) + pred  # M x N x T x 2 x B x 2
    pred_rel_err = all_err / sum_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    return pred_rel_err

In [4]:
true_rel_err = compute_rel_error(true_zt, true_zt_chaos)
swag_rel_err = compute_rel_error(true_zt, swag_pred_zt)
de_rel_err = compute_rel_error(true_zt, de_pred_zt)

In [24]:
def plot_rel_err(ref, swag, de):
    '''
    ref: M x T
    swag: M x T
    de: M x T
    '''
    ref_mean = ref.mean(0).cpu().numpy()
    ref_std = ref.std(0).cpu().numpy()

    ref_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': ref_mean,
        'y_hi': ref_mean + 2. * ref_std,
        'y_lo': np.clip(ref_mean - 2. * ref_std, 0.0, np.inf),
    })).mark_line(color='gray').encode(x='t', y=alt.Y('y', title='Relative Error'))
    ref_err_chart = ref_mean_chart.mark_area(color='gray', opacity=0.25).encode(
        x='t', y=alt.Y('y_lo', title='Relative Error'), y2='y_hi')
    ref_chart = ref_err_chart + ref_mean_chart

    de_mean = de.mean(0).cpu().numpy()
    de_std = de.std(0).cpu().numpy()

    de_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': de_mean,
        'y_hi': de_mean + 2. * de_std,
        'y_lo': np.clip(de_mean - 2. * de_std, 0.0, np.inf),
    })).mark_line(color='blue',opacity=0.8).encode(x='t', y='y')
    de_hi_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
    de_lo_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
    de_chart = de_mean_chart + de_hi_chart + de_lo_chart

    swag_mean = swag.mean(0).cpu().numpy()
    swag_std = swag.std(0).cpu().numpy()

    swag_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': swag_mean,
        'y_hi': swag_mean + 2. * swag_std,
        'y_lo': np.clip(swag_mean - 2. * swag_std, 0.0, np.inf),
    })).mark_line(color='red',opacity=.8).encode(x='t', y='y')
    swag_hi_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
    swag_lo_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
    swag_chart = swag_mean_chart + swag_hi_chart + swag_lo_chart

    return de_chart + swag_chart + ref_chart

In [26]:
for eval_idx in range(5):
    pl = plot_rel_err(true_rel_err[:, eval_idx, ...], swag_rel_err[:, eval_idx, ...], de_rel_err[:, eval_idx, ...])
    # pl.save(f'eval{eval_idx + 1}_rel_err.json')