# Imports and setup
Modify `clargs` to point to the path of the model you wish to load.

In [None]:
clargs = '''--model_path ./results/phys101/001/model_400.pth \
    --num_inner_steps 1 \
    --n_future 20 \
    --horiz_flip \
    --test_set_length 78 \
    --train_set_length 311 \
    --reuse_lstm_eps \
    --seed 1612 \
    --data_root ./data/phys101/phys101/scenarios/ramp \
    --dataset phys101 \
    --n_past 2 \
    --tailor \
    --n_trials 1 \
    --only_twenty_degree \
    --frame_step 2 \
    --crop_upper_right 1080 \
    --center_crop 1080 \
    --batch_size 2 \
    --image_width 128 \
    --num_threads 4 '''

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches
import seaborn as sns
import scipy.stats as st

import shlex  # for clargs in Jupyter

from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
import argparse
import copy
import random
from torch.utils.data import DataLoader
import utils
import itertools
import progressbar
import numpy as np
from scipy.ndimage.filters import gaussian_filter

import higher
from lpips_pytorch import LPIPS

from models.forward import predict_many_steps, tailor_many_steps
from models.cn import replace_cn_layers, CNLayer
from models.svg import SVGModel
from models.embedding import ConservedEmbedding
from utils import svg_crit
print(f'PID: {os.getpid()}')

torch.cuda.set_device(0)


# NOTE: deterministic for debugging
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=100, type=int, help='batch size')
parser.add_argument('--data_root', default='data', help='root directory for data')
parser.add_argument('--model_path', default='', help='path to model')
parser.add_argument('--baseline_model_path', default='', help='path to model')
parser.add_argument('--seed', default=1, type=int, help='manual seed')
parser.add_argument('--n_past', type=int, default=2, help='number of frames to condition on')
parser.add_argument('--n_future', type=int, default=10, help='number of frames to predict')
parser.add_argument('--num_threads', type=int, default=0, help='number of data loading threads')
parser.add_argument('--dataset', default='bair', help='dataset to train with')
parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--g_dim', type=int, default=128, help='dimensionality of encoder output vector and decoder input vector')
parser.add_argument('--use_action', type=int, default=0, help='if true, train action-conditional model')
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--tailor', action='store_true', help='if true, perform tailoring')
parser.add_argument('--num_inner_steps', type=int, default=1, help='how many tailoring steps?')
parser.add_argument('--num_train_batch', type=int, default=-1, help='if -1, do all of them')
parser.add_argument('--num_val_batch', type=int, default=-1, help='if -1, do all of them')
parser.add_argument('--train_set_length', type=int, default=256, help='size of training set')
parser.add_argument('--test_set_length', type=int, default=-1, help='size of test set')
parser.add_argument('--learn_inner_lr', action='store_true', help='optimize inner LR in outer loop?')
parser.add_argument('--n_trials', type=int, default=7, help='number of trials to average over')
parser.add_argument('--emb_dim', type=int, default=8, help='dim for Emb')
parser.add_argument('--inner_crit_mode', default='mse', help='mse or cosine')
parser.add_argument('--inner_lr', type=float, default=-1, help='learning rate for inner optimizer')
parser.add_argument('--val_inner_lr', type=float, default=-1, help='val. LR for inner opt (if -1, use orig.)')
parser.add_argument('--svg_loss_kl_weight', type=float, default=0.0001, help='weighting factor for KL loss')
parser.add_argument('--reuse_lstm_eps', action='store_true', help='correlated eps samples for prior & posterior?')
parser.add_argument('--only_tailor_on_improvement', action='store_true', help='no outer update if no inner improvement')
parser.add_argument('--stack_frames', action='store_true', help='stack every 2 frames channel-wise')
parser.add_argument('--only_twenty_degree', action='store_true', help='for Phys101 ramp, only 20 degree setting?')
parser.add_argument('--center_crop', type=int, default=1080, help='center crop param (phys101)')
parser.add_argument('--crop_upper_right', type=int, default=1080, help='upper right crop param (phys101)')
parser.add_argument('--frame_step', type=int, default=2, help='controls frame rate for Phys101')
parser.add_argument('--num_emb_frames', type=int, default=1, help='number of frames to pass to the embedding')
parser.add_argument('--horiz_flip', action='store_true', help='randomly flip phys101 sequences horizontally (p=.5)?')
parser.add_argument('--save_warmstart_dataset', action='store_true', help='save_warmstart_dataset')
parser.add_argument('--inner_opt_all_model_weights', action='store_true', help='optimize non-CN model weights in inner loop?')
parser.add_argument('--adam_inner_opt', action='store_true', help='use Adam in inner loop?')

