In [1]:
import sys
import os, copy, time
import random
import numpy as np
import torch
HOME_PATH = os.path.expanduser('~')
sys.path.append(os.path.join(HOME_PATH, 'vae_project/PixelCNN-Pytorch'))

import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, Laplace
from tqdm import tqdm
import matplotlib.pyplot as plt
import poisevae
from poisevae.utils import CategoricalImage as Categorical
from poisevae.datasets import MNIST_SVHN
from poisevae.networks.MNISTMNISTNetworks_X import EncMNIST1, DecMNIST1, EncMNIST2, DecMNIST2
from poisevae.networks.MNISTSVHNNetworks_pixelcnn import pixelcnn_decoder
from Model import PixelCNN

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
MNIST_PATH = os.path.join(HOME_PATH, 'Datasets/MNIST/%s.pt')
SVHN_PATH = os.path.join(HOME_PATH, 'Datasets/SVHN/%s_32x32.mat')

joint_dataset_train = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'train', svhn_mat_path=SVHN_PATH % 'train')
joint_dataset_test = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'test', svhn_mat_path=SVHN_PATH % 'test')

In [3]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(joint_dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(joint_dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)
len(train_loader), len(test_loader)

(1874, 312)

In [4]:
lat1, lat2 = 20, 20
enc_mnist1 = EncMNIST1(lat1, lat2).to(device)
enc_mnist2 = EncMNIST2(lat1, lat2).to(device)

mlp1 = torch.nn.Sequential(nn.Linear(lat1, 8), nn.ReLU(inplace=True), nn.Linear(8, 1))
mlp2 = torch.nn.Sequential(nn.Linear(lat2, 8), nn.ReLU(inplace=True), nn.Linear(8, 1))

dec_mnist1 = pixelcnn_decoder(mlp1, PixelCNN(no_layers=3), (1, 28, 28)).to(device)
dec_mnist2 = pixelcnn_decoder(mlp1, PixelCNN(no_layers=3), (1, 28, 28)).to(device)

# Options: 'autograd' and 'gradient'
vae = poisevae.POISEVAE_Gibbs('gradient',
                              [enc_mnist1, enc_mnist2], [dec_mnist1, dec_mnist2], likelihoods=[Categorical, Categorical],
                              latent_dims=[lat1, (lat2)], enc_config='nu', KL_calc='derivative', 
                              batch_size=batch_size, reduction='mean'
                             ).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
writer = None

In [5]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        # with poisevae._debug.Capturing() as output:
        results = vae([data[0].to(device, dtype=torch.float32), data[0].to(device, dtype=torch.float32)])#, 
                      # n_gibbs_iter=30)
        break

mu max: 0.16091635823249817 mu mean: 0.05937446281313896
mup max: 0.19173772633075714 mup mean: 0.04307272657752037
var min: 0.8618493676185608 var mean: 1.0055755376815796
varp min: 0.8768165707588196 varp mean: 0.9829151034355164
total loss: 8730.015625 kl term: 0.0
rec1 loss: 4403.138541666666 rec2 loss: 4326.877083333334

