Use this notebook as the central place to collect all data dumps and process them into figures for _Likelihood plots_.

In [None]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [None]:
## TODO: change this to the dump path.
data_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(data_dump_path)

data_dump = torch.load(data_dump_path)

ts = data_dump.get('ts')
z0_orig = data_dump.get('z0_orig')
true_zt = data_dump.get('true_zt')
true_zt_chaos = data_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
pred_zt = data_dump.get('pred_zt')

In [None]:
def compute_likelihood(ref, pred):
    '''
    Likelihood of the reference under Gaussian estimated
    by the pred samples, factored over time.
    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    batch_shape = pred.shape[:3]

    pred_mu = pred.view(*batch_shape, -1).mean(dim=0)
    pred_std = pred.view(*batch_shape, -1).std(dim=0) + 1e-6
    pred_dist = torch.distributions.MultivariateNormal(pred_mu, pred_std.diag_embed())

    log_prob = pred_dist.log_prob(ref.view(*ref.shape[:2], -1))  ## N x T
    return log_prob


In [None]:
  chaos_likelihood = compute_likelihood(true_zt, true_zt_chaos)
  pred_likelihood = compute_likelihood(true_zt, pred_zt)

In [None]:
def plot_likelihood(likelihood, color):
    y_mean = likelihood.mean(0).cpu().numpy()
    y_std = likelihood.std(0).cpu().numpy()

    err_mean_chart = alt.Chart(pd.DataFrame({
        't': ts.cpu().numpy(),
        'y': y_mean,
        'y_hi': y_mean + 2. * y_std,
        'y_lo': y_mean - 2. * y_std,
    })).mark_line(color=color, opacity=0.7).encode(x='t', y='y')
    err_std_chart = err_mean_chart.mark_area(color=color, opacity=0.2).encode(
        x='t', y=alt.Y('y_lo', title='Likelihood'), y2='y_hi')

    return err_mean_chart + err_std_chart


In [None]:
chart = plot_likelihood(chaos_likelihood, 'blue') + plot_likelihood(pred_likelihood, 'red')
# chart.save('chart.json')
chart