In [None]:
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from datetime import datetime
import os
from os import path
from tempfile import TemporaryFile
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable
import pycuda
from pycuda import compiler
import pycuda.driver as drv
from Datasets import Noisy_MNIST_Dataset, Tangled_MNIST_Dataset
from Plotting import plot_embeddings_single, plot_embeddings_private, display_reconstructions, save_disentangling_curves_single, save_disentangling_curves_private, grid_plot2d_single, grid_plot2d_private, plot_3d_embeddings, display_generated_images
from Nets import ACCA_Single, ACCA_Private, VCCA_Single, VCCA_Private, Discriminator_Really_Small, beefy_decoder, encoder
from Model_Training import train_acca_single, train_acca_private, train_vcca_single, train_vcca_private
from sklearn.svm import LinearSVC, LinearSVR
from sklearn.datasets import make_classification
from sklearn.neighbors import KernelDensity
from scipy.stats import multivariate_normal
import json

In [None]:
# General Parameters
MNIST_type = 'MNIST' # choose from {'MNIST', 'FashionMNIST', 'KMNIST'}
cuda = True # change to False if not using GPUs
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
device = "cuda:0"

# Create Datasets
train_dataset = Tangled_MNIST_Dataset(mnist_type=MNIST_type, train=True)
test_dataset = Tangled_MNIST_Dataset(mnist_type=MNIST_type, train=False)

validation_loader = torch.utils.data.DataLoader( # not using out-of-sample validation - using training data.  This is used to evaluate information content of representation
    train_dataset,
    batch_size=50000, shuffle=False, **kwargs)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=10000, shuffle=True, **kwargs)

num_epochs = 100

# Experiment 1: No Private Info

##  5 dimensions

### ACCA

In [None]:
"""
Experiment 1.2 ACCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 5
    q(z|x,y) not just q(z|x)
Plots:
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations 
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 1000
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 5
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/1/2/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_single5 = ACCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
discriminator_z_single5 = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
if torch.cuda.device_count() > 1:
    acca_single5 = nn.DataParallel(acca_single5)

# Train Model - returns (best_accuracy, best_epoch, best_state_dict, results, ae.state_dict(), best_acc_state_dict, models)
best_result, best_epoch, acca5_best_state_dict, results, last_state_dict, acca5_best_acc_state_dict, acca5_models = train_acca_single(acca_single5.to(device), z_dim, discriminator_z_single5.to(device), train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

""" Generate/Save Plots
    1. reconstructions
    2. training curves - 3 losses, class, rota, rotb
    3. random generations
"""
# Load best model
acca_single5.load_state_dict(acca5_best_state_dict)
acca_single5.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
# Gather embeddings and reconstructions
z = acca_single5.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_single5.module.decode(z)
z, labels, rot_x, rot_y, x_hat, y_hat = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'acca_reconstructions.png')

# 2. training curves
save_disentangling_curves_single(results, results_path + 'acca_training_curves.png')

# 3. random generations
z_fake = torch.cuda.FloatTensor(np.random.normal(0, 1, (18, z_dim)))
x_fake, y_fake = acca_single5.module.decode(z_fake)
z_fake, x_fake, y_fake = z_fake.detach().cpu().numpy(), x_fake.detach().cpu().numpy(), y_fake.detach().cpu().numpy()
display_generated_images(x_fake, y_fake, results_path + 'acca_random_generations.png')

### VCCA

In [None]:
"""
Experiment 1.1 VCCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 5
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 1000
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 5
dropout_prob = 0.0
best_results = {}
results_path = './results/1/2/vcca'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
vcca_single5 = VCCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
if torch.cuda.device_count() > 1:
    vcca_single5 = nn.DataParallel(vcca_single5)

best_result, best_epoch, vcca5_best_state_dict, results, last_state_dict, vcca5_best_acc_state_dict, vcca5_models = train_vcca_single(vcca_single5.to(device), z_dim, train_loader, validation_loader, test_loader, num_epochs, device)

""" Generate/Save Plots
    1. reconstructions
    2. training curves - 3 losses, class, rota, rotb
    3. random generations
"""
# Load best model
vcca_single5.load_state_dict(vcca5_best_state_dict)
vcca_single5.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))

# Gather embeddings and reconstructions
x_hat, y_hat, z_mu, z_logvar, z = vcca_single5(x.to(device),y.to(device))
z, labels, rot_x, rot_y, x_hat, y_hat = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'vcca_reconstructions.png')