clargs = shlex.split(clargs)

opt = parser.parse_args(clargs)

track_gen = True

opt.n_eval = opt.n_past+opt.n_future
opt.max_step = opt.n_eval

print("Random Seed: ", opt.seed)
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
dtype = torch.cuda.FloatTensor
device = torch.device('cuda')

# --------- load a dataset ------------------------------------
train_data, test_data = utils.load_dataset(opt)

if (opt.num_train_batch == -1) or (len(train_data) // opt.batch_size < opt.num_train_batch):
    opt.num_train_batch = len(train_data) // opt.batch_size
if (opt.num_val_batch == -1) or (len(test_data) // opt.batch_size < opt.num_val_batch):
    opt.num_val_batch = len(test_data) // opt.batch_size

train_loader = DataLoader(train_data,
                          num_workers=opt.num_threads,
                          batch_size=opt.batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=True)
test_loader = DataLoader(test_data,
                         num_workers=opt.num_threads,
                         batch_size=opt.batch_size,
                         shuffle=False,
                         drop_last=True,
                         pin_memory=True)


def get_batch_generator(data_loader):
    while True:
        for sequence in data_loader:
            if not opt.use_action:
                batch = utils.normalize_data(opt, dtype, sequence)
                yield batch
            else:
                images, actions = sequence
                images = utils.normalize_data(opt, dtype, images)
                actions = utils.sequence_input(actions.transpose_(0, 1), dtype)
                yield images, actions

training_batch_generator = get_batch_generator(train_loader)
testing_batch_generator = get_batch_generator(test_loader)

print('\nDatasets loaded!')



# ---------------- plotting util fns --------------------------
def combine_dims(a, start=2, count=2):
    """ Reshapes numpy array a by combining count dimensions, 
        starting at dimension index start """
    s = a.transpose((0,2,1,3,4)).shape
    return np.reshape(a.transpose((0,2,1,3,4)), s[:start] + (-1,) + s[start+count:])


def conf_int(data, alpha=0.95, dist='t'):
    if dist == 't':
        return st.t.interval(alpha, data.shape[0]-1, loc=data.mean(axis=0), scale=st.sem(data, axis=0))
    elif dist == 'norm':
        return st.norm.interval(alpha, loc=data.mean(axis=0), scale=st.sem(data, axis=0))
    elif dist == 'sem':
        return data.mean(axis=0) - st.sem(data, axis=0), data.mean(axis=0) + st.sem(data, axis=0)
    raise NotImplementedError


ckpt = torch.load(opt.model_path)
_opt = ckpt['opt']

# ---------------- set the options ----------------------------
if hasattr(_opt, 'num_emb_frames'):
    opt.num_emb_frames = _opt.num_emb_frames
opt.dataset = _opt.dataset
opt.last_frame_skip = _opt.last_frame_skip
opt.channels = _opt.channels
opt.image_width = _opt.image_width
if hasattr(_opt, 'inner_crit_compare_to'):
    opt.inner_crit_compare_to = _opt.inner_crit_compare_to

# ---------------- load the models ----------------------------
if 'svg_model' in ckpt.keys():
    if opt.inner_lr == -1:
        opt.inner_lr = _opt.inner_lr
    try:
        if opt.val_inner_lr == -1:
            opt.val_inner_lr = _opt.val_inner_lr
    except:
        pass

    opt.inner_crit_mode = _opt.inner_crit_mode
    opt.svg_loss_kl_weight = _opt.svg_loss_kl_weight
    svg_model = ckpt['svg_model']
    print('\nSVG model with pre-trained weights loaded!')
else:
    svg_model = utils.modernize_model(opt.model_path, opt)
    print('\nOld SVG model with pre-trained weights loaded and modernized!')
val_inner_lr = opt.inner_lr
if opt.val_inner_lr != -1:
    val_inner_lr = opt.val_inner_lr
replace_cn_layers(svg_model.encoder, batch_size=opt.batch_size)
replace_cn_layers(svg_model.decoder, batch_size=opt.batch_size)
svg_model.frame_predictor.batch_size = opt.batch_size
svg_model.posterior.batch_size = opt.batch_size
svg_model.prior.batch_size = opt.batch_size
opt.only_cn_decoder = False

svg_model.cuda()
emb = svg_model.emb
svg_model.eval()

norm_trnfm = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])

