In [16]:
import numpy as np
import os
import sys
import random
import torch
import matplotlib.pyplot as plt
import matplotlib.lines as mlines


from torch.utils.data import DataLoader

module_path = os.path.abspath(os.path.join('../..'))
sys.path.append(module_path+"/models")
sys.path.append(module_path+"/train")
sys.path.append(module_path+"/datasets")

from datasets import train_test_random_split, load_dsprites, CustomDSpritesDataset
from entanglement_metric import entanglement_metric_factor_vae, entanglement_metric_beta_vae, compute_mig
from beta_vae import Classifier
from factor_vae import Discriminator
from train import test_beta_vae, test_factor_vae, test_control_vae

In [17]:
seed = 2

In [18]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device.type

'cpu'

In [19]:
def plot_loss(loss_lists, title):
    if not isinstance(loss_lists[0], list):
        loss_lists = [loss_lists]
    for loss_list in loss_lists:
        plt.plot(np.arange(1,len(loss_list)+1, 1), loss_list)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training "+title+" of the classifier over epochs")
    plt.show()

In [20]:
dataset = load_dsprites("../../datasets/dsprites.npz",False)

In [21]:
data_ = CustomDSpritesDataset(dataset,seed=seed)

In [22]:
transform = None
transform_needs_latents = False

# transform = AddUniformNoise(-.1, .1)
# transform_needs_latents = False

# transform = AddGeneratedNoise(abs_path + "/datasets/noisenet.pth", device)
# transform_needs_latents = True

In [23]:
data_train, data_test = train_test_random_split(data_.idx,0.8, seed=seed)

In [24]:
batch_size=64
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False)

In [30]:
folder = "normal_dataset_bernoulli_cnn_models"

In [31]:
model_names = [f for f in os.listdir('../trained_models/'+folder) if f[-4:] =='.dat' ]
model_names

['cnnbetavae_epochs50_gamma4_lrvae0.0001_lrd0.0001.dat',
 'cnnfactorvae_epochs50_gamma20_lrvae0.0001_lrd1e-05.dat']

In [32]:

recon_losses = []
betavae_metric_accuracies=[]
factorvae_metric_accuracies=[]
mig_scores = []


In [None]:
i=1

for name in model_names:
    model = torch.load('../trained_models/'+folder+"/"+name, map_location=device)
    print("model "+str(i)+": "+name)
    if "betavae" in name:
        recon_loss = test_beta_vae(model, test_loader,0, 'bernoulli',data_, transform, transform_needs_latents, device=device)
    elif "controlvae" in name:
        recon_loss = test_control_vae(model, test_loader, 'bernoulli',data_, transform, transform_needs_latents, device=device)
    elif "factorvae" in name:
        discriminator = Discriminator()
        recon_loss = test_factor_vae(model,discriminator, test_loader,0, 'bernoulli',data_, transform, transform_needs_latents, device=device)
    else:
        print("Error in the dataset name")
        break
    recon_losses.append(recon_loss)
    print("Reconstruction loss: "+str(recon_loss))
    accuracies = entanglement_metric_factor_vae(model, data_, 500, 200, random_seeds=5, device = device,seed=seed)
    factorvae_metric_accuracies.append(accuracies)
    print("Factor Vae metric: ")
    print("Accuracy: "+str(np.mean(accuracies)))
    print("Beta Vae metric: ")
    torch.manual_seed(seed)
    classifier = Classifier()
    classifier.to(device)
    optimizer = torch.optim.Adagrad(classifier.parameters(), lr=1e-2)
    train_losses, train_accuracies, test_accuracies = entanglement_metric_beta_vae(model, classifier, optimizer, 10000,  data_, 1000, 50, random_seeds=5, device = device,seed=seed)
    betavae_metric_accuracies.append(test_accuracies)
    print("Accuracy: "+str(np.mean(test_accuracies)))
    plot_loss(train_losses, "NLL Loss")
    plot_loss(train_accuracies, "Accuracy")
    print("Mig metric")
    scores = compute_mig(model, data_, num_samples=100000, random_seeds=5, device=device, seed=seed)
    mig_scores.append(scores)
    print("Scores: "+str(scores))
    print("Score: "+str(np.mean(scores)))
    i+=1

model 1: cnnbetavae_epochs50_gamma4_lrvae0.0001_lrd0.0001.dat


  0%|          | 0/2304 [00:00<?, ?it/s]



Reconstruction loss: 45.15021482772298
accuracies : [0.6779999999999999, 0.692, 0.65, 0.632, 0.646]
[[[ 0.   2.2  3.4  8.4  0.   0.2]
  [ 0.  16.2  6.   8.2  3.8  0. ]
  [ 0.   0.   3.8 25.   0.   0. ]
  [ 0.   8.   7.2  1.4 16.6 80. ]
  [ 0.   1.6 36.4  2.  12.2 10.6]
  [ 0.  30.6 15.  15.4  0.  11.2]
  [ 0.  44.   3.6 16.4  0.   0.8]
  [ 0.   0.2  4.6 16.8  0.   0. ]
  [ 0.   0.   6.4  4.4  0.   0.2]
  [ 0.   3.   7.4  1.  65.8  0. ]]]
0.6596
Factor Vae metric: 
Accuracy: 0.6596
Beta Vae metric: 


In [None]:
np.save(folder+"_model_names.npy", np.array(model_names))
np.save(folder+"_test_recon_losses.npy", np.array(recon_losses))
np.save(folder+"_betavaemetric_scores.npy", np.array(betavae_metric_accuracies))
np.save(folder+"_factorvaemetric_scores.npy", np.array(factorvae_metric_accuracies))
np.save(folder+"_mig_scores.npy", np.array(mig_scores))