# 2. training curves
save_disentangling_curves_single(results, results_path + 'vcca_training_curves.png')

# 3. random generations
z_fake = torch.cuda.FloatTensor(np.random.normal(0, 1, (18, z_dim)))
x_fake, y_fake = vcca_single5.module.decode(z_fake)
z_fake, x_fake, y_fake = z_fake.detach().cpu().numpy(), x_fake.detach().cpu().numpy(), y_fake.detach().cpu().numpy()
display_generated_images(x_fake, y_fake, results_path + 'vcca_random_generations.png')


## 2 dimensions

### ACCA

In [None]:
"""
Experiment 1.1 ACCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations over (-3,3)x(-3,3)
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 1000
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/1/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_single2 = ACCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
discriminator_z_single2 = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
if torch.cuda.device_count() > 1:
    acca_single2 = nn.DataParallel(acca_single2)

# Train Model
best_result, best_epoch, acca2_best_state_dict, acca2_results, acca2_last_state_dict, acca2_best_acc_state_dict, acca2_models = train_acca_single(acca_single2.to(device), z_dim, discriminator_z_single2.to(device), train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
    4. random generations over (-4,4)x(-4,4)
"""
# Load best model
acca_single2.load_state_dict(acca2_best_acc_state_dict)
acca_single2.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))

# Gather embeddings and reconstructions
z = acca_single2.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_single2.module.decode(z)
z_acca, labels_acca, rot_x_acca, rot_y_acca, x_hat_acca, y_hat_acca = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_single(z_acca,labels_acca,rot_x_acca,rot_y_acca,results_path + 'acca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat_acca, y, y_hat_acca, results_path + 'acca_reconstructions.png')

# 3. training curves
save_disentangling_curves_single(acca2_results, results_path + 'acca_training_curves.png')

# 4. random generations over (-2.5,2.5)x(-2.5,2.5)
grid_plot2d_single(acca_single2, results_path + 'acca_x_generations.png', output_view_name='x')
grid_plot2d_single(acca_single2, results_path + 'acca_y_generations.png', output_view_name='y')

### VCCA

In [None]:
"""
Experiment 1.1 VCCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations over (-3,3)x(-3,3)
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 1000
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
dropout_prob = 0.0
best_results = {}
results_path = './results/1/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
vcca_single2 = VCCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
if torch.cuda.device_count() > 1:
    vcca_single2 = nn.DataParallel(vcca_single2)

best_result, best_epoch, vcca2_best_state_dict, vcca2_results, last_state_dict, vcca2_best_acc_state_dict, vcca2_models = train_vcca_single(vcca_single2.to(device), z_dim, train_loader, validation_loader, test_loader, num_epochs, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
    4. random generations over (-3,3)x(-3,3)
"""
# Load best model
vcca_single2.load_state_dict(vcca2_best_state_dict)
vcca_single2.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))

# Gather embeddings and reconstructions
x_hat, y_hat, z_mu, z_logvar, z = vcca_single2(x.to(device),y.to(device))
z_vcca, labels_vcca, rot_x_vcca, rot_y_vcca, x_hat_vcca, y_hat_vcca = z_mu.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_single(z_vcca,labels_vcca,rot_x_vcca,rot_y_vcca,results_path + 'vcca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat_vcca, y, y_hat_vcca, results_path + 'vcca_reconstructions.png')

# 3. training curves
save_disentangling_curves_single(vcca2_results, results_path + 'vcca_training_curves.png')

# 4. random generations over (-2.5,2.5)x(-2.5,2.5)
grid_plot2d_single(vcca_single2, results_path + 'vcca_x_generations.png', output_view_name='x')
grid_plot2d_single(vcca_single2, results_path + 'vcca_y_generations.png', output_view_name='y')

### Goodness of Fit Analysis

In [None]:
x_flat = np.linspace(-4, 4, 100)
y_flat = np.linspace(-4, 4, 100)
xi, yi = np.meshgrid(x_flat, y_flat)

# Create a figure with 3 plot areas
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(21, 5))
nbins = 100
grid_data = np.vstack([xi.flatten(), yi.flatten()]).T
levels = np.arange(-10,0,0.5)

# Gaussian plot
rv = multivariate_normal([0, 0], [[1, 0.], [0., 1]])
gaussiani = rv.logpdf(grid_data)
axes[0].set_title('N(0,I) Log Probibility Contours')
im0 = axes[0].pcolormesh(xi, yi, gaussiani.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[0].contour(xi, yi, gaussiani.reshape(xi.shape), levels=levels)
axes[0].axis('equal')
axes[0].set_aspect('equal', 'box')
axes[0].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im0, ax=axes[0])

