In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

import os
import sys
import random

from tqdm.notebook import tqdm

module_path = os.path.abspath(os.path.join('..'))
abs_path = "/content/drive/MyDrive/atml"
sys.path.append(abs_path+"/models")
sys.path.append(abs_path+"/train")
sys.path.append(abs_path+"/datasets")

from datasets import load_dsprites, CustomDSpritesDataset, train_test_random_split
from beta_vae import BetaVAEDSprites
from entanglement_metric import compute_mig

# Fix seed 
torch.manual_seed(2)
random.seed(2)
np.random.seed(2)

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [None]:
class NoiseGeneratorNet(nn.Module):
    def __init__(self, max_norm=1):
        super(NoiseGeneratorNet, self).__init__()

        self.max_norm = max_norm
        self.net = nn.Sequential(
            nn.Linear(5, 1200),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.ReLU(),
            nn.Linear(1200, 64 * 64),
        )

    def forward(self, x):
        out = self.net(x)
        out = out.view(-1, 64, 64)
        out = torch.clamp(out, min=-self.max_norm, max=self.max_norm)

        return out

In [None]:
def get_new_entangle_distangle_model_and_optimizer():
    model = BetaVAEDSprites()
    for param in model.encoder.parameters():
        param.requires_grad = False
    model_optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4)

    return model, model_optimizer

In [None]:
def save_model(epoch, model, optimizer, path):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
          }, path[:-4] + f'_epoch_{epoch}' + '.pth' )