lpips = LPIPS(
    net_type='alex',  # choose a network type from ['alex', 'squeeze', 'vgg']
    version='0.1'  # Currently, v0.1 is supported
).cuda()

print(opt)

# evaluation loop
If using cached metrics, you can skip this block.

In [None]:
print('starting eval loop...')
all_tailor_ssims = []
all_tailor_psnrs = []
all_tailor_mses = []
all_tailor_lpips = []
all_val_inner_losses = []
all_val_svg_losses = []
all_gen = []
all_outer_losses = []

for trial_num in range(opt.n_trials):
    print(f'TRIAL {trial_num}')
    
    tailor_ssims = []
    tailor_psnrs = []
    tailor_mses = []
    tailor_lpips = []
    val_inner_losses = []
    val_svg_losses = []
    val_outer_loss = 0.
    
    for batch_num in tqdm(range(opt.num_val_batch)):
        batch = next(testing_batch_generator)

        # tailoring pass
        gen_seq, mus, logvars, mu_ps, logvar_ps = tailor_many_steps(
            svg_model, batch, opt=opt, track_higher_grads=False,
            mode='eval',
            # extra kwargs
            inner_crit_mode=opt.inner_crit_mode,
            reuse_lstm_eps=opt.reuse_lstm_eps,
            val_inner_lr=val_inner_lr,
            svg_losses=val_svg_losses,
            tailor_losses=val_inner_losses,
            tailor_ssims=tailor_ssims,
            tailor_psnrs=tailor_psnrs,
            tailor_mses=tailor_mses,
            only_cn_decoder=opt.only_cn_decoder,
            adam_inner_opt=opt.adam_inner_opt,
        )
        
        if track_gen:
            all_gen.append([f.detach().cpu() for f in gen_seq])
        
        # LPIPS
        with torch.no_grad():
            lpips_scores = [[lpips(norm_trnfm(b[_idx]), norm_trnfm(g[_idx])).detach().cpu().item() for b, g in zip(batch[opt.n_past:], gen_seq[opt.n_past:])] for _idx in range(batch[0].shape[0])]
            tailor_lpips.append(lpips_scores)

        with torch.no_grad():
            outer_loss = svg_crit(gen_seq, batch, mus, logvars, mu_ps, logvar_ps, opt)

        val_outer_loss += outer_loss.detach().cpu().numpy().item()

    all_val_inner_losses.append([sum(x) / (opt.num_val_batch) for x in zip(*val_inner_losses)])
    all_val_svg_losses.append([sum(x) / (opt.num_val_batch) for x in zip(*val_svg_losses)])
    all_tailor_ssims.append(copy.deepcopy(tailor_ssims))
    all_tailor_psnrs.append(copy.deepcopy(tailor_psnrs))
    all_tailor_mses.append(copy.deepcopy(tailor_mses))
    all_tailor_lpips.append(copy.deepcopy(tailor_lpips))
    all_outer_losses.append(val_outer_loss / (opt.num_val_batch))
    
    print(f'Model {trial_num}:')
    print(np.array(all_val_svg_losses[-1]).shape)
    print(f'\tOuter SVG loss:   {np.array(all_val_svg_losses[-1]).mean(axis=(0))}')
    print(f'\tInner VAL loss:   {np.array(all_val_inner_losses[-1]).mean(axis=(0))}')
    print(f'\tOuter VAL loss:   {all_outer_losses[-1]}')
    print(f'\tOuter SSIM:       {np.array(all_tailor_ssims[-1]).mean(axis=(0,1,-2))}\n\t\tmean: {np.array(all_tailor_ssims[-1]).mean()}')
    print(f'\tOuter PSNR:       {np.array(all_tailor_psnrs[-1]).mean(axis=(0,1,-2))}\n\t\tmean: {np.array(all_tailor_psnrs[-1]).mean()}')
    print(f'\tOuter MSE:        {np.array(all_tailor_mses[-1]).mean(axis=(0,1,-2))}\n\t\tmean: {np.array(all_tailor_mses[-1]).mean()}')