# VCCA Plot
kde_vcca = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z_vcca)
zi_vcca = kde_vcca.score_samples(grid_data)
axes[1].set_title('VCCA Embeddings on Out of Sample Data')
im1 = axes[1].pcolormesh(xi, yi, zi_vcca.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[1].contour(xi, yi, zi_vcca.reshape(xi.shape), levels=levels)
axes[1].axis('equal')
axes[1].set_aspect('equal', 'box')
axes[1].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im1, ax=axes[1])

# ACCA Plot
kde_acca = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z_acca)
zi_acca = kde_acca.score_samples(grid_data)
axes[2].set_title('ACCA Embeddings on Out of Sample Data')
im2 = axes[2].pcolormesh(xi, yi, zi_acca.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[2].contour(xi, yi, zi_acca.reshape(xi.shape), levels=levels)
axes[2].axis('equal')
axes[2].set_aspect('equal', 'box')
axes[2].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im2, ax=axes[2])

fig.tight_layout()

## Beefy Decoder Experiments

### ACCA

In [None]:
"""
Experiment 1.1 ACCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations over (-3,3)x(-3,3)
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 400
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/1/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_single2 = ACCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob, encoder_function=encoder, decoder_function=beefy_decoder)
discriminator_z_single2 = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
if torch.cuda.device_count() > 1:
    acca_single2 = nn.DataParallel(acca_single2)

# Train Model
best_result, best_epoch, acca2_best_state_dict, acca2_results, acca2_last_state_dict, acca2_best_acc_state_dict, acca2_models = train_acca_single(acca_single2.to(device), z_dim, discriminator_z_single2.to(device), train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

discriminator_z_single2.cpu()
del discriminator_z_single2
torch.cuda.empty_cache()

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
    4. random generations over (-4,4)x(-4,4)
"""

# Load best model
acca_single2.load_state_dict(acca2_best_acc_state_dict)
acca_single2.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
# x, y, rot_x, rot_y, labels = next(iter(validation_loader))

# Gather embeddings and reconstructions
z = acca_single2.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_single2.module.decode(z)
z_acca, labels_acca, rot_x_acca, rot_y_acca, x_hat_acca, y_hat_acca = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_single(z_acca,labels_acca,rot_x_acca,rot_y_acca,results_path + 'acca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat_acca, y, y_hat_acca, results_path + 'acca_reconstructions.png')

# 3. training curves
save_disentangling_curves_single(acca2_results, results_path + 'acca_training_curves.png')

# 4. random generations over (-2.5,2.5)x(-2.5,2.5)
grid_plot2d_single(acca_single2, results_path + 'acca_x_generations.png', output_view_name='x')
grid_plot2d_single(acca_single2, results_path + 'acca_y_generations.png', output_view_name='y')

del acca_single2
x.detach().cpu()
y.detach().cpu()
z.detach().cpu()
del x
del y
del x_hat
del y_hat
del z
torch.cuda.empty_cache()

### VCCA

In [None]:
"""
Experiment 1.1 VCCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations over (-3,3)x(-3,3)
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 300
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
dropout_prob = 0.0
best_results = {}
results_path = './results/1/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
vcca_single2 = VCCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob, encoder_function=encoder, decoder_function=beefy_decoder)
if torch.cuda.device_count() > 1:
    vcca_single2 = nn.DataParallel(vcca_single2)

best_result, best_epoch, vcca2_best_state_dict, vcca2_results, last_state_dict, vcca2_best_acc_state_dict, vcca2_models = train_vcca_single(vcca_single2.to(device), z_dim, train_loader, validation_loader, test_loader, num_epochs, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
    4. random generations over (-3,3)x(-3,3)
"""
# Load best model
vcca_single2.load_state_dict(vcca2_best_state_dict)
vcca_single2.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
# x, y, rot_x, rot_y, labels = next(iter(validation_loader))

# Gather embeddings and reconstructions
x_hat, y_hat, z_mu, z_logvar, z = vcca_single2(x.to(device),y.to(device))
z_vcca, labels_vcca, rot_x_vcca, rot_y_vcca, x_hat_vcca, y_hat_vcca = z_mu.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_single(z_vcca,labels_vcca,rot_x_vcca,rot_y_vcca,results_path + 'vcca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat_vcca, y, y_hat_vcca, results_path + 'vcca_reconstructions.png')

# 3. training curves
save_disentangling_curves_single(vcca2_results, results_path + 'vcca_training_curves.png')

# 4. random generations over (-2.5,2.5)x(-2.5,2.5)
grid_plot2d_single(vcca_single2, results_path + 'vcca_x_generations.png', output_view_name='x')
grid_plot2d_single(vcca_single2, results_path + 'vcca_y_generations.png', output_view_name='y')

del vcca_single2
x.detach().cpu()
y.detach().cpu()
del x
del y
del x_hat
del y_hat
del z_mu
del z_logvar
del z
torch.cuda.empty_cache()

### Beefy Goodness of Fit Analysis

In [None]:
x_flat = np.linspace(-4, 4, 100)
y_flat = np.linspace(-4, 4, 100)
xi, yi = np.meshgrid(x_flat, y_flat)

# Create a figure with 3 plot areas
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(21, 5))
nbins = 100
grid_data = np.vstack([xi.flatten(), yi.flatten()]).T
levels = np.arange(-10,0,0.5)

# Gaussian plot
rv = multivariate_normal([0, 0], [[1, 0.], [0., 1]])
gaussiani = rv.logpdf(grid_data)
axes[0].set_title('N(0,I) Log Probibility Contours')
im0 = axes[0].pcolormesh(xi, yi, gaussiani.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[0].contour(xi, yi, gaussiani.reshape(xi.shape), levels=levels)
axes[0].axis('equal')
axes[0].set_aspect('equal', 'box')
axes[0].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im0, ax=axes[0])

# VCCA Plot
kde_vcca = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z_vcca)
zi_vcca = kde_vcca.score_samples(grid_data)
axes[1].set_title('VCCA Embeddings on Out of Sample Data')
im1 = axes[1].pcolormesh(xi, yi, zi_vcca.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[1].contour(xi, yi, zi_vcca.reshape(xi.shape), levels=levels)
axes[1].axis('equal')
axes[1].set_aspect('equal', 'box')
axes[1].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im1, ax=axes[1])