In [None]:
def load_models_and_optimizers(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']  

    return epoch, model, optimizer

In [None]:
def test(dataset, test_loader, device,
         noisenet, noisenet_optimizer,
         ent_model, ent_optimizer,
         dis_model, dis_optimizer):

    ent_loss_func = nn.MSELoss()
    dis_loss_func = nn.MSELoss()

    noisenet.to(device)
    ent_model.to(device)
    dis_model.to(device)

    noisenet.eval()
    ent_model.eval()
    dis_model.eval()

    ent_losses = []
    dis_losses = []
    noisenet_losses = []
    total_samples = 0 

    with torch.no_grad():
        for lt_mx in tqdm(test_loader, leave=False):
            indices = dataset.latent_to_index(lt_mx)
            lt_values = dataset.retrieve_latent_values(lt_mx.numpy())[:, 1:].to(device)

            x = dataset[indices].to(device).type(torch.float)
            noise = noisenet(lt_values.type(torch.float)).to(device)
            x_mod = x + epsilon * noise

            ent_recon, _, _ = ent_model(x)
            dis_recon, _, _ = dis_model(x)

            ent_recon = ent_recon.view(-1, 64, 64)
            dis_recon = dis_recon.view(-1, 64, 64)

            ent_recon_loss = ent_loss_func(ent_recon, x_mod)
            dis_recon_loss = dis_loss_func(dis_recon, x_mod)
            noisenet_loss = ent_recon_loss - dis_recon_loss

            ent_losses.append(ent_recon_loss * len(lt_mx))
            dis_losses.append(dis_recon_loss * len(lt_mx))
            noisenet_losses.append(noisenet_loss * len(lt_mx))
            total_samples += len(lt_mx)

        avg_ent_loss = sum(ent_losses) / total_samples
        avg_dis_loss = sum(dis_losses) / total_samples
        avg_noisenet_loss = sum(noisenet_losses) / total_samples

        print(f"Val | Ent Rec Loss: {avg_ent_loss} | Dis Rec Loss: {avg_dis_loss} | Noisenet Loss: {avg_noisenet_loss}")

In [None]:
def train(epochs, epsilon, dataset, 
          train_loader, test_loader, load_previous, device,
          noisenet, noisenet_optimizer, noisenet_path,
          ent_model, ent_optimizer, entangled_model_path,
          dis_model, dis_optimizer, disentangled_model_path):
  
    if load_previous:
        epoch, noisenet, noisenet_optimizer = load_models_and_optimizers(noisenet, noisenet_optimizer, noisenet_path)
        _, ent_model, ent_optimizer = load_models_and_optimizers(ent_model, ent_optimizer, entangled_model_path)
        _, dis_model, dis_optimizer = load_models_and_optimizers(dis_model, dis_optimizer, disentangled_model_path)

    ent_loss_func = nn.MSELoss()
    dis_loss_func = nn.MSELoss()

    noisenet.to(device)
    ent_model.to(device)
    dis_model.to(device)
    
    noisenet.train()
    ent_model.train()
    dis_model.train()

    for i in tqdm(range(epochs)):

        ent_losses = []
        dis_losses = []
        noisenet_losses = []
        total_samples = 0 

        for lt_mx in tqdm(train_loader, leave=False):
            indices = dataset.latent_to_index(lt_mx)
            lt_values = dataset.retrieve_latent_values(lt_mx.numpy())[:, 1:].to(device)

            x = dataset[indices].to(device).type(torch.float)
            noise = noisenet(lt_values.type(torch.float)).to(device)
            x_mod = x + epsilon * noise

            ent_recon, _, _ = ent_model(x)
            dis_recon, _, _ = dis_model(x)

            ent_recon = ent_recon.view(-1, 64, 64)
            dis_recon = dis_recon.view(-1, 64, 64)

            # print(ent_recon)
            # print(dis_recon)

            ent_recon_loss = ent_loss_func(ent_recon, x_mod)
            dis_recon_loss = dis_loss_func(dis_recon, x_mod)
            noisenet_loss = ent_recon_loss - dis_recon_loss

            ent_recon_loss.backward(retain_graph=True)
            dis_recon_loss.backward(retain_graph=True)
            noisenet_loss.backward()

            ent_optimizer.step()
            ent_optimizer.zero_grad()
            
            dis_optimizer.step()
            dis_optimizer.zero_grad()
            
            noisenet_optimizer.step()
            noisenet_optimizer.zero_grad()

            ent_losses.append(ent_recon_loss * len(lt_mx))
            dis_losses.append(dis_recon_loss * len(lt_mx))
            noisenet_losses.append(noisenet_loss * len(lt_mx))
            total_samples += len(lt_mx) 

        avg_ent_loss = sum(ent_losses) / total_samples
        avg_dis_loss = sum(dis_losses) / total_samples
        avg_noisenet_loss = sum(noisenet_losses) / total_samples

        print(f"Epoch: {i+1} | Ent Rec Loss: {avg_ent_loss} | Dis Rec Loss: {avg_dis_loss} | Noisenet Loss: {avg_noisenet_loss}")

        if i % 10 == 1:
            test(dataset, test_loader, device,
                 noisenet, noisenet_optimizer,
                 ent_model, ent_optimizer, dis_model, dis_optimizer)
            save_model(epoch + i + 1, noisenet, noisenet_optimizer, noisenet_path)
            save_model(epoch + i + 1, ent_model, ent_optimizer, entangled_model_path)
            save_model(epoch + i + 1, dis_model, dis_optimizer, disentangled_model_path)
            


In [None]:
# Define paths
dataset_path = abs_path + "/datasets/dsprites.npz"
noisenet_path = abs_path + '/experiments/trained_models/noise_net.pth'
entangled_model_path = abs_path + '/experiments/trained_models/entangled_model.pth'
disentangled_model_path = abs_path + '/experiments/trained_models/disentangled_model.pth'

In [None]:
## Prepare dataset

dataset = CustomDSpritesDataset(load_dsprites(dataset_path, False))

latent_matrix = dataset.sample_latent(len(dataset))
# sample_latent_values = dataset.retrieve_latent_values(latent_matrix)[:, 1:]

data_train, data_test = train_test_random_split(latent_matrix, 0.8)

batch_size = 64
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False)