all_tailor_ssims = np.array(all_tailor_ssims)
all_tailor_psnrs = np.array(all_tailor_psnrs)
all_tailor_mses = np.array(all_tailor_mses)
all_tailor_lpips = np.array(all_tailor_lpips)[:,:,None,:,:]
all_val_inner_losses = np.array(all_val_inner_losses)
all_val_svg_losses = np.array(all_val_svg_losses)
all_outer_losses = np.array(all_outer_losses)

# load cached metrics
the loaded metrics will be shaped like the following:
```
all_tailor_ssims.shape:               (n_trials, num_batches, num_inner_steps+1, batch_size, n_future)
combine_dims(all_tailor_ssims).shape: (n_trials, num_inner_steps+1, num_batches*batch_size, n_future)
```

In [None]:
_use_cached_noether_metrics = True

# specify ID(s) of the training run(s) and the learning rate(s)
noether_experiment_ids = [
    ('001', .001),
    ('002', .001),
    # etc.
]
baseline_experiment_ids = [
    '101',
    '102',
    # etc.
]

cache_steps = 1  # number of inner steps
baseline_eval_epoch = 2  # model checkpoint epoch
noether_eval_epoch = baseline_eval_epoch
num_trials_per_model = 1  # number of trials for evaluation loop
baseline_num_trials_per_model = 1
cache_noether_use_adam = False  # Adam or SGD?

print(f'noether_experiment_ids: {noether_experiment_ids}')

svg_all_base_tailor_ssims = []
svg_all_base_tailor_psnrs = []
svg_all_base_tailor_mses = []
svg_all_base_tailor_lpips = []
for exp_id in baseline_experiment_ids:
    fname = \
        f'eval_metrics/cached_metrics_id{exp_id}-ep{int(baseline_eval_epoch)}-trials{baseline_num_trials_per_model}.npz'
    svg_baseline_metrics = np.load(fname)
    svg_all_base_tailor_ssims.append(svg_baseline_metrics['all_base_tailor_ssims'].mean(axis=0, keepdims=True))
    svg_all_base_tailor_psnrs.append(svg_baseline_metrics['all_base_tailor_psnrs'].mean(axis=0, keepdims=True))
    svg_all_base_tailor_mses.append(svg_baseline_metrics['all_base_tailor_mses'].mean(axis=0, keepdims=True))
    svg_all_base_tailor_lpips.append(svg_baseline_metrics['all_base_tailor_lpips'].mean(axis=0, keepdims=True))
    print(f'loaded baseline metrics from {fname}')

noether_tailor_ssims = []
noether_tailor_psnrs = []
noether_tailor_mses = []
noether_tailor_lpips = []
noether_val_svg_losses = []
noether_val_inner_losses = []

actual_noether_experiment_ids = []