# ACCA Plot
kde_acca = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z_acca)
zi_acca = kde_acca.score_samples(grid_data)
axes[2].set_title('ACCA Embeddings on Out of Sample Data')
im2 = axes[2].pcolormesh(xi, yi, zi_acca.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[2].contour(xi, yi, zi_acca.reshape(xi.shape), levels=levels)
axes[2].axis('equal')
axes[2].set_aspect('equal', 'box')
axes[2].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im2, ax=axes[2])

fig.tight_layout()

# Experiment 2: Private Variables

## 6 Dims

### VCCA_Private

In [None]:
"""
Experiment 2.1 VCCA_Private on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    hx_dim = 2
    hy_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 300
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 1 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
hx_dim = 2
hy_dim = 2
dropout_prob = 0.0
best_results = {}
results_path = './results/2/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
vcca_private = VCCA_Private(z_dim=z_dim, num_z_inputs=num_z, hx_dim=hx_dim, hy_dim=hy_dim, dropout_prob=dropout_prob)

if torch.cuda.device_count() > 1:
    vcca_private = nn.DataParallel(vcca_private)

# Train Model
best_result, best_epoch, vccap_best_state_dict, vccap_results, vccap_last_state_dict, vccap_best_acc_state_dict, vccap_state_dict_list = train_vcca_private(vcca_private.to(device), z_dim, hx_dim, hy_dim, train_loader, validation_loader, test_loader, num_epochs, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
"""
# Load best model
vcca_private.load_state_dict(vccap_best_state_dict)
vcca_private.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
x_val, y_val, rot_x_val, rot_y_val, labels_val = next(iter(validation_loader))

