In [3]:
import os
import torch
import torchvision
import torch.backends.cudnn as cudnn
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from importlib import reload
from datasets.three_dim_shapes import ThreeDimShapesDataset
from datasets.small_norb import SmallNORBDataset
from datasets.seq_mnist import SequentialMNIST
import models.seqae as seqae
from tqdm import tqdm

if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
else:
    device = torch.device('cpu')
    gpu_index = -1



In [4]:
# PATH TO THE ROOT OF DATASETS DIRECTORIES
datadir = '/tmp/path/to/datadir'
logdir = '/tmp/path/to/logdir'

# Prediction 

In [None]:
def load_model(model, log_dir, iters):
    model.load_state_dict(torch.load(os.path.join(
        log_dir, 'snapshot_model_iter_{}'.format(iters)), map_location=device))

n_cond = 2

dataset_names = ['mnist', 'mnist_bg', '3dshapes', 'smallNORB']
model_names = ['neural_trans', 'neuralM', 'lstsq_rec', 'lstsq_multi', 'lstsq']
comp_errors = {}

for dataset_name in dataset_names:
    rng = np.random.RandomState(1)
    comp_errors[dataset_name] = {}
    if dataset_name == 'mnist':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=False,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 2
        iters=50000
        ch_x = 3
    elif dataset_name == 'mnist_bg':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=True,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 4
        iters=100000
        ch_x = 3
    elif dataset_name == '3dshapes':
        T = 8
        test_data = ThreeDimShapesDataset(
            root=datadir,
            train=True, T=T,
            rng=rng)
        bottom_width = 8
        n_blocks = 3
        k=1
        iters=50000
        ch_x = 3
    elif dataset_name == 'smallNORB':
        T = 6
        test_data = SmallNORBDataset(
            root=datadir,
            train=False,
            T=T,
            rng=rng)
        bottom_width = 6
        n_blocks = 4
        k=1
        iters=50000
        ch_x = 1
    else:
        raise NotImplementedError
    n_rolls = T - n_cond
    test_loader = DataLoader(test_data, 16, True, num_workers=0)
    images = iter(test_loader).next()
    images = torch.stack(images).transpose(1, 0)
    images = images.to(device)
    for model_name in model_names:
        if model_name == 'neural_trans':
            model = seqae.SeqAENeuralTransition(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, T_cond=2, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'neuralM':
            model = seqae.SeqAENeuralM(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'multi-lstsq_K8':
            model = seqae.SeqAEMultiLSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, K=8, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'lstsq' or model_name == 'lstsq_rec':
            model = seqae.SeqAELSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        model.to(device)
        results = []
        for seed in [1, 2, 3]: 
            path = os.path.join(logdir, "{}-{}-seed{}".format(dataset_name, model_name, seed))
            load_model(model, path, iters=iters)
            model.eval()
            losses = []
            with torch.no_grad():
                for images in tqdm(test_loader):
                    images = torch.stack(images).transpose(1, 0)
                    images = images.to(device)
                    images_cond, images_target = images[:, :n_cond], images[:, n_cond:n_cond+n_rolls]
                    images_pred = torch.sigmoid(model(images_cond, n_rolls=n_rolls)) 
                    loss = torch.sum((images_target-images_pred)**2, axis=[2, 3, 4])
                    losses.append(loss.detach().cpu().numpy())
            results.append(np.mean(np.concatenate(losses, 0), 0))
        comp_errors[dataset_name][model_name]= np.mean(np.stack(results), 0)


In [None]:
!mkdir figs

for dataset_name in dataset_names:
    label_name= {'neural_trans':'Neural transition', 'neuralM':'Neural$M^*$',
                'lstsq_rec':'Rec. model', 'lstsq_multi':'Ours w/fixed blocks', 'lstsq': 'Ours'}
    l_styles = ['-.', '-o', ':', '--', '-']

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    cs=colors
    plt.style.use('ggplot')
    plt.figure(figsize=[6, 4])
    T = comp_errors[dataset_name]['lstsq'].shape[0]
    for i, model_name in enumerate(model_names):
        plt.plot(np.arange(1, T+1), 
                    comp_errors[dataset_name][model_name],l_styles[i], label=label_name[model_name],
                    c=cs[i])
     
    plt.xticks(np.arange(1, T+1, 1 + (T+1)//10), fontsize=16)
    plt.yticks(fontsize=16)
    plt.xlabel('$t_p$', fontsize=24)
    plt.ylabel('MSE', fontsize=20)
    plt.tight_layout()
    plt.savefig('figs/comp_error_{}.pdf'.format(dataset_name))