for exp_id, cache_lr in noether_experiment_ids:
    fname = \
        f'eval_metrics/cached_metrics_id{exp_id}-ep{int(baseline_eval_epoch)}'
    fname += f'-trials{num_trials_per_model}'
    if cache_lr is not None and cache_steps is not None:
        fname += f'-lr{cache_lr}'
        fname += f'-steps{cache_steps}'
    if cache_noether_use_adam:
        fname += '-adam'
    fname += '.npz'
    if os.path.isfile(fname):
        svg_baseline_metrics = np.load(fname)
        noether_tailor_ssims.append(svg_baseline_metrics['all_base_tailor_ssims'].mean(axis=0, keepdims=True))
        noether_tailor_psnrs.append(svg_baseline_metrics['all_base_tailor_psnrs'].mean(axis=0, keepdims=True))
        noether_tailor_mses.append(svg_baseline_metrics['all_base_tailor_mses'].mean(axis=0, keepdims=True))
        noether_tailor_lpips.append(svg_baseline_metrics['all_base_tailor_lpips'].mean(axis=0, keepdims=True))
        if 'all_val_svg_losses' in svg_baseline_metrics and 'all_val_inner_losses' in svg_baseline_metrics:
            noether_val_svg_losses.append(svg_baseline_metrics['all_val_svg_losses'])
            noether_val_inner_losses.append(svg_baseline_metrics['all_val_inner_losses'])
        print(f'loaded Noether Network metrics from {fname}')
        actual_noether_experiment_ids.append(exp_id)
    else:
        print(f'File {fname} does not exist!')

noether_experiment_ids = actual_noether_experiment_ids

svg_all_base_tailor_ssims = np.concatenate(svg_all_base_tailor_ssims)
svg_all_base_tailor_psnrs = np.concatenate(svg_all_base_tailor_psnrs)
svg_all_base_tailor_mses = np.concatenate(svg_all_base_tailor_mses)
svg_all_base_tailor_lpips = np.concatenate(svg_all_base_tailor_lpips)

if _use_cached_noether_metrics:
    all_tailor_ssims = np.concatenate(noether_tailor_ssims)
    all_tailor_psnrs = np.concatenate(noether_tailor_psnrs)
    all_tailor_mses = np.concatenate(noether_tailor_mses)
    all_tailor_lpips = np.concatenate(noether_tailor_lpips)
    all_val_inner_losses = np.concatenate(noether_val_inner_losses)
    all_val_svg_losses = np.concatenate(noether_val_svg_losses)

# Plot evaluation metrics vs. prediction horizon

In [None]:
%matplotlib inline
from matplotlib.ticker import MaxNLocator

num_epochs = int(opt.model_path.split('/')[-1].split('_')[-1].split('.')[0])
experiment_id = opt.model_path.split('/')[-2]
plot_super_title = f"SVG baseline (400 ep) vs. meta-tailored (90 more epochs) on {opt.dataset} ({opt.image_width}x{opt.image_width})"
plot_super_title = f"SVG ({experiment_id}) meta-trained for {num_epochs} epoch(s)"
plot_super_title = f"Physics 101: real-world video prediction\n"

_title_size = 30
_label_size = 20
_legend_size = 14
_tick_size = 14

ci_alpha = 0.95
base_label = 'naive baseline'
base_label = 'SVG (more steps)'
svg_label = 'SVG baseline'
plot_baseline = True
tailor_plot_idx = [1, 99]  # indices of inner steps to show (if they exist)

sns.set_style("whitegrid")

ssim_mean = combine_dims(all_tailor_ssims)
psnr_mean = combine_dims(all_tailor_psnrs)
mse_mean =  combine_dims(all_tailor_mses)
lpips_mean =  combine_dims(all_tailor_lpips)

print(ssim_mean.shape)