In [None]:
# Define model and optimizers
noisenet = NoiseGeneratorNet()
noisenet_optimizer = optim.Adam(noisenet.parameters(), lr=3e-4)

ent_model, ent_optimizer = get_new_entangle_distangle_model_and_optimizer()
dis_model, dis_optimizer = get_new_entangle_distangle_model_and_optimizer()

In [None]:
# Define hyperparameters
epochs = 30
load_previous = True
epsilon = .1

In [None]:
train(epochs, epsilon, dataset, 
      train_loader, test_loader, load_previous, device,
      noisenet, noisenet_optimizer, noisenet_path,
      ent_model, ent_optimizer, entangled_model_path,
      dis_model, dis_optimizer, disentangled_model_path)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 1 | Ent Rec Loss: 0.010092332027852535 | Dis Rec Loss: 86.66987609863281 | Noisenet Loss: -86.65958404541016


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 2 | Ent Rec Loss: 0.008881114423274994 | Dis Rec Loss: 86.53910827636719 | Noisenet Loss: -86.53013610839844


HBox(children=(FloatProgress(value=0.0, max=2304.0), HTML(value='')))

Val | Ent Rec Loss: 0.00846796203404665 | Dis Rec Loss: 86.5679702758789 | Noisenet Loss: -86.5595474243164


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 3 | Ent Rec Loss: 0.00831396784633398 | Dis Rec Loss: 86.51904296875 | Noisenet Loss: -86.51072692871094


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 4 | Ent Rec Loss: 0.007945897988975048 | Dis Rec Loss: 86.5387191772461 | Noisenet Loss: -86.5309829711914


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 5 | Ent Rec Loss: 0.007668616250157356 | Dis Rec Loss: 86.53812408447266 | Noisenet Loss: -86.53035736083984


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 6 | Ent Rec Loss: 0.007530121598392725 | Dis Rec Loss: 86.53314208984375 | Noisenet Loss: -86.52555084228516


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 7 | Ent Rec Loss: 0.0073712849989533424 | Dis Rec Loss: 86.52690887451172 | Noisenet Loss: -86.5193099975586


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 8 | Ent Rec Loss: 0.007240653038024902 | Dis Rec Loss: 86.56037139892578 | Noisenet Loss: -86.55345153808594


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 9 | Ent Rec Loss: 0.007184520363807678 | Dis Rec Loss: 86.52892303466797 | Noisenet Loss: -86.52153015136719


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 10 | Ent Rec Loss: 0.00710141658782959 | Dis Rec Loss: 86.54950714111328 | Noisenet Loss: -86.54254913330078


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 11 | Ent Rec Loss: 0.007024945691227913 | Dis Rec Loss: 86.53284454345703 | Noisenet Loss: -86.52590942382812


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 12 | Ent Rec Loss: 0.0069568101316690445 | Dis Rec Loss: 86.55333709716797 | Noisenet Loss: -86.54650115966797


HBox(children=(FloatProgress(value=0.0, max=2304.0), HTML(value='')))

