In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import numpy as onp
import flax.linen as nn
import matplotlib.pyplot as plt
import optax
import diffrax
import distrax
import sde.markov_approximation as ma
from sde.models import FractionalSDE, VideoSDE, StaticFunction
from sde import data
from sde.util import NumpyLoader
from moviepy.editor import ImageSequenceClip
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
import typing
import wandb
import imageio
import pandas as pd
import seaborn as sns
from sde.train import build_data_and_model, Drift, Diffusion, ControlFunction
solver = diffrax.StratonovichMilstein()
from sde.util import NumpyLoader

def save_strip(v, path):
    v = v.squeeze()
    m = 2
    num_frames = len(v)
    canvas = onp.ones((64, num_frames * (64 + m) - m), dtype=onp.uint8) * 255
    for f, frame in enumerate(v):
        j0 = f * (64 + m)
        j1 = j0 + 64
        canvas[:, j0:j1] = (frame * 255).clip(0, 255).astype(onp.uint8)
    imageio.imwrite(path, canvas)

In [None]:
api = wandb.Api()
run_ids = [
    # insert run paths here
    # e.g. 'name/jax-mmnist/abcd1234',
]
runs = [{'run_id': run_id, 'run': api.run(run_id)} for run_id in run_ids]
for run in runs:
    run.update({'cfg': run['run'].config})

In [None]:
def get_model_and_params(run):
    cfg = run['cfg']
    run['run'].file('params.p').download(replace=True)
    with open('params.p', 'rb') as f:
        params = pickle.load(f)

    ts, dt, data_train, data_val, model, _ = build_data_and_model(
        cfg['dataset'],
        cfg['white'],
        cfg['num_latents'],
        cfg['num_contents'],
        cfg['num_features'],
        cfg['num_k'],
        cfg['gamma_max'],
        cfg['int_sub_steps'],
    )
    if cfg['white']:
        num_k = 1
        gamma = None
        hurst = - 1
    else:
        num_k = cfg['num_k']
        gamma = ma.gamma_by_gamma_max(cfg['num_k'], cfg['gamma_max'])
        hurst = None
    run.update({
        'params': params,
        'ts': ts,
        'dt': dt,
        'data_train': data_train,
        'data_val': data_val,
        'model': model,
        'num_k': num_k,
        'gamma': gamma,
        'hurst': hurst,
    })
for run in runs:
    get_model_and_params(run)

data_val = runs[0]['data_val'] # use the same validation set

### Validation ELBO

In [None]:
dataloader = NumpyLoader(data_val, batch_size=32, shuffle=True, num_workers=8, drop_last=True)

def calculate_validation_elbo(run, model, params, ts, dt, **kwargs):
    def loss_fn(params, key, frames):
        frames_, (kl_x0, logpath) = model(params, key, ts, frames, dt, solver)
        nll = ((frames - frames_) ** 2).sum()
        loss = nll + 1. * (kl_x0 + logpath)
        return loss, (nll, kl_x0, logpath)
    @jax.jit
    def batched_loss_fn(params, key, frames):
        keys = jax.random.split(key, 32)
        loss, aux = jax.vmap(loss_fn, (None, 0, 0))(params, keys, frames)
        return loss.mean(), jax.tree_util.tree_map(jnp.mean, aux)

    random_key = jax.random.PRNGKey(7)
    elbos = []
    for frames in tqdm(dataloader):
        random_key, key = jax.random.split(random_key, 2)
        elbo, aux = batched_loss_fn(params, key, frames)
        elbos.append(elbo)
    print(run.id, onp.mean(elbos))

for run in runs:
    calculate_validation_elbo(**run)

### Inference on validation set

In [None]:
i = 1
key = jax.random.PRNGKey(7)
frames = data_val[i]

def inference(run, cfg, model, params, ts, dt, **kwargs):
    print(run.id, cfg)
    frames_, _ = model(params, key, ts, frames, dt, diffrax.StratonovichMilstein())

    save_strip(frames[::1], f'strips/{i}_true.png')
    save_strip(frames_[::1], f'strips/{i}_posterior_{run.id}.png')
    v = jnp.concatenate([frames, frames_], axis=2)
    v = (v * 255).clip(0, 255).astype(onp.uint8).repeat(3, axis=-1)
    return ImageSequenceClip(list(v), fps=5).ipython_display()

In [None]:
inference(**runs[0])

In [None]:
inference(**runs[1])