# Gather embeddings and reconstructions
x_hat, y_hat, z_mu, z_logvar, hx_mu, hx_logvar, hy_mu, hy_logvar, z, hx, hy = vcca_private(x,y)
z, hx, hy, labels, rot_x, rot_y, x_hat, y_hat = z_mu.detach().cpu().numpy(), hx_mu.detach().cpu().numpy(), hy_mu.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_private(z, hx, hy, labels, results_path + 'vcca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'vcca_reconstructions.png')

# 3. training curves
save_disentangling_curves_private(vccap_results, results_path + 'vcca_training_curves.png')

# 4. embeddings plots comparing info content
plot_embeddings_single(z, labels, rot_x, rot_y, results_path + 'vcca_z_embeddings.png')
plot_embeddings_single(hx, labels, rot_x, rot_y, results_path + 'vcca_hx_embeddings.png')
plot_embeddings_single(hy, labels, rot_x, rot_y, results_path + 'vcca_hy_embeddings.png')

# 5. Goodness of Fit
x_flat = np.linspace(-4, 4, 100)
y_flat = np.linspace(-4, 4, 100)
xi, yi = np.meshgrid(x_flat, y_flat)

# Create a figure with 4 plot areas
fig, axes = plt.subplots(ncols=4, nrows=1, figsize=(21, 5))
nbins = 100
grid_data = np.vstack([xi.flatten(), yi.flatten()]).T
levels = np.arange(-10,0,0.5)

# Gaussian plot
rv = multivariate_normal([0, 0], [[1, 0.], [0., 1]])
gaussiani = rv.logpdf(grid_data)
axes[0].set_title('N(0,I) Log Probibility Contours')
im0 = axes[0].pcolormesh(xi, yi, gaussiani.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[0].contour(xi, yi, gaussiani.reshape(xi.shape), levels=levels)
axes[0].axis('equal')
axes[0].set_aspect('equal', 'box')
axes[0].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im0, ax=axes[0])

# z Plot
kde_z = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z)
z_scores = kde_z.score_samples(grid_data)
axes[1].set_title('z Embeddings on Out of Sample Data')
im1 = axes[1].pcolormesh(xi, yi, z_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[1].contour(xi, yi, z_scores.reshape(xi.shape), levels=levels)
axes[1].axis('equal')
axes[1].set_aspect('equal', 'box')
axes[1].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im1, ax=axes[1])

# hx Plot
kde_hx = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(hx)
hx_scores = kde_hx.score_samples(grid_data)
axes[2].set_title('hx Embeddings on Out of Sample Data')
im2 = axes[2].pcolormesh(xi, yi, hx_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[2].contour(xi, yi, hx_scores.reshape(xi.shape), levels=levels)
axes[2].axis('equal')
axes[2].set_aspect('equal', 'box')
axes[2].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im2, ax=axes[2])

# hy Plot
kde_hy = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(hy)
hy_scores = kde_hy.score_samples(grid_data)
axes[3].set_title('hy Embeddings on Out of Sample Data')
im3 = axes[3].pcolormesh(xi, yi, hy_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[3].contour(xi, yi, hy_scores.reshape(xi.shape), levels=levels)
axes[3].axis('equal')
axes[3].set_aspect('equal', 'box')
axes[3].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im3, ax=axes[3])

fig.tight_layout()

### ACCA-Private

In [None]:
"""
Experiment 2.1 ACCA_Private on Tangled MNIST
Model params:
    No dropout
    z_dim = 2
    hx_dim = 2
    hy_dim = 2
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 2
batch_size = 300
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 1 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 2
hx_dim = 2
hy_dim = 2
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/2/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_private = ACCA_Private(z_dim=z_dim, num_z_inputs=num_z, hx_dim=hx_dim, hy_dim=hy_dim, dropout_prob=dropout_prob)
discriminator_z = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
discriminator_hx = Discriminator_Really_Small(hx_dim, disc_multiplier)
discriminator_hy = Discriminator_Really_Small(hy_dim, disc_multiplier)

if torch.cuda.device_count() > 1:
    acca_private = nn.DataParallel(acca_private)

# Train Model
best_result, best_epoch, accap_best_state_dict, accap_results, accap_last_state_dict, accap_best_acc_state_dict, accap_state_dict_list = train_acca_private(acca_private.to(device), z_dim, hx_dim, hy_dim, [discriminator_z.to(device), discriminator_hx.to(device), discriminator_hy.to(device)], train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
"""
# Load best model - accap_best_state_dict, accap_last_state_dict, accap_best_acc_state_dict, accap_state_dict_list
acca_private.load_state_dict(accap_best_acc_state_dict)
acca_private.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
x_val, y_val, rot_x_val, rot_y_val, labels_val = next(iter(validation_loader))

# Gather embeddings and reconstructions
z_ten, hx_ten, hy_ten = acca_private.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_private.module.decode(z_ten,hx_ten,hy_ten)
z, hx, hy, labels, rot_x, rot_y, x_hat, y_hat = z_ten.detach().cpu().numpy(), hx_ten.detach().cpu().numpy(), hy_ten.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. 3 embedding plots - with class, rota, and rotb coloring
# test
plot_embeddings_private(z, hx, hy, labels, results_path + 'acca_embeddings.png')
# plot_embeddings_experiment3b(z, hx, hy, labels, rot_x, rot_y, results_path + 'acca_embeddings.png')

