In [1]:
import numpy as np
import os
import sys
import random
import torch
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path+"/models")
sys.path.append(module_path+"/train")
sys.path.append(module_path+"/datasets")

from factor_vae import FactorVAEDSprites, Discriminator
from beta_vae import Classifier
from datasets import train_test_random_split, load_dsprites, CustomDSpritesDatasetFactorVAE
from train import train_factor_vae, test_factor_vae
from entanglement_metric import entanglement_metric_factor_vae

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device.type

In [None]:
def plot_loss(loss_list, title):
    plt.plot(np.arange(1,len(loss_list)+1, 1), loss_list)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training "+title)
    plt.show()

In [None]:
dataset = load_dsprites("../datasets/dsprites.npz",False)

In [None]:
data_size = 1000
data_ = CustomDSpritesDatasetFactorVAE(dataset,data_size)

In [None]:
data_train, data_test = train_test_random_split(data_, 0.8)


In [None]:
batch_size = 64
train_loader = DataLoader(data_train, batch_size=batch_size,shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size,shuffle=False)

In [None]:
model = FactorVAEDSprites()
model.to(device)
discriminator = Discriminator(nb_layers=4,hidden_dim=500)
discriminator.to(device)
vae_optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
gamma = 0
epochs = 3

In [None]:
train_losses_list, recon_losses_list, kl_divs_list, tc_losses_list, discriminator_losses_list = train_factor_vae(model, discriminator, epochs, train_loader, vae_optimizer, discriminator_optimizer, gamma, 'bernoulli',device=device)

In [None]:
plot_loss(train_losses_list, "Total loss")
plot_loss(recon_losses_list, "Reconstruction loss")
plot_loss(kl_divs_list, "KL divergence")
plot_loss(tc_losses_list, "TC loss")
plot_loss(discriminator_losses_list, "Discriminator loss")

In [None]:
test_factor_vae(model, discriminator, test_loader, gamma, 'bernoulli', device=device)

In [None]:
with torch.no_grad():
    for i in [1,2,20]:
        data, _ = data_test[i]
        data = data.float()
        if device != None:
            data = data.to(device)
        recon, mu, logvar, z = model(data)
        plt.imshow(data.cpu(), cmap='Greys_r')
        plt.title("Real data")
        plt.show()
        plt.imshow(torch.sigmoid(recon.squeeze().cpu()), cmap='Greys_r')
        plt.title("Reconstruction")
        plt.show()


In [None]:
print("Factor Vae metric: ")
loss = entanglement_metric_factor_vae(model, data_, 300, 200, random_seeds=5)
print("Accuracy: "+str(1-loss))
print("Beta Vae metric: ")
classifier = Classifier()
optimizer = torch.optim.Adagrad(classifier.parameters(), lr=1e-2)
train_losses, train_accuracies, test_accuracy = entanglement_metric_beta_vae(model, classifier, optimizer, 1000,  data_, 500, 50, random_seeds=5)
print("Accuracy: "+str(test_accuracy))
plot_loss(train_losses, "NLL Loss")
plot_loss(train_accuracies, "Accuracy")