if plot_baseline:
    svg_base_ssim_mean = combine_dims(svg_all_base_tailor_ssims)[:,-1,:,:]
    svg_base_psnr_mean = combine_dims(svg_all_base_tailor_psnrs)[:,-1,:,:]
    svg_base_mse_mean =  combine_dims(svg_all_base_tailor_mses)[:,-1,:,:]
    svg_base_lpips_mean =  combine_dims(svg_all_base_tailor_lpips)[:,-1,:,:]

clrs = sns.color_palette('husl', ssim_mean.shape[1]+1)
clrs[-1] = 'black'
noether_color = '#2ca02c' # green
base_color = '#ff7f0e' # orange
outer_clrs = sns.color_palette('husl', 3)

fig = plt.figure(figsize=(24, 4))
gs = GridSpec(1, 4, figure=fig)
ax00 = fig.add_subplot(gs[0,3])
ax01 = fig.add_subplot(gs[0,2])
ax11 = fig.add_subplot(gs[0,0])
ax10 = fig.add_subplot(gs[0,1])

for ax in ax00, ax01, ax10, ax11:
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.tick_params(axis='x', labelsize=_tick_size)
    ax.tick_params(axis='y', labelsize=_tick_size)

ax10.ticklabel_format(style='sci', scilimits=(0,0), axis='y')
ax10.yaxis.offsetText.set_fontsize(_tick_size)

ax00.set_title('Test SSIM ⬆️', fontsize=_title_size)
x_ax = np.arange(opt.n_past+1, opt.n_past+1+ssim_mean.shape[-1])
for step in range(ssim_mean.shape[1]):
    mod_label = f'Noether Network ({step} step)'
    if step in tailor_plot_idx:
        ax00.plot(x_ax, ssim_mean[:,step,:,:].mean(axis=0).mean(axis=-2),
                  color=noether_color,
                  label=mod_label)
        lb, ub = conf_int(ssim_mean[:,step,:,:].mean(axis=-2), dist='sem')
        ax00.fill_between(x_ax, lb, ub, alpha=0.3, color=noether_color)

if plot_baseline:
    ax00.plot(x_ax, svg_base_ssim_mean.mean(axis=0).mean(-2), color=base_color, label=svg_label)
    lb, ub = conf_int(svg_base_ssim_mean.mean(axis=-2), dist='sem')
    ax00.fill_between(x_ax, lb, ub, alpha=0.3, color=base_color)

ax00.legend(prop={'size': _legend_size})
ax00.set_xlabel('Prediction horizon', fontsize=_label_size)

ax01.set_title('Test PSNR ⬆️', fontsize=_title_size)
for step in range(ssim_mean.shape[1]):
    mod_label = f'tailored ({step} step)'
    if step in tailor_plot_idx:
        ax01.plot(x_ax, psnr_mean[:,step,:,:].mean(axis=0).mean(axis=-2), color=noether_color)
        lb, ub = conf_int(psnr_mean[:,step,:,:].mean(axis=-2), dist='sem')
        ax01.fill_between(x_ax, lb, ub, alpha=0.3, color=noether_color)

if plot_baseline:
    ax01.plot(x_ax, svg_base_psnr_mean.mean(axis=0).mean(-2), color=base_color)
    lb, ub = conf_int(svg_base_psnr_mean.mean(axis=-2), dist='sem')
    ax01.fill_between(x_ax, lb, ub, alpha=0.3, color=base_color)

ax01.set_xlabel('Prediction horizon', fontsize=_label_size)
ax10.set_title('Test MSE ⬇️', fontsize=_title_size)
for step in range(ssim_mean.shape[1]):
    mod_label = f'tailored ({step} step)'
    if step in tailor_plot_idx:
        ax10.plot(x_ax, mse_mean[:,step,:,:].mean(axis=0).mean(axis=-2),
                  color=noether_color,
                  label=mod_label)
        lb, ub = conf_int(mse_mean[:,step,:,:].mean(axis=-2), dist='sem')
        ax10.fill_between(x_ax, lb, ub, alpha=0.3, color=noether_color)  # color=clrs[step])