# 2. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'acca_reconstructions.png')

# 3. training curves
save_disentangling_curves_private(accap_results, results_path + 'acca_training_curves.png')

# 4. embeddings plots comparing info content
plot_embeddings_single(z, labels, rot_x, rot_y, results_path + 'acca_z_embeddings.png')
plot_embeddings_single(hx, labels, rot_x, rot_y, results_path + 'acca_hx_embeddings.png')
plot_embeddings_single(hy, labels, rot_x, rot_y, results_path + 'acca_hy_embeddings.png')

# 5. Goodness of Fit
x_flat = np.linspace(-4, 4, 100)
y_flat = np.linspace(-4, 4, 100)
xi, yi = np.meshgrid(x_flat, y_flat)

# Create a figure with 4 plot areas
fig, axes = plt.subplots(ncols=4, nrows=1, figsize=(21, 5))
nbins = 100
grid_data = np.vstack([xi.flatten(), yi.flatten()]).T
levels = np.arange(-10,0,0.5)

# Gaussian plot
rv = multivariate_normal([0, 0], [[1, 0.], [0., 1]])
gaussiani = rv.logpdf(grid_data)
axes[0].set_title('N(0,I) Log Probibility Contours')
im0 = axes[0].pcolormesh(xi, yi, gaussiani.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[0].contour(xi, yi, gaussiani.reshape(xi.shape), levels=levels)
axes[0].axis('equal')
axes[0].set_aspect('equal', 'box')
axes[0].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im0, ax=axes[0])

# z Plot
kde_z = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(z)
z_scores = kde_z.score_samples(grid_data)
axes[1].set_title('z Embeddings on Out of Sample Data')
im1 = axes[1].pcolormesh(xi, yi, z_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[1].contour(xi, yi, z_scores.reshape(xi.shape), levels=levels)
axes[1].axis('equal')
axes[1].set_aspect('equal', 'box')
axes[1].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im1, ax=axes[1])

# hx Plot
kde_hx = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(hx)
hx_scores = kde_hx.score_samples(grid_data)
axes[2].set_title('hx Embeddings on Out of Sample Data')
im2 = axes[2].pcolormesh(xi, yi, hx_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[2].contour(xi, yi, hx_scores.reshape(xi.shape), levels=levels)
axes[2].axis('equal')
axes[2].set_aspect('equal', 'box')
axes[2].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im2, ax=axes[2])

# hy Plot
kde_hy = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(hy)
hy_scores = kde_hy.score_samples(grid_data)
axes[3].set_title('hy Embeddings on Out of Sample Data')
im3 = axes[3].pcolormesh(xi, yi, hy_scores.reshape(xi.shape), shading='gouraud', cmap='plasma')
axes[3].contour(xi, yi, hy_scores.reshape(xi.shape), levels=levels)
axes[3].axis('equal')
axes[3].set_aspect('equal', 'box')
axes[3].set(xlim=(-4, 4), ylim=(-4, 4))
fig.colorbar(im3, ax=axes[3])

fig.tight_layout()

## 12 Dims

### ACCA-Private

In [None]:
"""
Experiment 2.1 ACCA_Private on Tangled MNIST
Model params:
    No dropout
    z_dim = 4
    hx_dim = 4
    hy_dim = 4
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 2
batch_size = 300
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 1 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 4
hx_dim = 4
hy_dim = 4
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/2/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_private = ACCA_Private(z_dim=z_dim, num_z_inputs=num_z, hx_dim=hx_dim, hy_dim=hy_dim, dropout_prob=dropout_prob)
discriminator_z = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
discriminator_hx = Discriminator_Really_Small(hx_dim, disc_multiplier)
discriminator_hy = Discriminator_Really_Small(hy_dim, disc_multiplier)

if torch.cuda.device_count() > 1:
    acca_private = nn.DataParallel(acca_private)

# Train Model
best_result, best_epoch, accap_best_state_dict, accap_results, accap_last_state_dict, accap_best_acc_state_dict, accap_state_dict_list = train_acca_private(acca_private.to(device), z_dim, hx_dim, hy_dim, [discriminator_z.to(device), discriminator_hx.to(device), discriminator_hy.to(device)], train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
"""
# Load best model - accap_best_state_dict, accap_last_state_dict, accap_best_acc_state_dict, accap_state_dict_list
acca_private.load_state_dict(accap_best_acc_state_dict)
acca_private.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))

