In [None]:
import os

import numpy as np

import torch
import torch.nn as nn
from torch.distributions import Normal, Laplace

import poisevae
from poisevae.datasets import MNIST_SVHN
from poisevae.networks.MNISTSVHNNetworks import EncMNIST, DecMNIST, EncSVHN, DecSVHN

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

In [2]:
HOME_PATH = os.path.expanduser('~')

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
MNIST_PATH = os.path.join(HOME_PATH, 'Datasets/MNIST/%s.pt')
SVHN_PATH = os.path.join(HOME_PATH, 'Datasets/SVHN/%s_32x32.mat')

joint_dataset_train = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'train', svhn_mat_path=SVHN_PATH % 'train')
joint_dataset_test = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'test', svhn_mat_path=SVHN_PATH % 'test')

In [5]:
batch_size = 1000
loader = torch.utils.data.DataLoader(joint_dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)

(234, 39)

In [7]:
lat1, lat2 = 16, 24
enc_mnist = EncMNIST(lat1).to(device)
dec_mnist = DecMNIST(lat1).to(device)
enc_svhn = EncSVHN(lat2).to(device)
dec_svhn = DecSVHN(lat2).to(device)
    
vae = poisevae.POISEVAE([enc_mnist, enc_svhn], [dec_mnist, dec_svhn], likelihoods=[Laplace, Laplace],
                        latent_dims=[lat1, (lat2, 1, 1)], batch_size=batch_size).to(device)

In [16]:
path = 'runs/MNIST_SVHN'
vae, _, epoch = poisevae.utils.load_checkpoint(vae, load_path=sorted(glob.glob(os.path.join(PATH, 'train*.pt')))[-1])
epoch

In [17]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        label = data[-1]
        data = [data[0].to(device, dtype=torch.float32), data[1].to(device, dtype=torch.float32)]
        results = vae(data)
        break

In [None]:
MNIST_latent_clf = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter=1000)
MNIST_latent_clf.fit(data, label)

In [None]:
SVHN_latent_clf = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter=1000)
SVHN_latent_clf.fit(data, label)