In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import torch

import torchvision
import time

from model import VAE, IVAE
from train import train_geco_draw, train_beta_draw
from utils import sample_vae, marginal_KL, Compute_NLL
import datasets
from conv_draw import ConvolutionalDRAW
# from GECO import *
# from beta_vae import *
torch.cuda.set_device(5)


%matplotlib inline

In [2]:
# !nvidia-smi

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)
# device = 'cpu'

cuda


In [4]:
# !unzip data/celeba.zip
def plot_gallery(images, h, w, n_row=3, n_col=6):
    plt.figure(figsize=(1.5 * n_col, 1.7 * n_row))
    plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
    for i in range(n_row * n_col):
        plt.subplot(n_row, n_col, i + 1)
        plt.imshow(images[i], cmap=plt.cm.gray, vmin=-1, vmax=1, interpolation='nearest')
        plt.xticks(())
        plt.yticks(())

In [5]:
dataset = 'cifar10'

if dataset == 'mnist':
    train_set = datasets.MNIST('./data/'+dataset+'/',  download=True, train=True, \
                                            transform=torchvision.transforms.ToTensor())
    test_set = datasets.MNIST('./data/'+dataset+'/',  download=True, train=False, \
                                            transform=torchvision.transforms.ToTensor())
    input_size = (28, 28)
    
elif dataset == 'cifar10':
    train_set = datasets.CIFAR10('./data/'+dataset+'/',  download=True, train=True, \
                                            transform=torchvision.transforms.ToTensor())
    test_set = datasets.CIFAR10('./data/'+dataset+'/',  download=True, train=False, \
                                            transform=torchvision.transforms.ToTensor())
    input_size = (32, 32)
else:
    train_set = datasets.CELEBA('./data/'+dataset+'/', train=True, \
                                            transform=torchvision.transforms.ToTensor())
    test_set = datasets.CELEBA('./data/'+dataset+'/',  train=False, \
                                            transform=torchvision.transforms.ToTensor())
    input_size = (218,178)
    
    
batch_size = 300
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
cd_model = ConvolutionalDRAW(x_dim = 3, x_shape = input_size, h_dim = 256, T = 1)
optimizer = optim.Adam(cd_model.parameters(), lr=1e-3)
scheduler = None

train_geco_draw(cd_model, optimizer, scheduler, 
               train_loader = train_loader, 
               valid_loader = test_loader, 
               device = device, lbd_step = 100, 
               num_epochs=100, lambd_init = torch.FloatTensor([1]),
               tol = 0.6, pretrain = 0)

In [9]:
cd_model_beta05 = ConvolutionalDRAW(x_dim = 3, x_shape = input_size, h_dim = 256, T = 1)

optimizer = optim.Adam(cd_model_beta05.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, verbose=True)
train_beta_draw(cd_model_beta05, optimizer, scheduler, train_loader, 
           test_loader, num_epochs=50, beta=0.5)