# Gather reconstructions and curves
z_ten, hx_ten, hy_ten = acca_private.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_private.module.decode(z_ten,hx_ten,hy_ten)
z, hx, hy, labels, rot_x, rot_y, x_hat, y_hat = z_ten.detach().cpu().numpy(), hx_ten.detach().cpu().numpy(), hy_ten.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 2. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'acca_reconstructions.png')

# 3. training curves
save_disentangling_curves_private(accap_results, results_path + 'acca_training_curves.png')

### VCCA-Private

In [None]:
"""
Experiment 2.1 VCCA_Private on Tangled MNIST
Model params:
    No dropout
    z_dim = 4
    hx_dim = 4
    hy_dim = 4
    q(z|x,y) not just q(z|x)
Plots:
    3 embedding plots - with class, rota, and rotb coloring
    reconstructions
    training curves - 3 losses, class, rota, rotb
    want to demonstrate goodness of fit argument
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 300
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 1 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 4
hx_dim = 4
hy_dim = 4
dropout_prob = 0.0
best_results = {}
results_path = './results/2/1/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
vcca_private = VCCA_Private(z_dim=z_dim, num_z_inputs=num_z, hx_dim=hx_dim, hy_dim=hy_dim, dropout_prob=dropout_prob)

if torch.cuda.device_count() > 1:
    vcca_private = nn.DataParallel(vcca_private)

# Train Model
best_result, best_epoch, vccap_best_state_dict, vccap_results, vccap_last_state_dict, vccap_best_acc_state_dict, vccap_state_dict_list = train_vcca_private(vcca_private.to(device), z_dim, hx_dim, hy_dim, train_loader, validation_loader, test_loader, num_epochs, device)

""" Generate/Save Plots
    1. 3 embedding plots - with class, rota, and rotb coloring
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
"""
# Load best model
vcca_private.load_state_dict(vccap_best_state_dict)
vcca_private.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))
x_val, y_val, rot_x_val, rot_y_val, labels_val = next(iter(validation_loader))

# Gather embeddings and reconstructions
x_hat, y_hat, z_mu, z_logvar, hx_mu, hx_logvar, hy_mu, hy_logvar, z, hx, hy = vcca_private(x,y)
z, hx, hy, labels, rot_x, rot_y, x_hat, y_hat = z_mu.detach().cpu().numpy(), hx_mu.detach().cpu().numpy(), hy_mu.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 2. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'vcca_reconstructions.png')

# 3. training curves
save_disentangling_curves_private(vccap_results, results_path + 'vcca_training_curves.png')

# Experiment 3: S-Curve

In [None]:
"""
ACCA_Single on Tangled MNIST
Model params:
    No dropout
    z_dim = 3
    q(z|x,y) not just q(z|x)
    p(z) is S-Manifold Distribution
Plots:
    embeddings
    reconstructions
    training curves - 3 losses, class, rota, rotb
    random generations 
"""

# For reproducibility
torch.manual_seed(10)

# Experiment Paramaters
# num_epochs = 100
batch_size = 1000
train_loader = torch.utils.data.DataLoader( 
    train_dataset,
    batch_size=batch_size, shuffle=True, **kwargs)
num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
z_dim = 3
dropout_prob = 0.0
best_results = {}
disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
recon_loss = 'L2'
results_path = './results/1/3/'
if not os.path.exists(results_path):  # checking on results directory
    os.makedirs(results_path)

# Network
acca_single_s = ACCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
discriminator_z_single = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
if torch.cuda.device_count() > 1:
    acca_single_s = nn.DataParallel(acca_single_s)

# Train Model
best_result, best_epoch, acca3_best_state_dict, acca3_results, acca3_last_state_dict, acca3_best_acc_state_dict, acca3_models = train_acca_single(acca_single_s.to(device), z_dim, discriminator_z_single.to(device), train_loader, validation_loader, test_loader, num_epochs, recon_loss, device, 'S_manifold')

""" Generate/Save Plots
    1. embeddings
    2. reconstructions
    3. training curves - 3 losses, class, rota, rotb
    4. random generations