Val | Ent Rec Loss: 0.00698463711887598 | Dis Rec Loss: 86.55661010742188 | Noisenet Loss: -86.549560546875


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 13 | Ent Rec Loss: 0.006939350161701441 | Dis Rec Loss: 86.54370880126953 | Noisenet Loss: -86.53668975830078


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 14 | Ent Rec Loss: 0.006892199162393808 | Dis Rec Loss: 86.53593444824219 | Noisenet Loss: -86.52890014648438


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 15 | Ent Rec Loss: 0.006857556756585836 | Dis Rec Loss: 86.53009796142578 | Noisenet Loss: -86.52323150634766


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 16 | Ent Rec Loss: 0.00680677080526948 | Dis Rec Loss: 86.5538101196289 | Noisenet Loss: -86.54694366455078


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 17 | Ent Rec Loss: 0.006766307633370161 | Dis Rec Loss: 86.51815032958984 | Noisenet Loss: -86.51153564453125


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 18 | Ent Rec Loss: 0.006742486730217934 | Dis Rec Loss: 86.5492172241211 | Noisenet Loss: -86.542724609375


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 19 | Ent Rec Loss: 0.006713216193020344 | Dis Rec Loss: 86.529052734375 | Noisenet Loss: -86.5224609375


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 20 | Ent Rec Loss: 0.006680119317024946 | Dis Rec Loss: 86.55083465576172 | Noisenet Loss: -86.5439682006836


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 21 | Ent Rec Loss: 0.0066680521704256535 | Dis Rec Loss: 86.54971313476562 | Noisenet Loss: -86.5431137084961


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 22 | Ent Rec Loss: 0.0066298674792051315 | Dis Rec Loss: 86.54182434082031 | Noisenet Loss: -86.5352783203125


HBox(children=(FloatProgress(value=0.0, max=2304.0), HTML(value='')))

Val | Ent Rec Loss: 0.006538779009133577 | Dis Rec Loss: 86.49456787109375 | Noisenet Loss: -86.4880599975586


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 23 | Ent Rec Loss: 0.0066008856520056725 | Dis Rec Loss: 86.53150939941406 | Noisenet Loss: -86.52497100830078


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 24 | Ent Rec Loss: 0.006582634523510933 | Dis Rec Loss: 86.5346450805664 | Noisenet Loss: -86.52802276611328


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 25 | Ent Rec Loss: 0.0065825022757053375 | Dis Rec Loss: 86.5115966796875 | Noisenet Loss: -86.50492095947266


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 26 | Ent Rec Loss: 0.006585023831576109 | Dis Rec Loss: 86.55160522460938 | Noisenet Loss: -86.54488372802734


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 27 | Ent Rec Loss: 0.006570361088961363 | Dis Rec Loss: 86.54497528076172 | Noisenet Loss: -86.53843688964844


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 28 | Ent Rec Loss: 0.006534775253385305 | Dis Rec Loss: 86.55387115478516 | Noisenet Loss: -86.54741668701172


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 29 | Ent Rec Loss: 0.0065191639587283134 | Dis Rec Loss: 86.53340911865234 | Noisenet Loss: -86.52701568603516


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 30 | Ent Rec Loss: 0.006491084583103657 | Dis Rec Loss: 86.55618286132812 | Noisenet Loss: -86.5495376586914


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 31 | Ent Rec Loss: 0.006475829053670168 | Dis Rec Loss: 86.5018081665039 | Noisenet Loss: -86.49529266357422


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

Epoch: 32 | Ent Rec Loss: 0.006460115779191256 | Dis Rec Loss: 86.5286865234375 | Noisenet Loss: -86.52244567871094


HBox(children=(FloatProgress(value=0.0, max=2304.0), HTML(value='')))

Val | Ent Rec Loss: 0.006363532040268183 | Dis Rec Loss: 86.598388671875 | Noisenet Loss: -86.59211730957031


HBox(children=(FloatProgress(value=0.0, max=9216.0), HTML(value='')))

In [None]:
# _, ent_model, ent_optimizer = load_models_and_optimizers(ent_model, ent_optimizer, entangled_model_path)
# for params in ent_model.decoder.parameters():
#     print(params)

In [None]:
# _, dis_model, dis_optimizer = load_models_and_optimizers(dis_model, dis_optimizer, disentangled_model_path)
# for params in dis_model.decoder.parameters():
#     print(params)

In [4]:
bce = torch.nn.BCELoss()

In [26]:
bce(torch.Tensor([.3, .84]), torch.Tensor([.3, .85]))

tensor(0.5170)

In [6]:
torch.Tensor([1, 1])

tensor([1., 1.])

In [10]:
torch.empty(3).random_(2)

tensor([0., 1., 1.])