For a given set training noise, compare the relative error over time. This requires evaluation trajectory dumps.

In [None]:
import wandb
import torch
import numpy as np
# import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set()

In [None]:
dump_dir = os.path.abspath('../.log')

def parse_dump(dump):
    ts = dump.get('ts')
    # z0_orig = dump.get('z0_orig')
    true_zt = dump.get('true_zt')
    true_zt_chaos = dump.get('true_zt_chaos').permute(1, 0, 2, 3, 4, 5)
    pred_zt = dump.get('pred_zt')

    return ts, true_zt, true_zt_chaos, pred_zt

def download_runs(sweep_id):
    api = wandb.Api()
    sweep = api.sweep(sweep_id)
    for run in sweep.runs:
        download_root = os.path.join(dump_dir, sweep.name, run.name)
        for f in run.files():
            if f.name == 'data.pt':
                fpath = os.path.join(download_root, f.name)
                if not os.path.isfile(fpath):
                    f.download(root=download_root)
                yield run.name, run.config, parse_dump(torch.load(fpath))

## Metric computations

Everything that needs to be computed for final graphs.

In [None]:
def compute_rel_error(ref, pred):
    '''
    N is the number of initial conditions.
    M is the number of samples in prediction
    The first dimension "2" corresponds to position + velocity.
    B is the number of bodies.
    The last dimension "2" corresponds to xy.

    Arguments:
        ref: N x T x 2 x B x 2
        pred: M x N x T x 2 x B x 2
    '''
    delta_z = ref.unsqueeze(0) - pred  # M x N x T x 2 x B x 2
    all_err = delta_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    sum_z = ref.unsqueeze(0) + pred  # M x N x T x 2 x B x 2
    pred_rel_err = all_err / sum_z.pow(2).sum(dim=-1).sum(dim=-1).sum(dim=-1).sqrt()  # M x N x T

    return pred_rel_err

def compute_geom_mean(ts, values):
    '''Geometric mean of a continuous function over time.
    Arguments:
        ts: T
        values: ... x T
    '''
    t_range = ts.max() - ts.min()
    return torch.trapz((values + 1e-8).log(), ts).div(t_range).exp()  # ...

def compute_likelihood(ref, pred):
  '''
  Likelihood of the reference under Gaussian estimated
  by the samples, factored over time.
  Arguments:
    ref: N x T x 2 x B x 2
    pred: M x N x T x 2 x B x 2
  '''
  batch_shape = pred.shape[:3]

  pred_mu = pred.view(*batch_shape, -1).mean(dim=0)
  pred_std = pred.view(*batch_shape, -1).std(dim=0) + 1e-6
  pred_std[torch.isnan(pred_std)] = 1e-6
  pred_dist = torch.distributions.MultivariateNormal(pred_mu, pred_std.diag_embed())

  log_prob = pred_dist.log_prob(ref.view(*ref.shape[:2], -1))  ## N x T
  return log_prob

def compute_lad_fit(ts, rel_err):
    '''Returns the line
    ts: T
    rel_err: N x T
    '''
    coef = np.polyfit(ts.cpu().numpy(), rel_err.cpu().numpy().T, 1)  # 2 x N
    coef = torch.from_numpy(coef).float()
    line = coef[0].unsqueeze(-1) * ts.unsqueeze(0) + coef[1].unsqueeze(-1)  # N x T
    return line

def compute_metrics(ts, true_zt, true_zt_chaos, pred_zt):
    assert ts.shape == true_zt.shape[1:2]
    assert true_zt_chaos.shape[1:] == true_zt.shape
    assert true_zt_chaos.shape == pred_zt.shape
    
    T = ts.size(0)

    ## Uncollapsed relative error, for individual plots
    pred_rel_err = compute_rel_error(true_zt, pred_zt)  # M x N x T
    chaos_rel_err = compute_rel_error(true_zt, true_zt_chaos)  # M x N x T

    ## Relative error for the prediction (after BMA), for time-evolution plots.
    pred_mean_rel_err =  compute_rel_error(true_zt, pred_zt.mean(dim=0, keepdim=True)).squeeze(0)  # N x T
    chaos_mean_rel_err =  compute_rel_error(true_zt, true_zt_chaos.mean(dim=0, keepdim=True)).squeeze(0)  # N x T
    determ_rel_err =  compute_rel_error(true_zt, pred_zt[:1]).squeeze(0)  # N x T

    ## For bar plots of error comparison.
    geom_pred_err = compute_geom_mean(ts[:], pred_mean_rel_err[:, :])  # N
    geom_chaos_err = compute_geom_mean(ts[:], chaos_mean_rel_err[:, :])  # N
    geom_determ_err = compute_geom_mean(ts[:], determ_rel_err[:, :])  # N

    ## Relative error slope fits.
    pred_lad_fit = compute_lad_fit(ts, pred_mean_rel_err)  # N x T
    chaos_lad_fit = compute_lad_fit(ts, chaos_mean_rel_err)  # N x T
    determ_lad_fit = compute_lad_fit(ts, determ_rel_err)  # N x T

    ## For likelihood evolution over time.
    chaos_likelihood = compute_likelihood(true_zt, true_zt_chaos)  # N x T
    pred_likelihood = compute_likelihood(true_zt, pred_zt)  # N x T

    return dict(
        geom_pred_err=geom_pred_err.numpy(),
        geom_chaos_err=geom_chaos_err.numpy(),
        geom_determ_err=geom_determ_err.numpy()
    )

## Geometric Error

In [None]:
geom_err_df = []

for name, cfg, dump in download_runs(sweep_id='snym/physics-uncertainty-exps/hye2w7je'):
    noise_rate = cfg.get('noise_rate')
    uq_type = cfg.get('uq_type')

    metrics = compute_metrics(*dump)
    
    geom_pred_err = metrics.get('geom_pred_err')
    geom_chaos_err = metrics.get('geom_chaos_err')
    geom_determ_err = metrics.get('geom_determ_err')

    geom_err_df.extend([(uq_type, i, noise_rate, geom_pred_err[i]) for i in range(len(geom_pred_err))])

geom_err_df.extend([('chaos', i, 0., geom_chaos_err[i]) for i in range(len(geom_chaos_err))])
geom_err_df.extend([('determ', i, 0., geom_determ_err[i]) for i in range(len(geom_determ_err))])

geom_err_df = pd.DataFrame(geom_err_df, columns=['method', 'eval_id', 'noise_rate', 'geom_err'])

In [None]:
## geometric error of the mean prediction.
ax = sns.boxplot(data=geom_err_df[geom_err_df.noise_rate <= 0.], y='geom_err', x='method')