"""
# Load best model
acca_single_s.load_state_dict(acca3_best_state_dict)
acca_single_s.eval()

# Get out-of-sample (test) data
x, y, rot_x, rot_y, labels = next(iter(test_loader))

# Gather embeddings and reconstructions
z = acca_single_s.module.encode(x.to(device),y.to(device))
x_hat, y_hat = acca_single_s.module.decode(z)
z, labels, rot_x, rot_y, x_hat, y_hat = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

# 1. embeddings
plot_3d_embeddings(z, labels, results_path + 's_prior.png', cmap_name='jet', plot_dataset=True)
plot_3d_embeddings(z, labels, results_path + 'embedding.png', cmap_name='jet')
plot_3d_embeddings(z, rot_x, results_path + 'embedding.png', cmap_name='jet')
plot_3d_embeddings(z, rot_y, results_path + 'embedding.png', cmap_name='jet')

# 2. reconstructions
display_reconstructions(x, x_hat, y, y_hat, results_path + 'reconstructions.png')

# 3. training curves
save_disentangling_curves_single(acca3_results, results_path + 'training_curves.png')

# Experiment 4: Information Capacity in ACCA

In [None]:
def experiment4(z_dim):
    # For reproducibility
    torch.manual_seed(10)

    # Experiment Paramaters
    # num_epochs = 2
    batch_size = 1000
    train_loader = torch.utils.data.DataLoader( 
        train_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)
    num_z = 2 # 2 indicates encoder is q(z|x,y), 1 indicates encoder is q(z|x) - though with 1, there is no hope of view-specific info from y making it into z, only view-specific info from x
    dropout_prob = 0.0
    best_results = {}
    disc_multiplier = 6 # a complexity multiplier (number of neurons) for the Discriminator_Really_Small discriminator - see Nets.py
    recon_loss = 'L2'
    results_path = './results/4/1/'
    if not os.path.exists(results_path):  # checking on results directory
        os.makedirs(results_path)

    # Network
    acca_single = ACCA_Single(z_dim=z_dim, num_z=num_z, dropout_prob=dropout_prob)
    discriminator_z_single = Discriminator_Really_Small(z_dim, disc_multiplier)  #max(z_dim, 6))
    if torch.cuda.device_count() > 1:
        acca_single = nn.DataParallel(acca_single)

    # Train Model - returns (best_accuracy, best_epoch, best_state_dict, results, ae.state_dict(), best_acc_state_dict, models)
    best_result, best_epoch, acca_best_state_dict, results, last_state_dict, acca_best_acc_state_dict, acca_models = train_acca_single(acca_single.to(device), z_dim, discriminator_z_single.to(device), train_loader, validation_loader, test_loader, num_epochs, recon_loss, device)

    """ Generate/Save Plots
        1. reconstructions
        2. training curves - 3 losses, class, rota, rotb
        3. random generations
    """
    # Load best model
    acca_single.load_state_dict(acca_best_state_dict)
    acca_single.eval()

    # Get out-of-sample (test) data
    x, y, rot_x, rot_y, labels = next(iter(test_loader))
    # Gather embeddings and reconstructions
    z = acca_single.module.encode(x.to(device),y.to(device))
    x_hat, y_hat = acca_single.module.decode(z)
    z, labels, rot_x, rot_y, x_hat, y_hat = z.detach().cpu().numpy(), labels.detach().cpu().numpy(), rot_x.detach().cpu().numpy(), rot_y.detach().cpu().numpy(), x_hat.detach().cpu().numpy(), y_hat.detach().cpu().numpy()

    # 1. reconstructions
    display_reconstructions(x, x_hat, y, y_hat, results_path + 'acca_reconstructions.png')

    # 2. training curves
    save_disentangling_curves_single(results, results_path + 'acca_training_curves.png')

    # 3. random generations
    z_fake = torch.cuda.FloatTensor(np.random.normal(0, 1, (18, z_dim)))
    x_fake, y_fake = acca_single.module.decode(z_fake)
    z_fake, x_fake, y_fake = z_fake.detach().cpu().numpy(), x_fake.detach().cpu().numpy(), y_fake.detach().cpu().numpy()
    display_generated_images(x_fake, y_fake, results_path + 'acca_random_generations.png')

    return results, acca_best_state_dict 

## Experiments

In [None]:
results_3, state_dict_3 = experiment4(3)

In [None]:
results_4, state_dict_4 = experiment4(4)

In [None]:
results_6, state_dict_6 = experiment4(6)

In [None]:
results_8, state_dict_8 = experiment4(8)

In [None]:
results_10, state_dict_10 = experiment4(10)

In [None]:
results_15, state_dict_15 = experiment4(15)

In [None]:
results_20, state_dict_20 = experiment4(20)