if plot_baseline:
    ax10.plot(x_ax, svg_base_mse_mean.mean(axis=0).mean(-2), color=base_color)
    lb, ub = conf_int(svg_base_mse_mean.mean(-2), dist='sem')
    ax10.fill_between(x_ax, lb, ub, alpha=0.3, color=base_color)

ax10.set_xlabel('Prediction horizon', fontsize=_label_size)

ax11.set_title('Test LPIPS ⬇️', fontsize=_title_size)
mod_label = f'tailored ({step} step)'
ax11.plot(x_ax, lpips_mean[:,0,:,:].mean(axis=0).mean(axis=-2), color=noether_color)
lb, ub = conf_int(lpips_mean[:,0,:,:].mean(axis=-2), dist='sem')
ax11.fill_between(x_ax, lb, ub, alpha=0.3, color=noether_color)  # color=clrs[step])

if plot_baseline:
    ax11.plot(x_ax, svg_base_lpips_mean.mean(axis=0).mean(-2), color=base_color)
    lb, ub = conf_int(svg_base_lpips_mean.mean(-2), dist='sem')
    ax11.fill_between(x_ax, lb, ub, alpha=0.3, color=base_color)

ax11.set_xlabel('Prediction horizon', fontsize=_label_size)
print()

# Plot {inner, outer} loss vs. inner step
This is only interesting for evaluation runs where you take many inner steps.

In [None]:
%matplotlib inline

_title_size = 21
_label_size = 18
_legend_size = 14
_tick_size = 14

plt_steps = all_val_inner_losses.shape[1]

fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(6, 10))

for ax in axes:
    ax.tick_params(axis='x', labelsize=_tick_size)
    ax.tick_params(axis='y', labelsize=_tick_size)
axes[1].ticklabel_format(style='sci', scilimits=(0,0), axis='y')
axes[1].yaxis.offsetText.set_fontsize(_tick_size)

axes[0].plot(all_val_inner_losses.mean(0)[:plt_steps])
axes[1].plot(all_val_svg_losses.mean(0)[:plt_steps])

lb, ub = conf_int(all_val_inner_losses[:,:plt_steps], dist='sem')
axes[0].fill_between(np.arange(0,plt_steps), lb, ub, alpha=0.3,)

lb, ub = conf_int(all_val_svg_losses[:,:plt_steps], dist='sem')
axes[1].fill_between(np.arange(0,plt_steps), lb, ub, alpha=0.3,)

axes[0].set_title('Inner loss vs. Inner step', fontsize=_title_size)
axes[1].set_title('Outer loss vs. Inner step', fontsize=_title_size)
axes[0].set_ylabel('Inner loss', fontsize=_label_size)
axes[1].set_ylabel('Outer loss', fontsize=_label_size)
axes[0].set_xlabel('Inner step', fontsize=_label_size)
axes[1].set_xlabel('Inner step', fontsize=_label_size)

axes[0].set_ylim(bottom=0)

fig.tight_layout()
plt.show()

# Grad-CAM
This section contains code to produce Grad-CAM heatmaps to visualize the "important" regions of frames for each dimension of the embedding (either in the embedding space or in PCA coordinates, which is useful for large embeddings).
You'll need to run the evaluation loop earlier in the notebook to populate `all_gen` with the predicted frames for all of the sequences in the validation set.

In [None]:
gt_all = []
for batch_num in tqdm(range(opt.num_val_batch)):
    batch = next(testing_batch_generator)
    gt_all.append(batch)
gt_catted = torch.stack([torch.cat([gt_all[_i][t] for _i in range(len(gt_all))]) for t in range(len(gt_all[0]))]).cpu()

gc_all_gen = torch.stack([torch.cat([all_gen[_i][t] for _i in range(len(all_gen))]) for t in range(len(all_gen[0]))])
gt_catted.shape, gc_all_gen.shape

In [None]:
single_seq = True
grad_cam_seq_id = 17

