In [4]:
import os
import torch
import torchvision
import torch.backends.cudnn as cudnn
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from importlib import reload
from datasets.three_dim_shapes import ThreeDimShapesDataset
from datasets.small_norb import SmallNORBDataset
from datasets.seq_mnist import SequentialMNIST
import models.seqae as seqae
import signal
import seaborn as sns

if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
else:
    device = torch.device('cpu')
    gpu_index = -1


In [5]:
# PATH TO THE ROOT OF DATASETS DIRECTORIES
datadir = '/tmp/path/to/datadir'
logdir = '/tmp/path/to/logdir'

In [6]:
def load_model(model, log_dir, iters):
    model.load_state_dict(torch.load(os.path.join(
        log_dir, 'snapshot_model_iter_{}'.format(iters)), map_location=device))

def init_random_seed():
    random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    np.random.seed(1)

def save_images(images_cond, images_true, images_pred, fname, n_col=2):
    T = images_cond.shape[1] + images_true.shape[1]
    with torch.no_grad():
        model.eval()
        images_pred = torch.cat([torch.ones_like(images_cond), images_pred[:,]], 1)
        plt.figure(figsize=[30, 8])
        for k in range(0, images_pred.shape[0]):
            plt.subplot(images_pred.shape[0]//n_col, n_col,k+1)
            seq_images = torch.cat([images_pred, images_cond, images_true], 1)
            seq_images_grid = torchvision.utils.make_grid(seq_images[k], nrow=T, pad_value=1.0)
            plt.imshow(seq_images_grid.detach().cpu().numpy().transpose(1,2,0))
            plt.axis('off')
    plt.tight_layout()
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0)
    plt.savefig(os.path.join('gen_images/', fname))

# Gen images

In [10]:

!mkdir gen_images
n_cond = 2
seed = 1
dataset_names = ['mnist', 'mnist_bg', '3dshapes', 'smallNORB']
model_names = ['neural_trans', 'neuralM', 'lstsq_rec', 'lstsq_multi', 'lstsq']

for dataset_name in dataset_names:
    rng = np.random.RandomState(1)
    if dataset_name == 'mnist':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=False,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 2
        ch_x = 3
        iters=50000
    elif dataset_name == 'mnist_bg':
        T = 10
        test_data = SequentialMNIST(
            datadir, False, T=T, max_angle_velocity_ratio=[-0.2, 0.2],
            max_color_velocity_ratio=[-0.2, 0.2],
            max_pos=[-10, 10],
            max_T=T,
            only_use_digit4=True,
            backgrnd=True,
            rng=rng)
        bottom_width = 4
        n_blocks = 3
        k = 4
        ch_x = 3
        iters=100000
    elif dataset_name == '3dshapes':
        T = 8
        test_data = ThreeDimShapesDataset(
            root=datadir,
            train=True, T=T,
            rng=rng)
        bottom_width = 8
        n_blocks = 3
        k=1
        ch_x=3
        iters=50000
    elif dataset_name == 'smallNORB':
        T = 6
        test_data = SmallNORBDataset(
            root=datadir,
            train=False,
            T=T,
            rng=rng)
        bottom_width = 6
        n_blocks = 4
        k=1
        ch_x = 1
        iters=50000
    else:
        raise NotImplementedError
    
    n_rolls = T - n_cond
    
    for model_name in model_names:
        if model_name == 'neural_trans':
            model = seqae.SeqAENeuralTransition(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, T_cond=2, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'neuralM':
            model = seqae.SeqAENeuralM(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'multi-lstsq_K8':
            model = seqae.SeqAEMultiLSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, K=8, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        elif model_name == 'lstsq' or model_name == 'lstsq_rec':
            model = seqae.SeqAELSTSQ(
                dim_a=16, dim_m=256, global_average_pooling=False, k=k, bottom_width=bottom_width, n_blocks=n_blocks, ch_x=ch_x)
        model.to(device)
        test_loader = DataLoader(test_data, 4, True, num_workers=0)
        
        # Initialize lazy modules
        images = iter(test_loader).next()
        images = torch.stack(images).transpose(1, 0)
        images = images.to(device)
        model(images[:, :2])
        
        path = os.path.join(logdir, "{}-{}-seed{}".format(dataset_name, model_name, seed))
        load_model(model, path, iters=iters)
        
        model.eval()
        with torch.no_grad():
            init_random_seed()
            test_data.rng= np.random.RandomState(1234)
            test_loader = DataLoader(test_data, 4, True, num_workers=0)
            images = iter(test_loader).next()
            images = torch.stack(images).transpose(1, 0)
            images = images.to(device)
            images_cond, images_target = images[:, :n_cond], images[:, n_cond:n_cond+n_rolls]
            images_pred = torch.sigmoid(model(images_cond, n_rolls=n_rolls))
            save_images(images_cond, images_target, images_pred, fname=dataset_name+'_'+model_name, n_col=2)
