In [1]:
import sys, os, copy, time
import random
import numpy as np
import torch
HOME_PATH = os.path.expanduser('~')

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import poisevae
from poisevae.datasets import MNIST
from poisevae.networks.PixelCNN_MNIST import EncMNIST, DecMNIST
from poisevae.networks.pixelcnn import PixelCNN

In [2]:
MNIST_PATH = os.path.join(HOME_PATH, 'Datasets/MNIST/%s.pt')

joint_dataset_train = MNIST(mnist_pt_path=MNIST_PATH % 'train')
joint_dataset_test = MNIST(mnist_pt_path=MNIST_PATH % 'test')

In [3]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(joint_dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(joint_dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)
len(train_loader), len(test_loader)

(468, 78)

In [4]:
lat1, lat2 = 20, 20
color_level = 2
device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'

enc_mnist1 = EncMNIST(lat1).to(device)
dec_mnist1 = DecMNIST(PixelCNN(lat1, 1, color_level), color_level).to(device)
enc_mnist2 = EncMNIST(lat1).to(device)
dec_mnist2 = DecMNIST(PixelCNN(lat2, 1, color_level), color_level).to(device)

# Options: 'derivative_autograd', 'derivative_gradient', and 'std_normal'
vae = poisevae.POISEVAE([enc_mnist1, enc_mnist2], [dec_mnist1, dec_mnist2], latent_dims=[lat1, (lat2)], 
                        enc_config='nu', KL_calc='std_normal', batch_size=batch_size
                       ).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
writer = None

In [5]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        results = vae([data[0].to(device, dtype=torch.float32), data[0].to(device, dtype=torch.float32)], dec_kwargs={'generate_mode':False})#, 
                      # n_gibbs_iter=30)
        break

In [6]:
results

{'z': [tensor([[ 0.2914,  0.1672, -0.4861,  ..., -0.3267, -0.0480, -0.3408],
          [-0.1464, -0.6487, -0.7244,  ...,  0.0053,  0.4496, -0.4386],
          [ 0.1726, -0.0591,  0.0678,  ...,  0.0322,  0.1208,  0.1369],
          ...,
          [ 0.1928, -0.1426,  0.1846,  ...,  0.7213,  0.5474,  0.2770],
          [-0.2040, -0.4464, -0.7286,  ...,  0.6933,  0.4375, -0.4019],
          [-0.1120, -0.1380, -0.3855,  ..., -0.6082, -0.1347,  0.0192]]),
  tensor([[-0.0898,  0.2724, -0.3209,  ...,  0.4118, -0.2057,  0.1090],
          [-0.0270,  0.0083, -0.2140,  ..., -0.0459,  0.3498,  0.1629],
          [-1.1466, -0.2157, -0.1386,  ..., -0.1205, -0.2268, -0.2199],
          ...,
          [-0.2823,  0.5071,  0.0437,  ...,  0.3191, -0.6092,  0.0818],
          [-0.4581,  0.2540,  0.3982,  ...,  0.2707, -0.1295,  0.0321],
          [-0.2472, -0.3399, -0.1706,  ..., -0.4798, -0.1646, -0.2468]])],
 'x_rec': [tensor([[[[0.1226, 0.1195, 0.1177,  ..., 0.1160, 0.1160, 0.1160],
            [0.1234