In [1]:
import numpy as np
import torch
import os
import sys
import matplotlib.pyplot as plt
from torchvision import transforms

from torch.utils.data import DataLoader

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path+"/models")
sys.path.append(module_path+"/train")
sys.path.append(module_path+"/experiments")
sys.path.append(module_path+"/datasets")

from factor_vae import FactorVAEDSprites, Discriminator, FactorVAEcnn
from beta_vae import Classifier, BetaVAECelebA
from datasets import train_test_random_split, load_dsprites, CustomDSpritesDatasetFactorVAE, AddUniformNoise, CustomDSpritesDataset, AddGeneratedNoise
from train import train_factor_vae, test_factor_vae, train_beta_vae, test_beta_vae
from entanglement_metric import entanglement_metric_factor_vae, entanglement_metric_beta_vae
from utils import save_checkpoint_factorvae, load_checkpoint_factorvae


In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device.type

'cuda'

In [5]:
seed = 2
os.environ['PYTHONHASHSEED'] = str(seed)


In [6]:
dataset = load_dsprites("../datasets/dsprites.npz",False)


In [7]:
#transform = None
transform = AddGeneratedNoise("atml/datasets/noisenet.pth",device)
#transform = transforms.Compose([AddUniformNoise(-.1, .1) ])

In [8]:
torch.manual_seed(seed)
np.random.seed(seed)
#data_ = CustomDSpritesDataset(dataset,seed=seed,transform=transform)
data_ = CustomDSpritesDatasetFactorVAE(dataset,seed=seed,transform=transform)



In [9]:
data_train, data_test = train_test_random_split(data_, 0.8, seed=seed)

In [10]:
batch_size = 64
train_loader = DataLoader(data_train, batch_size=batch_size,shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size,shuffle=False)

In [11]:
def plot_loss(loss_lists, title):
    if not isinstance(loss_lists[0],list):
        loss_lists=[loss_lists]
    for loss_list in loss_lists:
        plt.plot(np.arange(1,len(loss_list)+1, 1), loss_list)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training "+title)
        plt.show()

In [12]:
lr1s = [0.01]
params = [[5,0.0001]]
step = 50
starts= np.arange(0,50,step)

In [None]:
for lr1 in lr1s:
    for (gamma,lr2) in params:
        for start in starts:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            np.random.seed(seed)
            print("lr1 :", lr1)
            print("lr2 :", lr2)
            print("gamma :", gamma)
            print("epochs :", str(start)+" to "+str(start+step))
            model = FactorVAEDSprites()
            model.to(device)
            discriminator = Discriminator(nb_layers=6,hidden_dim=1000)
            discriminator.to(device)
            vae_optimizer = torch.optim.Adagrad(model.parameters(),lr=lr1)
            discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr2)
            #old_name = "/content/drive/My Drive/atml_models/factorvae_epochs"+str(start)+"_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".pth.tar"
            #model,discriminator,vae_optimizer,discriminator_optimizer,start_epoch = load_checkpoint_factorvae(model,discriminator,vae_optimizer,discriminator_optimizer,old_name)
            train_losses_list, recon_losses_list, kl_divs_list, tc_losses_list, discriminator_losses_list = train_factor_vae(model, discriminator, step, train_loader, vae_optimizer, discriminator_optimizer, gamma, 'gaussian', device = device)
            torch.save(model,"/content/drive/My Drive/atml_models/noisy3_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".dat")
            np.save("/content/drive/My Drive/atml_models/noisy3_recon_loss_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",recon_losses_list)
            np.save("/content/drive/My Drive/atml_models/noisy3_kl_divs_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",kl_divs_list)
            np.save("/content/drive/My Drive/atml_models/noisy3_tc_loss_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",tc_losses_list)
            plot_loss(train_losses_list, "Total loss")
            plot_loss(recon_losses_list, "Reconstruction loss")
            plot_loss(kl_divs_list, "KL divergence")
            plot_loss(tc_losses_list, "TC loss")
            plot_loss(discriminator_losses_list, "Discriminator loss")
            #new_name = "/content/drive/My Drive/atml_models/factorvae_epochs"+str(start+step)+"_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".pth.tar"
            #save_checkpoint_factorvae(model, discriminator, vae_optimizer, discriminator_optimizer, new_name, start+step)
            test_factor_vae(model, discriminator, test_loader, gamma, 'gaussian', device=device)
            print("Factor Vae metric: ")
            losses = entanglement_metric_factor_vae(model, data_, 500, 500, random_seeds=5, device=device,seed=seed)
            print("Accuracy: "+str(np.mean(losses)))
            print("Beta Vae metric: ")
            classifier = Classifier()
            classifier.to(device)
            optimizer = torch.optim.Adagrad(classifier.parameters(), lr=1e-2)
            losses, accuracies, test_accuracies = entanglement_metric_beta_vae(model, classifier, optimizer, 2000,  data_, 1000, 50, random_seeds=2, device=device,seed=seed)
            print("Accuracy: "+str(np.mean(test_accuracies)))
            plot_loss(losses, "NLL Loss")
            plot_loss(accuracies, "Accuracy")


lr1 : 0.01
lr2 : 0.0001
gamma : 5
epochs : 0 to 50




Epoch 0 finished, loss: 60.33667491707537, recon loss: 46.344402058463956, kl div: 13.774261785457181, TC loss: 0.043602170549187726, discriminator loss: 0.682157525081291
Epoch 1 finished, loss: 39.48241013288498, recon loss: 31.04262187745836, kl div: 8.300892201976644, TC loss: 0.027779216067983725, discriminator loss: 0.6853030270883917
Epoch 2 finished, loss: 38.87419954977102, recon loss: 30.42804006776876, kl div: 8.326883636777186, TC loss: 0.023855174140029096, discriminator loss: 0.6863800497635061
Epoch 3 finished, loss: 38.577981761760185, recon loss: 30.129345694142913, kl div: 8.341780984774232, TC loss: 0.021371015859800637, discriminator loss: 0.687119894706282
Epoch 4 finished, loss: 38.37926313446628, recon loss: 29.93422055741151, kl div: 8.347666165584492, TC loss: 0.01947528271326367, discriminator loss: 0.6875606259175887
Epoch 5 finished, loss: 38.25119371008542, recon loss: 29.80379956588149, kl div: 8.350886991040575, TC loss: 0.019301434449011542, discriminato

gamma 5 lr2 0.0001 ou 0.00001

gamma 40 lr2 0.00005