In [1]:
import os
import torch
import torchvision
import torch.backends.cudnn as cudnn
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from importlib import reload
from datasets.three_dim_shapes import ThreeDimShapesDataset
from datasets.small_norb import SmallNORBDataset
from datasets.seq_mnist import SequentialMNIST
import models.seqae as seqae
import signal
import seaborn as sns
from tqdm import tqdm
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
else:
    device = torch.device('cpu')
    gpu_index = -1


In [2]:
# PATH TO THE ROOT OF DATASETS DIRECTORIES
datadir = '/tmp/path/to/datadir'
logdir = '/tmp/path/to/logdir'

# Equivariance error

In [None]:
def load_model(model, log_dir, iters):
    model.load_state_dict(torch.load(os.path.join(
        log_dir, 'snapshot_model_iter_{}'.format(iters)), map_location=device))

n_cond = 2

shared_transition = True
equiv_errors = {}
dataset_names = ['mnist', 'mnist_bg', '3dshapes', 'smallNORB']
model_names = ['neuralM', 'lstsq_rec', 'lstsq']

for dataset_name in dataset_names:
    equiv_errors[dataset_name] = {}
    rng = np.random.RandomState(1)
    if dataset_name == 'mnist':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=False,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 2
        ch_x = 3
        iters=50000
    elif dataset_name == 'mnist_bg':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=True,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 4
        ch_x = 3  
        iters=100000
    elif dataset_name == '3dshapes':
        T = 8
        test_data = ThreeDimShapesDataset(
            root=datadir,
            train=True, T=T,
            rng=rng)
        bottom_width = 8
        n_blocks = 3
        k=1
        ch_x = 3
        iters=50000
    elif dataset_name == 'smallNORB':
        T = 6
        test_data = SmallNORBDataset(
            root=datadir,
            train=False,
            T=T,
            rng=rng)
        bottom_width = 6
        n_blocks = 4
        k=1
        ch_x = 1
        iters=50000
    else:
        raise NotImplementedError

    n_rolls = T - n_cond
    
    for model_name in model_names:
        if model_name == 'neural_trans':
            model = seqae.SeqAENeuralTransition(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, T_cond=2, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'neuralM':
            model = seqae.SeqAENeuralM(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'multi-lstsq_K8':
            model = seqae.SeqAEMultiLSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, K=8, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'lstsq' or model_name == 'lstsq_rec':
            model = seqae.SeqAELSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        model.to(device)
        bsize = 32
        test_loader = DataLoader(test_data, bsize, True, num_workers=0)
        
        # Initialize lazy modules
        images = iter(test_loader).next()
        images = torch.stack(images).transpose(1, 0)
        images = images.to(device)
        model(images[:, :2])
        results = []
        for seed in [1, 2, 3]: 
            path = os.path.join(logdir, "{}-{}-seed{}".format(dataset_name, model_name, seed))
            load_model(model, path, iters=iters)
            count =0
            losses = []
            losses_perm = []
            
            with torch.no_grad():
                for images in tqdm(test_loader):
                    images = torch.stack(images).transpose(1,0)
                    images = images.to(device)
                    images_cond = images[:, :2]
                    images_target = images[:, 2:3]
                    M = model.get_M(images_cond) #n a a
                    H = model.encode(images_cond[:, -1:])[:, 0] # n s a
                    H_next = H @ M 
                    swapM = M[torch.arange(-bsize//2, M.shape[0]-bsize//2)]
                    H_next_perm = H @ swapM
                    x_next = torch.sigmoid(model.decode(H_next[:, None]))
                    x_next_perm = torch.sigmoid(model.decode(H_next_perm[:, None]))
                    
                    if dataset_name == 'smallNORB':
                        x_next = torch.mean(x_next, 2, keepdims=True)
                        x_next_perm = torch.mean(x_next_perm, 2, keepdims=True)
                    loss = torch.sum((images_target-x_next)**2, dim=(2,3,4)).detach().cpu().numpy()
                    loss_perm = torch.sum((images_target-x_next_perm)**2, dim=(2,3,4)).detach().cpu().numpy()
                    losses.append(loss)
                    losses_perm.append(loss_perm)
                    
                    count += 1
                    test_data.init_shared_transition_parameters()
            results.append([np.mean(np.concatenate(losses)), np.mean(np.concatenate(losses_perm))])
        equiv_errors[dataset_name][model_name] = [np.mean(np.array(results), 0), np.std(np.array(results), 0)]

In [None]:
print(equiv_errors)