In [25]:
import os
import sys
import torch
import altair as alt
import pandas as pd
import numpy as np

alt.data_transformers.disable_max_rows();


In [26]:
## TODO: change this to the dump path.
data_dump_path = '../../de-3body-data.pt'

assert os.path.isfile(data_dump_path)

data_dump = torch.load(data_dump_path)

ts = data_dump.get('ts')
z0_orig = data_dump.get('z0_orig')
true_zt = data_dump.get('true_zt')
true_zt_chaos = data_dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
de_pred_zt = data_dump.get('pred_zt')

data_dump_path = '../../swag-3body-data.pt'
data_dump = torch.load(data_dump_path)
swag_pred_zt = data_dump.get('pred_zt')

In [27]:
def compute_rel_error(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
    '''
    diffs = (pred.unsqueeze(0) - pred.unsqueeze(1)).mean(dim=(0,1)).pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()
    return (diffs + 1e-10).log()

In [28]:
true_rel_err = compute_rel_error(true_zt, true_zt_chaos)
de_pred_rel_err = compute_rel_error(true_zt, de_pred_zt)
swag_pred_rel_err = compute_rel_error(true_zt, swag_pred_zt)

In [58]:
def plot_fit_line(rel_err, color):
    y = rel_err[:167].cpu().numpy()
    t = ts[:167].cpu().numpy()

    m, b = np.polyfit(t, y, 1) 
    
    y_scale = alt.Scale(domain=(np.min(m * t + b) - 1, np.max(m * t + b) + 2))
    
    data_chart = alt.Chart(pd.DataFrame({
        't': t,
        'y': y,
    })).mark_line(color=color, opacity=1.0, strokeDash=[2,2]).encode(x='t', y=alt.Y('y', scale=y_scale)) 
    
    fit_chart = alt.Chart(pd.DataFrame({
        't': t,
        'y': m * t + b,
    })).mark_line(color=color, opacity=0.8).encode(x='t', y=alt.Y('y', scale=y_scale, title='Log Average Distance'))
    
    return data_chart + fit_chart

def plot_slope_point(rel_err, shape, color):
    y = rel_err.cpu().numpy()
    
    m, b = np.polyfit(ts.cpu().numpy()[:167], y[:167], 1) 
    
    if shape == "circle":
        point_chart = alt.Chart(pd.DataFrame({
            'm': [m],
            'b': [b],
        })).mark_point(color=color, opacity=0.7).encode(x='m', y='b') 
    elif shape == "square":
        point_chart = alt.Chart(pd.DataFrame({
            'm': [m],
            'b': [b],
        })).mark_square(color=color, opacity=0.7).encode(x='m', y='b')  
    elif shape == "tick":
        point_chart = alt.Chart(pd.DataFrame({
            'm': [m],
            'b': [b],
        })).mark_tick(color=color, opacity=0.7).encode(x='m', y='b')  
        
    return point_chart

In [30]:
np.shape(true_rel_err)

torch.Size([5, 334])

In [63]:
idx = 4
chart = plot_fit_line(true_rel_err[idx, :167], 'blue') + \
        plot_fit_line(swag_pred_rel_err[idx, :167], 'red') #+ \
#        plot_fit_line(de_pred_rel_err[idx, ...], 'green')
chart

In [32]:
colors = ['red','blue','green','black','magenta']
for idx in range(5):
    if idx == 0:
        chart = plot_slope_point(true_rel_err[idx, ...], 'circle', colors[idx]) + \
                plot_slope_point(swag_pred_rel_err[idx, ...], 'square', colors[idx]) + \
                plot_slope_point(de_pred_rel_err[idx, ...], 'tick', colors[idx])
    else:
        chart += plot_slope_point(true_rel_err[idx, ...], 'circle', colors[idx]) + \
                 plot_slope_point(swag_pred_rel_err[idx, ...], 'square', colors[idx]) + \
                 plot_slope_point(de_pred_rel_err[idx, ...], 'tick', colors[idx])
chart