In [1]:
import sys, os, copy, time
import random
import numpy as np
import torch
HOME_PATH = os.path.expanduser('~')

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import poisevae
from poisevae.datasets import MNIST_MNIST
from poisevae.networks.PixelCNN_MNIST import EncMNIST, DecMNIST
from poisevae.networks.pixelcnn import PixelCNN

In [2]:
MNIST_PATH = os.path.join(HOME_PATH, 'Datasets/MNIST/%s.pt')

joint_dataset_train = MNIST_MNIST(mnist_pt_path=MNIST_PATH % 'train')
joint_dataset_test = MNIST_MNIST(mnist_pt_path=MNIST_PATH % 'test')

In [3]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(joint_dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(joint_dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)
len(train_loader), len(test_loader)

(468, 78)

In [4]:
lat1, lat2 = 20, 20
color_level = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

enc_mnist1 = EncMNIST(lat1).to(device)
dec_mnist1 = DecMNIST(PixelCNN(lat1, 1, color_level), color_level).to(device)
enc_mnist2 = EncMNIST(lat1).to(device)
dec_mnist2 = DecMNIST(PixelCNN(lat2, 1, color_level), color_level).to(device)

# Options: 'derivative_autograd', 'derivative_gradient', and 'std_normal'
vae = poisevae.POISEVAE([enc_mnist1, enc_mnist2], [dec_mnist1, dec_mnist2], latent_dims=[lat1, (lat2)], 
                        enc_config='nu', KL_calc='std_normal', batch_size=batch_size
                       ).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
writer = None

In [5]:
epoch = 0

In [6]:
epochs = 3 + epoch
for epoch in tqdm(range(epoch, epochs)):
    poisevae.utils.train(vae, train_loader, optimizer, epoch, kl_weight=1, n_gibbs_iter=5, writer=writer, 
                         dec_kwargs={'generate_mode':False})
    poisevae.utils.test(vae, test_loader, epoch, kl_weight=1, n_gibbs_iter=5, writer=writer, 
                        dec_kwargs={'generate_mode':False})
    # if (epoch+1) % 10 == 0 and epoch > 0:
    #     poisevae.utils.save_checkpoint(vae, optimizer, os.path.join(PATH, 'training_%d.pt' % (epoch+1)), epoch+1) 

  0%|                                                                                   | 0/3 [00:13<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 58.00 MiB (GPU 0; 11.91 GiB total capacity; 8.93 GiB already allocated; 52.62 MiB free; 9.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        results = vae([data[0].to(device, dtype=torch.float32), data[0].to(device, dtype=torch.float32)])
        break

results['x_rec'][0].shape

In [None]:
ncols = min(len(results['x_rec'][0]), 20)
fig, ax = plt.subplots(nrows=2, ncols=ncols, figsize=(15, 1.5))
for i, aux in enumerate(zip(data[0], results['x_rec'][0])):
    if i >= ncols:
        break
    for j, im in enumerate(aux):
        ax[j, i].imshow(im.cpu().numpy().reshape(28, 28), cmap='gray')
        ax[j, i].set_axis_off()
# ax[1, 0].set_ylabel('Rec', fontsize=24)
fig.tight_layout(pad=0)
# fig.savefig(PATH + 'MNISTRec.pdf', dpi=300)

## Joint Generation

In [None]:
with torch.no_grad():
    results = vae.generate(n_samples=12, dec_kwargs={'generate_mode':True})
results['x_rec'][0].shape

In [None]:
ncols = min(len(results['x_rec'][0]), 20)
fig, ax = plt.subplots(nrows=2, ncols=ncols, figsize=(15, 1.5))
for i, aux in enumerate(zip(data[0], results['x_rec'][0])):
    if i >= ncols:
        break
    for j, im in enumerate(aux):
        ax[j, i].imshow(im.cpu().numpy().reshape(28, 28), cmap='gray')
        ax[j, i].set_axis_off()
# ax[1, 0].set_ylabel('Rec', fontsize=24)
fig.tight_layout(pad=0)
# fig.savefig(PATH + 'MNISTRec.pdf', dpi=300)

In [None]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        results = vae([data[0].to(device, dtype=torch.float32), data[0].to(device, dtype=torch.float32)])#, 
                      # n_gibbs_iter=30)
        break