### Prior Model

In [None]:
# condition on
i = 3
frames = data_val[i]

In [None]:
# this block will use the 'frames' from the block above, i.e. it conditions on that.

def prior_predictions(run, model, params, hurst, ts, dt, cfg, gamma, **kwargs):
    num_samples = 4
    random_key = jax.random.PRNGKey(42)

    h = model.encoder(params, frames)
    w = model.content(params, h)
    x0_posterior, h = model.infer(params, h)
    key, random_key = jax.random.split(random_key, 2)
    x0 = x0_posterior.sample(seed=key, sample_shape=(num_samples,))

    b = Drift(cfg['num_latents'])
    u = StaticFunction(lambda *args: jnp.zeros(cfg['num_latents']))
    s = Diffusion(cfg['num_latents'])
    sde = FractionalSDE(b, u, s, gamma, hurst=hurst, type=1, time_horizon=ts[-1], num_latents=cfg['num_latents'])
    x0_prior = distrax.MultivariateNormalDiag(jnp.zeros(cfg['num_latents']), jnp.ones(cfg['num_latents']))
    prior_model = VideoSDE(model.image_size, model.num_channels, cfg['num_features'], cfg['num_latents'], cfg['num_contents'], x0_prior, True, sde)

    keys = jax.random.split(random_key, num_samples)
    xs_prior, _ = jax.vmap(prior_model.sde, (None, 0, 0, None, None, None, None))(params, keys, x0, ts, dt, solver, {})
    w_t = w[None, :].repeat(len(ts), axis=0)
    w_tb = w_t[None, :].repeat(num_samples, axis=0)
    frames_prior = jax.vmap(prior_model.decoder, (None, 0))(params, jnp.concatenate([w_tb, xs_prior], axis=-1))

    for s, sample in enumerate(frames_prior):
        save_strip(sample[::1], f'strips/{i}_prediction_{run.id}_sample_{s}.png')
    v = jnp.concatenate(frames_prior, axis=2)
    v = (frames_prior * 255).clip(0, 255).astype(onp.uint8).repeat(3, axis=-1)
    return v

In [None]:
v = prior_predictions(**runs[0])
ImageSequenceClip(list(v), fps=5).ipython_display()

In [None]:
v = prior_predictions(**runs[1])
ImageSequenceClip(list(v), fps=5).ipython_display()

### Best Prediction

In [None]:
# condition on
i = 3
frames = data_val[i]

In [None]:
class ControlFunctionMask(ControlFunction):
    def __call__(self, params, t, x, y, args):
        context = args['context']
        h = jax.vmap(jnp.interp, (None, None, 1))(t, context['ts'], context['hs'])
        output = self.mlp.apply(params, jnp.concatenate([x, y.flatten(), h], axis=-1))
        return jnp.where(
            t > context['ts'][-1],
            jnp.zeros(self.num_latents),  # no control after context -> prior
            output,
        )

def best_predictions(run, model, params, hurst, ts, dt, cfg, gamma, **kwargs):
    num_samples = 5
    num_conditioned = 5
    random_key = jax.random.PRNGKey(42)

    h = model.encoder(params, frames[:num_conditioned])
    w = model.content(params, h)
    x0_posterior, h = model.infer(params, h)
    key, random_key = jax.random.split(random_key, 2)
    x0 = x0_posterior.sample(seed=key, sample_shape=(num_samples,))
    context = {'ts': ts[:num_conditioned], 'hs': h}

    b = Drift(cfg['num_latents'])
    u = ControlFunctionMask(cfg['num_k'], cfg['num_latents'], cfg['num_features'])
    s = Diffusion(cfg['num_latents'])
    sde = FractionalSDE(b, u, s, gamma, hurst=hurst, type=1, time_horizon=ts[-1], num_latents=cfg['num_latents'])
    x0_prior = distrax.MultivariateNormalDiag(jnp.zeros(cfg['num_latents']), jnp.ones(cfg['num_latents']))
    prior_model = VideoSDE(model.image_size, model.num_channels, cfg['num_features'], cfg['num_latents'], cfg['num_contents'], x0_prior, True, sde)

    keys = jax.random.split(random_key, num_samples)
    xs_prior, _ = jax.vmap(prior_model.sde, (None, 0, 0, None, None, None, None))(params, keys, x0, ts, dt, solver, {'context': context})
    w_t = w[None, :].repeat(len(ts), axis=0)
    w_tb = w_t[None, :].repeat(num_samples, axis=0)
    frames_pred = jax.vmap(prior_model.decoder, (None, 0))(params, jnp.concatenate([w_tb, xs_prior], axis=-1))
    return frames_pred

