In [None]:
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set_theme()
sns.set_context('paper')

%matplotlib inline

In [None]:
dump_dir = os.path.abspath('../.log')

def parse_dump(dump):
    ts = dump.get('ts')
    # z0_orig = dump.get('z0_orig')
    true_zt = dump.get('true_zt')
    true_zt_chaos = dump.get('true_zt_chaos')
    pred_zt = dump.get('pred_zt')
    pred_zt_chaos = dump.get('pred_zt_chaos')

    return ts, true_zt, true_zt_chaos, pred_zt, pred_zt_chaos

def download_runs(sweep_id):
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        download_root = os.path.join(dump_dir, sweep.name, run.name)
        for f in run.files():
            if f.name == 'data.pt':
                print(run.name)
                fpath = os.path.join(download_root, f.name)
                if not os.path.isfile(fpath):
                    f.download(root=download_root)
                yield run.name, run.config, parse_dump(torch.load(fpath))

## Method Colors for consistency

In [None]:
method_color = dict(
    ou=(.2,.6,.4,.75),
    de=(.2,.2,1.,.75),
    swag=(1.,.2,.2,.75)
)

## Metric computations

Everything that needs to be computed for final graphs.

In [None]:
def compute_rel_error(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
        ref: N x T x 2 x B x 2
        pred: M x N x T x 2 x B x 2
    '''
    delta_z = ref.unsqueeze(0) - pred  # M x N x T x 2 x B x 2
    all_err = delta_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    sum_z = ref.unsqueeze(0) + pred  # M x N x T x 2 x B x 2
    pred_rel_err = all_err / sum_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    return pred_rel_err

def compute_geom_mean(ts, values):
    '''Geometric mean of a continuous function over time.
    Arguments:
        ts: T
        values: ... x T
    '''
    t_range = ts.max() - ts.min()
    return torch.trapz((values + 1e-8).log(), ts).div(t_range).exp()  # ...

def compute_likelihood(ref, pred):
    '''
    Likelihood of the reference under Gaussian estimated
    by the samples, factored over time.
    Arguments:
        ref: N x T x 2 x B x 2
        pred: M x N x T x 2 x B x 2
    '''
    batch_shape = pred.shape[:3]

    pred_mu = pred.view(*batch_shape, -1).mean(dim=0)
    pred_std = pred.view(*batch_shape, -1).std(dim=0) + 1e-6
    pred_std[torch.isnan(pred_std)] = 1e-6
    pred_dist = torch.distributions.MultivariateNormal(pred_mu, pred_std.diag_embed())

    log_prob = pred_dist.log_prob(ref.view(*ref.shape[:2], -1))  ## N x T
    return log_prob

def compute_lad_fit(ts, rel_err):
    '''Returns the line
    ts: T
    rel_err: N x T
    '''
    coef = np.polyfit(ts.cpu().numpy(), rel_err.cpu().numpy().T, 1)  # 2 x N
    coef = torch.from_numpy(coef).float()
    line = coef[0].unsqueeze(-1) * ts.unsqueeze(0) + coef[1].unsqueeze(-1)  # N x T
    return line

## Trace Plots

In [None]:
def trace_plot(t, y, ax, y_std=None, min_y=-np.inf, max_y=np.inf,
               kind='region', color='black',
               plt_args=None):
    '''
    Arguments:
        zt: Assumes zeroth dimension is number of samples.
        ax: Matplotlib Axis
        pos: Positive quantity flag (for range bounds)
    '''
    assert kind in ['region', 'bound']
    plt_args = plt_args or dict()

    mu = np.mean(y, axis=0)
    std = np.std(y, axis=0) if y_std is None else y_std

    ax.plot(t, mu, c=color, **plt_args)

    if y.shape[0] == 1:
        return

    lower = np.clip(mu - 2. * std, min_y, max_y)
    upper = np.clip(mu + 2. * std, min_y, max_y)

    if kind == 'region':
        ax.fill_between(ts, lower, upper, color=color, alpha=0.2)
    elif kind == 'bound':
        ax.plot(np.array([t, t]).T, np.array([lower, upper]).T, c=color, dashes=[8,4])
    else:
        raise NotImplementedError

def generate_trace_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=5, method=None, save=False):
    '''Generate trace plots for n random initial conditions.
    '''
    assert method is not None
    pcolor = method_color[method]

    n_samples, n_init, T, z_dim, n_body, n_dof = true_zt_chaos.shape

    if isinstance(idx, int):
        idx = np.random.permutation(n_init)[:idx]
    true_zt = true_zt[idx]
    true_zt_chaos = true_zt_chaos[:, idx, ...]
    pred_zt = pred_zt[:, idx, ...]

    for init_id in range(len(idx)):
        for b_id in range(n_body):
            df = []
            for s_id in range(n_samples):
                for dof_id in range(n_dof):
                    for t in range(T):
                        df.append(('chaos', s_id, dof_id,
                                    ts[t], true_zt_chaos[s_id, init_id, t, 0, b_id, dof_id].item()))

            df = pd.DataFrame(df, columns=['method', 's_id', 'dof_id', 't', 'v'])

            g = sns.relplot(data=df, x='t', y='v', row='dof_id', ci=95, kind='line', color='gray')
            # g = sns.relplot(data=df, x='t', y='v', col='dof_id', ci=95, kind='line', color='gray')

            dof_label = ['x', 'y']
            for dof_id in range(n_dof):
                g.axes[dof_id][0].clear()  ## hack to clear one std dev plots from sns, but retain axis.

                g.axes[dof_id][0].set_title(f'Body {b_id + 1} $\mid$ Dimension {dof_label[dof_id]}')
                g.axes[dof_id][0].set_xlabel('t')
                
                trace_plot(ts, true_zt[np.newaxis, init_id, :, 0, b_id, dof_id], g.axes[dof_id,0],
                           color=(.2,.2,.2), plt_args=dict(dashes=[4,2]))
                
                trace_plot(ts, true_zt_chaos[:, init_id, :, 0, b_id, dof_id], g.axes[dof_id,0],
                           color=(.6,.6,.6))
                
                trace_plot(ts, pred_zt[:, init_id, :, 0, b_id, dof_id], g.axes[dof_id,0], kind='bound', color=pcolor)

            plt.show(g.fig)
            if save:
                g.fig.savefig(f'chnn-{method}-traj{idx[init_id] + 1}-body{b_id + 1}-xy.pdf')

In [None]:
assert False # comment to exec.

## Can use integer to plot random traces, or a fixed list of traces.
plot_idx = [20]
should_save = False

## CHNN-OU
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/1w572o1p'):
    ts = ts.numpy()
    true_zt = true_zt.numpy()
    true_zt_chaos = true_zt_chaos.numpy()
    pred_zt = pred_zt.numpy()
    generate_trace_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='ou', save=should_save)

## CHNN-DE
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/5mjska2l'):
    ts = ts.numpy()
    true_zt = true_zt.numpy()
    true_zt_chaos = true_zt_chaos.numpy()
    pred_zt = pred_zt.numpy()
    generate_trace_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='de', save=should_save)

## CHNN-SWAG
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/9hy4ksyf'):
    ts = ts.numpy()
    true_zt = true_zt.numpy()
    true_zt_chaos = true_zt_chaos.numpy()
    pred_zt = pred_zt.numpy()
    generate_trace_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='swag', save=should_save)

## Relative Error Plots

In [None]:
def generate_rel_err_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=5, method=None, save=False):
    '''Generate trace plots for n random initial conditions.
    '''
    assert method is not None
    pcolor = method_color[method]

    n_samples, n_init, T, z_dim, n_body, n_dof = true_zt_chaos.shape

    pred_rel_err = compute_rel_error(true_zt, pred_zt)  # M x N x T
    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos)  # M x N x T

    ts = ts.numpy()
    if isinstance(idx, int):
        idx = np.random.permutation(n_init)[:idx]
    pred_rel_err = pred_rel_err[:, idx, ...].numpy()
    chaos_rel_err = chaos_rel_err[:, idx, ...].numpy()

    for init_id in range(len(idx)):
        g = sns.relplot(x=ts, y=np.mean(chaos_rel_err, axis=0)[init_id], kind='line', color='gray')

        trace_plot(ts, chaos_rel_err[:, init_id, :], g.axes[0][0],
                   color=(.6,.6,.6), min_y=0.)
        trace_plot(ts, pred_rel_err[:, init_id, :], g.axes[0][0],
                   color=pcolor, min_y=0., kind='bound')

        g.axes[0][0].set_xlabel('t')
        g.axes[0][0].set_ylabel('$\delta(t)$')
        g.axes[0][0].set_title(f'Relative Error (Trajectory {idx[init_id] + 1})')

        plt.show(g.fig)
        if save:
            g.fig.savefig(f'chnn-{method}-traj{idx[init_id] + 1}-relerr.pdf')

In [None]:
assert False # comment to exec.

## Can use integer to plot random traces, or a fixed list of traces.
plot_idx = [20]
should_save = False

## CHNN-OU
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/1w572o1p'):
    pred_rel_err = compute_rel_error(true_zt, pred_zt)  # M x N x T
    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos)  # M x N x T

    generate_rel_err_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='ou', save=should_save)

## CHNN-DE
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/5mjska2l'):
    pred_rel_err = compute_rel_error(true_zt, pred_zt)  # M x N x T
    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos)  # M x N x T

    generate_rel_err_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='de', save=should_save)

## CHNN-SWAG
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/9hy4ksyf'):
    pred_rel_err = compute_rel_error(true_zt, pred_zt)  # M x N x T
    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos)  # M x N x T

    generate_rel_err_plots(ts, true_zt, true_zt_chaos, pred_zt, idx=plot_idx, method='swag', save=should_save)

## TODO: combine runs from other trajectories

## Log Average Error Plots

In [None]:
def generate_lae_plots(ts, true_zt, true_zt_chaos, pred_zt, n=5):
    '''Generate linear error fits
    '''
    n_samples, n_init, T, z_dim, n_body, n_dof = true_zt_chaos.shape
    
    pred_mean_rel_err =  compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True)).squeeze(0)  # N x T
    chaos_mean_rel_err =  compute_rel_error(true_zt, true_zt_chaos.mean(dim=0, keepdim=True)).squeeze(0)  # N x T
    # determ_rel_err =  compute_rel_error(true_zt, pred_zt[:1]).squeeze(0)  # N x T

    ## FIXME: should we take the log here?
    pred_lad_fit = compute_lad_fit(ts, pred_mean_rel_err.log())  # N x T
    chaos_lad_fit = compute_lad_fit(ts, chaos_mean_rel_err.log())  # N x T
    # determ_lad_fit = compute_lad_fit(ts, determ_rel_err)  # N x T

    idx = np.random.permutation(n_init)[:n]
    pred_mean_rel_err = pred_mean_rel_err[idx, ...]
    pred_lad_fit = pred_lad_fit[idx, ...]
    chaos_mean_rel_err = chaos_mean_rel_err[idx, ...]
    chaos_lad_fit = chaos_lad_fit[idx, ...]
    # determ_rel_err = determ_rel_err[idx, ...]
    # determ_lad_fit = determ_lad_fit[idx, ...]

    ts = ts.numpy()
    pred_mean_rel_err = pred_mean_rel_err.numpy()
    pred_lad_fit = pred_lad_fit.exp().numpy()
    chaos_mean_rel_err = chaos_mean_rel_err.numpy()
    chaos_lad_fit = chaos_lad_fit.exp().numpy()
    # determ_lad_fit = determ_lad_fit.numpy()

    for init_id in range(n):
        g = sns.relplot(x=ts, y=chaos_mean_rel_err[init_id], kind='line', color='gray')
        sns.lineplot(x=ts, y=chaos_lad_fit[init_id], color='gray', dashes=[4,4], ax=g.axes[0][0])

        sns.lineplot(x=ts, y=pred_mean_rel_err[init_id], color=(.2,.2,1.,.75), ax=g.axes[0][0])
        sns.lineplot(x=ts, y=pred_lad_fit[init_id], color=(.2,.2,1.,.75), dashes=[4,4], ax=g.axes[0][0])

        g.axes[0][0].set_xlabel('t')
        g.axes[0][0].set_ylabel('$\log{\delta(t)}$')
        g.axes[0][0].set_title(f'Log Error Fit (Trajectory {idx[init_id] + 1})')

        plt.show(g.fig)

In [None]:
# for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, *_) in download_runs(sweep_id='snym/phy-unc-exps/9hy4ksyf'):
#     generate_lae_plots(ts, true_zt, true_zt_chaos, pred_zt)

## TODO: combine runs from other trajectories. also fix me.

## Geometric Error

In [None]:
assert False

df = []

## CHNN-OU
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, _) in download_runs(sweep_id='snym/phy-unc-exps/1w572o1p'):
    uq_type = cfg.get('uq_type')
    pred_rel_err = compute_rel_error(true_zt, pred_zt)
    pred_geom_err = compute_geom_mean(ts, pred_rel_err).squeeze(0)

    df.extend([ (uq_type, i, pred_geom_err[i].item()) for i in range(pred_geom_err.size(0)) ])

## CHNN-DE
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, _) in download_runs(sweep_id='snym/phy-unc-exps/5mjska2l'):
    uq_type = cfg.get('uq_type')
    pred_rel_err = compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True))
    pred_geom_err = compute_geom_mean(ts, pred_rel_err.mean(dim=0, keepdim=True)).squeeze(0)

    df.extend([ (uq_type, i, pred_geom_err[i].item()) for i in range(pred_geom_err.size(0)) ])

## CHNN-SWAG
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, _) in download_runs(sweep_id='snym/phy-unc-exps/9hy4ksyf'):
    uq_type = cfg.get('uq_type')

    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos.mean(dim=0, keepdim=True))
    chaos_geom_err = compute_geom_mean(ts, chaos_rel_err.mean(dim=0, keepdim=True)).squeeze(0)

    pred_rel_err = compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True))
    pred_geom_err = compute_geom_mean(ts, pred_rel_err.mean(dim=0, keepdim=True)).squeeze(0)

    df.extend([ (uq_type, i, pred_geom_err[i].item()) for i in range(pred_geom_err.size(0)) ])
    df.extend([ ('chaos', i, pred_geom_err[i].item()) for i in range(chaos_geom_err.size(0)) ])

df = pd.DataFrame(df, columns=['method', 'traj_id', 'geom_err'])
g = sns.catplot(data=df, x='method', y='geom_err', kind='bar')
# g.fig.savefig('rel-err.pdf')
plt.show(g.fig)

## Likelihood Plots

In [None]:
def generate_likelihood_plot(ts, true_zt, true_zt_chaos, pred_zt, n=5):
    '''Generate trace plots for n random initial conditions.
    '''
    n_samples, n_init, T, z_dim, n_body, n_dof = true_zt_chaos.shape

    chaos_likelihood = compute_likelihood(true_zt, true_zt_chaos)  # N x T
    pred_likelihood = compute_likelihood(true_zt, pred_zt)  # N x T

    ts = ts.numpy()
    idx = np.random.permutation(n_init)[:n]
    chaos_likelihood = chaos_likelihood[idx].numpy()
    pred_likelihood = pred_likelihood[idx].numpy()

    for init_id in range(n):
        g = sns.relplot(x=ts, y=chaos_likelihood[init_id], kind='line', color='gray')

        sns.lineplot(x=ts, y=pred_likelihood[init_id], color=(.2,.2,1.,.75), ax=g.axes[0][0])

        g.axes[0][0].set_xlabel('t')
        g.axes[0][0].set_ylabel('$\mathcal{L}(t)$')
        g.axes[0][0].set_title(f'Likelihood over time (Trajectory {idx[init_id] + 1})')

        plt.show(g.fig)

In [None]:
for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt) in download_runs(sweep_id='snym/phy-unc-exps/9hy4ksyf'):
    generate_likelihood_plot(ts, true_zt, true_zt_chaos, pred_zt)

## TODO: combine runs from other trajectories

## Varying Training Data

In [None]:
df = []

label = {'swag': 'CHNN-SWAG', 'deep-ensemble': 'CHNN-DE'}

## CHNN-DE, CHNN-SWAG
for sweep_id in ['snym/phy-unc-exps/hrn4us3n', 'snym/phy-unc-exps/xnjmhsxf']:
    for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, _) in download_runs(sweep_id=sweep_id):
        uq_type = label[cfg.get('uq_type')]
        n_subsample = cfg.get('n_subsample')

        pred_rel_err = compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True))
        pred_geom_err = compute_geom_mean(ts, pred_rel_err.mean(dim=0, keepdim=True)).squeeze(0)

        df.extend([ (uq_type, n_subsample, i, pred_geom_err[i].item()) for i in range(pred_geom_err.size(0)) ])

df = pd.DataFrame(df, columns=['method', 'n', 'traj_id', 'geom_err'])
g = sns.catplot(data=df, x='n', y='geom_err', hue='method', kind='bar', legend=False)
g.ax.legend(loc='upper right')
g.ax.set_ylabel('Geom. mean of rel. error')
g.ax.set_xlabel('Num. of train samples')
g.fig.savefig('vary-data-plot.pdf')
plt.show(g.fig)

## Varying data noise corruption

In [None]:
df = []

label = {'swag': 'CHNN-SWAG', 'deep-ensemble': 'CHNN-DE'}

## CHNN-DE, CHNN-SWAG
for sweep_id in ['snym/phy-unc-exps/rxroo7zi', 'snym/phy-unc-exps/2zqxdb4f']:
    for name, cfg, (ts, true_zt, true_zt_chaos, pred_zt, _) in download_runs(sweep_id=sweep_id):
        uq_type = label[cfg.get('uq_type')]
        noise_rate = cfg.get('noise_rate')

        pred_rel_err = compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True))
        pred_geom_err = compute_geom_mean(ts, pred_rel_err.mean(dim=0, keepdim=True)).squeeze(0)

        df.extend([ (uq_type, noise_rate, i, pred_geom_err[i].item()) for i in range(pred_geom_err.size(0)) ])

df = pd.DataFrame(df, columns=['method', 'noise', 'traj_id', 'geom_err'])
g = sns.catplot(data=df, x='noise', y='geom_err', hue='method', kind='bar', legend=False)
g.ax.legend(loc='upper left')
g.ax.set_ylabel('Geom. mean of rel. error')
g.ax.set_xlabel('Noise std. deviation')
g.fig.savefig('vary-noise-plot.pdf')
plt.show(g.fig)