batch = [fr for fr in gt_catted[17:19]]  # for single time step, many sequences
if single_seq:
    batch = [fr for fr in gt_catted[:,grad_cam_seq_id:grad_cam_seq_id+1]]  # for all time steps, single sequence

stacked_batch = []
for i in range(1, len(batch)):
    stacked_batch.append(torch.cat((batch[i-1], batch[i]), dim=1))
if not single_seq:
    stacked_batch = [fr.unsqueeze(0) for fr in stacked_batch[0]]
len(stacked_batch), stacked_batch[0].shape

In [None]:
class EmbWrapper(nn.Module):
    def __init__(self, emb, pca_model):
        super().__init__()
        self.emb = emb
        self.pca_model = pca_model
        
    def forward(self, X):
        X_emb = self.emb(X)
        X_pca = self.transform_torch(X_emb)
        return X_pca
    
    def transform_torch(self, X):
        if self.pca_model.mean_ is not None:
            X = X - torch.from_numpy(self.pca_model.mean_).cuda()
        X_transformed = torch.mm(X, torch.from_numpy(self.pca_model.components_.T).cuda())
        if self.pca_model.whiten:
            X_transformed /= torch.from_numpy(np.sqrt(self.pca_model.explained_variance_)).cuda()
        return X_transformed

In [None]:
USE_PCA = True

emb = svg_model.emb

if USE_PCA:
    from sklearn.decomposition import PCA

    debug_loader = DataLoader(test_data,
                             num_workers=opt.num_threads,
                             batch_size=78,
                             shuffle=False,
                             drop_last=True,
                             pin_memory=True)

    debug_batch_generator = get_batch_generator(debug_loader)
    dbatch = next(debug_batch_generator)

    with torch.no_grad():
        stacked_dbatch = []
        for i in range(1, len(dbatch)):
            stacked_dbatch.append(torch.cat((dbatch[i-1], dbatch[i]), dim=1))
        all_dembs = [emb(frame).detach().cpu() for frame in stacked_dbatch]  # len(embs) = len(gen_seq) - 1

    dembs = torch.cat(all_dembs[2:])
    my_model = PCA(n_components=6)
    my_model.fit_transform(dembs.numpy())
    print(my_model.explained_variance_ratio_.cumsum())

    emb = EmbWrapper(copy.deepcopy(svg_model.emb), my_model)
emb

In [None]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2

if hasattr(emb, 'emb'):
    target_layer = emb.emb.layer2[-2]
    grad_cam_emb_dim = emb.pca_model.n_components
    print('using wrapper')
else:
    target_layer = emb.layer2[-2]
    grad_cam_emb_dim = emb.fc1.out_features
    print('using emb directly')
cam_model = GradCAM(model=emb, target_layer=target_layer, use_cuda=True)

In [None]:
%matplotlib inline

plt.figure(figsize=(50, 300))
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([])

cams = []
visualizations = []

for j in tqdm(range(len(stacked_batch))):
    in_tensor = stacked_batch[j]
    dim_cams = []
    dim_visualizations = []
    for dim in range(grad_cam_emb_dim):
        cam = cam_model(
            input_tensor=in_tensor,
            target_category=dim,
        )
        converted_cam = cv2.cvtColor(cam[0], cv2.COLOR_GRAY2BGR)
        visualization = show_cam_on_image(
            in_tensor[0,:3].cpu().numpy().transpose((1,2,0)),
            converted_cam,
            use_rgb=True,
        )
        dim_visualizations.append(visualization)
        dim_cams.append(cam)
    visualizations.append(dim_visualizations)
    cams.append(dim_cams)

plt.imshow(np.concatenate([np.concatenate(v, 1) for v in visualizations], 0))

In [None]:
import imageio

gif_path = './grad_cam/'
os.makedirs(gif_path, exist_ok=True)
gif_frames = [np.concatenate(v, 1) for v in visualizations]
imageio.mimsave(gif_path + f'grad_cam_seq{grad_cam_seq_id}.gif', gif_frames)