Use this notebook as the central place to collect all data dumps and process them into figures for _Likelihood plots_.

In [1]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [2]:
swag_dump_path = '../.log/de-3body/files/data.pt'
de_dump_path = '../.log/swag-3body/files/data.pt'

assert os.path.isfile(swag_dump_path)
assert os.path.isfile(de_dump_path)

swag_dump = torch.load(swag_dump_path)
de_dump = torch.load(de_dump_path)

ts = swag_dump.get('ts')
z0_orig = swag_dump.get('z0_orig')
true_zt = swag_dump.get('true_zt')
true_zt_chaos = swag_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
swag_pred_zt = swag_dump.get('pred_zt')
de_pred_zt = de_dump.get('pred_zt')

assert true_zt_chaos.shape[1:] == true_zt.shape
assert true_zt_chaos.shape == swag_pred_zt.shape
assert de_pred_zt.shape == swag_pred_zt.shape

true_zt_chaos.shape

torch.Size([10, 5, 334, 2, 3, 2])

In [3]:
def compute_log_likelihood(ref, pred):
    '''
    Log-Likelihood of the reference under Gaussian estimated
    by the pred samples, factored over time.
    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    ref = ref.view(*ref.shape[:-3], -1)
    pred = pred.view(*pred.shape[:-3], -1)

    pred_mu = pred.mean(dim=0)
    pred_std = pred.std(dim=0) + 1e-6

    # assert not torch.isnan(pred_std).any()

    pred_dist = torch.distributions.MultivariateNormal(pred_mu, pred_std.diag_embed())
    log_prob = pred_dist.log_prob(ref)  ## N x T
    return log_prob


In [4]:
#chaos_likelihood = compute_likelihood(true_zt, true_zt_chaos)
de_likelihood = compute_log_likelihood(true_zt, de_pred_zt)
swag_likelihood = compute_log_likelihood(true_zt, swag_pred_zt)
chaos_likelihood = compute_log_likelihood(true_zt, true_zt_chaos)
de_likelihood.shape

torch.Size([5, 334])

In [6]:
chaos_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': chaos_likelihood.mean(dim=0).cpu().numpy(),
    'y_lo': (chaos_likelihood.mean(dim=0) - 2. * chaos_likelihood.std(dim=0)).cpu().numpy(),
    'y_hi': (chaos_likelihood.mean(dim=0) + 2. * chaos_likelihood.std(dim=0)).cpu().numpy(),
})).mark_line(color='gray').encode(x='t', y='y')
chaos_err_chart = chaos_mean_chart.mark_area(color='gray',opacity=0.25).encode(
    y=alt.Y('y_lo', title='Log Likelihood'), y2='y_hi')
chaos_chart = chaos_err_chart + chaos_mean_chart

de_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': de_likelihood.mean(dim=0).cpu().numpy(),
    'y_lo': (de_likelihood.mean(dim=0) - 2. * de_likelihood.std(dim=0)).cpu().numpy(),
    'y_hi': (de_likelihood.mean(dim=0) + 2. * de_likelihood.std(dim=0)).cpu().numpy(),
})).mark_line(color='blue',opacity=0.8).encode(x='t', y='y')
de_hi_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
de_lo_chart = de_mean_chart.mark_line(color='blue',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
de_chart = de_mean_chart + de_hi_chart + de_lo_chart

swag_mean_chart = alt.Chart(pd.DataFrame({
    't': ts.cpu().numpy(),
    'y': swag_likelihood.mean(dim=0).cpu().numpy(),
    'y_lo': (swag_likelihood.mean(dim=0) - 2. * swag_likelihood.std(dim=0)).cpu().numpy(),
    'y_hi': (swag_likelihood.mean(dim=0) + 2. * swag_likelihood.std(dim=0)).cpu().numpy(),
})).mark_line(color='red',opacity=0.8).encode(x='t', y='y')
swag_hi_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_hi')
swag_lo_chart = swag_mean_chart.mark_line(color='red',opacity=0.8,strokeDash=[2,2]).encode(x='t', y='y_lo')
swag_chart = swag_mean_chart + swag_hi_chart + swag_lo_chart

# (swag_chart + de_chart + chaos_chart).save('chart.json')

swag_chart + de_chart + chaos_chart