In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
%cd drive/My\ Drive/
%cd atml
%cd experiments

/content/drive/My Drive
/content/drive/My Drive/atml
/content/drive/My Drive/atml/experiments


In [4]:
import numpy as np
import torch
import os
import sys
import matplotlib.pyplot as plt
from torchvision import transforms

from torch.utils.data import DataLoader

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path+"/models")
sys.path.append(module_path+"/train")
sys.path.append(module_path+"/experiments")
sys.path.append(module_path+"/datasets")

from factor_vae import FactorVAEDSprites, Discriminator, FactorVAEcnn
from beta_vae import Classifier, BetaVAECelebA
from datasets import train_test_random_split, load_dsprites, AddUniformNoise, CustomDSpritesDataset, AddGeneratedNoise
from train import train_factor_vae, test_factor_vae, train_beta_vae, test_beta_vae
from entanglement_metric import entanglement_metric_factor_vae, entanglement_metric_beta_vae
from utils import save_checkpoint_factorvae, load_checkpoint_factorvae

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device.type

'cuda'

In [6]:
seed = 2
os.environ['PYTHONHASHSEED'] = str(seed)

In [7]:
dataset = load_dsprites("../datasets/dsprites.npz", False)

In [8]:
# transform = None
# transform_needs_latents = False

# transform = AddUniformNoise(-.1, .1)
# transform_needs_latents = False

transform = AddGeneratedNoise("../datasets/noisenet.pth", device)
transform_needs_latents = True

In [9]:
torch.manual_seed(seed)
np.random.seed(seed)
data_ = CustomDSpritesDataset(dataset, seed=seed)

In [10]:
data_train, data_test = train_test_random_split(data_.idx, 0.8, seed=seed)

In [11]:
batch_size = 64
train_loader = DataLoader(data_train, batch_size=batch_size,shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size,shuffle=False)

In [12]:
def plot_loss(loss_lists, title):
    if not isinstance(loss_lists[0],list):
        loss_lists=[loss_lists]
    for loss_list in loss_lists:
        plt.plot(np.arange(1,len(loss_list)+1, 1), loss_list)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training "+title)
        plt.show()

In [13]:
lr1s = [0.01]
params = [[5,0.0001]]
step = 50
starts= np.arange(0,50,step)

In [14]:
for lr1 in lr1s:
    for (gamma,lr2) in params:
        for start in starts:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            np.random.seed(seed)
            print("lr1 :", lr1)
            print("lr2 :", lr2)
            print("gamma :", gamma)
            print("epochs :", str(start)+" to "+str(start+step))
            model = FactorVAEDSprites()
            model.to(device)
            discriminator = Discriminator(nb_layers=6, hidden_dim=1000)
            discriminator.to(device)
            vae_optimizer = torch.optim.Adagrad(model.parameters(),lr=lr1)
            discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr2)
            #old_name = "/content/drive/My Drive/atml_models/factorvae_epochs"+str(start)+"_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".pth.tar"
            #model,discriminator,vae_optimizer,discriminator_optimizer,start_epoch = load_checkpoint_factorvae(model,discriminator,vae_optimizer,discriminator_optimizer,old_name)
            train_losses_list, recon_losses_list, kl_divs_list, tc_losses_list, discriminator_losses_list = train_factor_vae(model, discriminator, step, train_loader, vae_optimizer, discriminator_optimizer, gamma, 'gaussian', 
                                                                                                                            data_, transform, transform_needs_latents, device=device)
            
            torch.save(model,"/content/drive/My Drive/atml_models/noisy3_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".dat")
            np.save("/content/drive/My Drive/atml_models/noisy3_recon_loss_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",recon_losses_list)
            np.save("/content/drive/My Drive/atml_models/noisy3_kl_divs_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",kl_divs_list)
            np.save("/content/drive/My Drive/atml_models/noisy3_tc_loss_factorvae_epochs50_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".npy",tc_losses_list)
            
            plot_loss(train_losses_list, "Total loss")
            plot_loss(recon_losses_list, "Reconstruction loss")
            plot_loss(kl_divs_list, "KL divergence")
            plot_loss(tc_losses_list, "TC loss")
            plot_loss(discriminator_losses_list, "Discriminator loss")
            #new_name = "/content/drive/My Drive/atml_models/factorvae_epochs"+str(start+step)+"_gamma"+str(gamma)+"_lrvae"+str(lr1)+"_lrd"+str(lr2)+".pth.tar"
            #save_checkpoint_factorvae(model, discriminator, vae_optimizer, discriminator_optimizer, new_name, start+step)
            test_factor_vae(model, discriminator, test_loader, gamma, 'gaussian', 
                            data_, transform, transform_needs_latents, device=device)
            print("Factor Vae metric: ")
            losses = entanglement_metric_factor_vae(model, data_, 500, 500, random_seeds=5, device=device,seed=seed)
            print("Accuracy: "+str(np.mean(losses)))
            print("Beta Vae metric: ")
            classifier = Classifier()
            classifier.to(device)
            optimizer = torch.optim.Adagrad(classifier.parameters(), lr=1e-2)
            losses, accuracies, test_accuracies = entanglement_metric_beta_vae(model, classifier, optimizer, 2000,  data_, 1000, 50, random_seeds=2, device=device,seed=seed)
            print("Accuracy: "+str(np.mean(test_accuracies)))
            plot_loss(losses, "NLL Loss")
            plot_loss(accuracies, "Accuracy")


KeyboardInterrupt: ignored

gamma 5 lr2 0.0001 ou 0.00001

gamma 40 lr2 0.00005