In [None]:
frames_pred = best_predictions(**runs[0])
errors = ((frames[None] - frames_pred) ** 2).sum(axis=(1, 2, 3, 4))
best_i = jnp.argmin(errors)
v = jnp.concatenate([frames, frames_pred[best_i]], axis=1)
v = (v * 255).clip(0, 255).astype(onp.uint8).repeat(3, axis=-1)
ImageSequenceClip(list(v), fps=5).ipython_display()

In [None]:
frames_pred = best_predictions(**runs[1])
errors = ((frames[None] - frames_pred) ** 2).sum(axis=(1, 2, 3, 4))
best_i = jnp.argmin(errors)
v = jnp.concatenate([frames, frames_pred[best_i]], axis=1)
v = (v * 255).clip(0, 255).astype(onp.uint8).repeat(3, axis=-1)
ImageSequenceClip(list(v), fps=5).ipython_display()

In [None]:
def get_calculate_error_fn(run, model, params, hurst, ts, dt, cfg, gamma, **kwargs):
    num_samples = 100
    num_conditioned = 5

    @jax.jit
    def calculate_error(frames, key):
        keys = jax.random.split(key, 2)
        h = model.encoder(params, frames[:num_conditioned])
        w = model.content(params, h)
        x0_posterior, h = model.infer(params, h)
        x0 = x0_posterior.sample(seed=keys[0], sample_shape=(num_samples,))
        context = {'ts': ts[:num_conditioned], 'hs': h}

        b = Drift(cfg['num_latents'])
        u = ControlFunctionMask(cfg['num_k'], cfg['num_latents'], cfg['num_features'])
        s = Diffusion(cfg['num_latents'])
        sde = FractionalSDE(b, u, s, gamma, hurst=hurst, type=1, time_horizon=ts[-1], num_latents=cfg['num_latents'])
        x0_prior = distrax.MultivariateNormalDiag(jnp.zeros(cfg['num_latents']), jnp.ones(cfg['num_latents']))
        prior_model = VideoSDE(model.image_size, model.num_channels, cfg['num_features'], cfg['num_latents'], cfg['num_contents'], x0_prior, True, sde)

        keys = jax.random.split(keys[1], num_samples)
        xs_prior, _ = jax.vmap(prior_model.sde, (None, 0, 0, None, None, None, None))(params, keys, x0, ts, dt, solver, {'context': context})
        w_t = w[None, :].repeat(len(ts), axis=0)
        w_tb = w_t[None, :].repeat(num_samples, axis=0)
        frames_pred = jax.vmap(prior_model.decoder, (None, 0))(params, jnp.concatenate([w_tb, xs_prior], axis=-1))
        errors = ((frames[None] - frames_pred) ** 2).mean(axis=(2, 3, 4))
        errors = jnp.min(errors, axis=0) # take the best one per timestep
        return errors
    return calculate_error

In [None]:
random_key = jax.random.PRNGKey(7)
error_fns = [get_calculate_error_fn(**run) for run in runs]

df = []
for i in tqdm(range(len(data_val))):
    frames = data_val[i]
    for run, error_fn in zip(runs, error_fns):
        random_key, key = jax.random.split(random_key, 2)
        errors = error_fn(frames, key)
        for f, error in enumerate(errors):
            df.append({
                'data_i': i,
                'model': run['run'].id,
                'error': float(error),
                'frame': float(f),
                'white': run['cfg']['white'],
            })
df = pd.DataFrame(df)
df['psnr'] = df.apply(lambda row: 10 * onp.log10(1/row.error), axis=1)

In [None]:
df.to_pickle('psnr.p')

In [None]:
g = sns.lineplot(df, x='frame', y='psnr', hue='white')
g.set(xlim=(4, 25))
g.set(ylim=(12, 24))

In [None]:
df_all_frames = df.groupby(['data_i', 'model'])['error'].mean()
# df_all_frames['psnr'] = df_all_frames.apply(lambda row: 10 * onp.log10(1/), axis=1)

In [None]:
df_all_frames_psnr = 10 * onp.log10(1/df_all_frames)
df_all_frames_psnr.groupby('model').mean()

In [None]:
df_all_frames_psnr.groupby('model').std()