In [1]:
import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.distributions import Normal, Laplace

import poisevae
from poisevae.datasets import MNIST_SVHN
from poisevae.networks.MNISTSVHNNetworks import EncMNIST, DecMNIST, EncSVHN, DecSVHN

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

In [2]:
HOME_PATH = os.path.expanduser('~')

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
MNIST_PATH = os.path.join(HOME_PATH, 'Datasets/MNIST/%s.pt')
SVHN_PATH = os.path.join(HOME_PATH, 'Datasets/SVHN/%s_32x32.mat')

joint_dataset_train = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'train', svhn_mat_path=SVHN_PATH % 'train')
joint_dataset_test = MNIST_SVHN(mnist_pt_path=MNIST_PATH % 'test', svhn_mat_path=SVHN_PATH % 'test')

In [5]:
batch_size = 8000
train_loader = torch.utils.data.DataLoader(joint_dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(joint_dataset_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
lat1, lat2 = 32, 32
enc_mnist = EncMNIST(lat1).to(device)
dec_mnist = DecMNIST(lat1).to(device)
enc_svhn = EncSVHN(lat2).to(device)
dec_svhn = DecSVHN(lat2).to(device)
    
vae = poisevae.POISEVAE_Gibbs('autograd',
                              [enc_mnist, enc_svhn], [dec_mnist, dec_svhn], likelihoods=[Laplace, Laplace],
                              latent_dims=[lat1, (lat2, 1, 1)], enc_config='nu', KL_calc='derivative', 
                              batch_size=batch_size
                             ).to(device)

In [21]:
# SVM-RBF
def eval_clf(train_data, test_data):
    train_X, train_Y = train_data[0].cpu().numpy(), train_data[1].cpu().numpy()
    test_X, test_Y = test_data[0].cpu().numpy(), test_data[1].cpu().numpy()
    
    clf = SVC(kernel='rbf')
    clf.fit(train_X, train_Y)
    Y_hat = clf.predict(test_X)
    acc = accuracy_score(test_Y, Y_hat)
    return clf, acc

# # Logistic Regression
# def eval_clf(train_data, test_data):
#     train_X, train_Y = train_data[0].cpu().numpy(), train_data[1].cpu().numpy()
#     test_X, test_Y = test_data[0].cpu().numpy(), test_data[1].cpu().numpy()
    
#     clf = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
#     clf.fit(train_X, train_Y)
#     Y_hat = clf.predict(test_X)
#     acc = accuracy_score(test_Y, Y_hat)
#     return clf, acc

# # One-hot linear model
# class LatentClassifier(nn.Module):
#     def __init__(self, lat_dim):
#         super(LatentClassifier, self).__init__()
#         self.mlp = nn.Linear(lat_dim, 10)

#     def forward(self, x):
#         return self.mlp(x)
    
# def eval_clf(train_data, test_data):
#     train_X, train_Y = train_data[0].to(device), train_data[1].to(device)#torch.nn.functional.one_hot(train_data[1], num_classes=10)
#     test_X, test_Y = test_data[0].to(device), test_data[1].to(device)#torch.nn.functional.one_hot(test_data[1], num_classes=10)
    
#     clf = LatentClassifier(train_X.shape[1]).to(device)
#     optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3)
#     clf.train()
#     losses = []
#     for _ in  range(40):
#         for i in range(0, train_X.shape[0], 100):
#             optimizer.zero_grad()
#             Y_hat = clf(train_X[i:min(i+100, train_X.shape[0])])
#             loss = torch.nn.functional.cross_entropy(Y_hat, train_Y[i:min(i+100, train_X.shape[0])])
#             loss.backward() 
#             optimizer.step()
#             losses.append(loss.item())
#     plt.plot(losses)
#     clf.eval()
#     with torch.no_grad():
#         _, Y_hat = clf(test_X).max(dim=1)
#         return clf, (Y_hat == test_Y).sum().item() / test_Y.shape[0]

In [22]:
def get_latents(loader):
    x, y = {'mnist': [], 'svhn': []}, []
    for i, data in enumerate(loader):
        y.append(data[-1])
        data = [data[0].to(device, dtype=torch.float32), data[1].to(device, dtype=torch.float32)]
        nu1, nu2, _, _ = vae.encode(data)
        t1, t2 = vae.get_t()
        mu = []
        for nu1_i, nu2_i, t1_i, t2_i in zip(nu1, nu2, t1, t2):
            if (nu1_i is None) and (nu2_i is None): # the Nones come together
                mu.append(-torch.reciprocal(2 * t2_i) * t1_i)
            else:
                mu.append(-torch.reciprocal(2 * (t2_i + nu2_i)) * (nu1_i + t1_i))
        x['mnist'].append(mu[0])
        x['svhn'].append(mu[1])
        # G = vae.get_G()
        # _, t2 = vae.get_t()
        # z, _ = vae._sampling(G.detach(), *ret, t2, n_iterations=50)
        # x['mnist'].append(z[0])
        # x['svhn'].append(z[1])
    y = torch.cat(y, 0)
    x['mnist'] = torch.cat(x['mnist'], 0)#.flatten(end_dim=1)
    x['svhn'] = torch.cat(x['svhn'], 0)#.flatten(end_dim=1)
    print(x['svhn'].shape)
    return x, y

In [23]:
# paths = glob.glob('../example/runs/MNIST_SVHN/fix_t/22*')
paths = [None]

clf_models = {'model': [], 'latent space': [], 'accuracy': [], 'classifier': []}

for path in paths:
    # vae, _, _ = poisevae.utils.load_checkpoint(vae, load_path=os.path.join(path, 'training_200.pt'))
    vae, _, epoch = poisevae.utils.load_checkpoint(vae, load_path='training_200.pt')
    with torch.no_grad():
        clf_train_x, clf_train_y = get_latents(train_loader)
        clf_test_x, clf_test_y = get_latents(test_loader)
        
    for i, lat_space in enumerate(('mnist', 'svhn')):
        clf_models['model'].append('POISE-VAE')
        clf_models['latent space'].append(lat_space)
        results = eval_clf([clf_train_x[lat_space], clf_train_y], 
                           [clf_test_x[lat_space], clf_test_y])
        clf_models['classifier'].append(results[0])
        clf_models['accuracy'].append(results[1])

torch.Size([56000, 32])
torch.Size([8000, 32])


In [24]:
pd.DataFrame(clf_models)

Unnamed: 0,model,latent space,accuracy,classifier
0,POISE-VAE,mnist,0.929125,SVC()
1,POISE-VAE,svhn,0.68875,SVC()
