In [147]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [148]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [149]:
import numpy as np
import torch
import os
import sys
import random

from sklearn import metrics

module_path = os.path.abspath(os.path.join('..'))
abs_path = "/content/drive/MyDrive/atml"
sys.path.append(abs_path+"/models")
sys.path.append(abs_path+"/train")
sys.path.append(abs_path+"/datasets")

from datasets import load_dsprites, CustomDSpritesDataset
from beta_vae import BetaVAEDSprites
from control_vae import ControlVAEDSprites
from factor_vae import FactorVAEDSprites

# Fix seed 
torch.manual_seed(2)
random.seed(2)
np.random.seed(2)

In [153]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [154]:
def discretize_matrix(matrix, num_bins):
    '''
    The matrix would get discretized along its rows
    '''
    discretized_matrix = np.zeros_like(matrix)
    for i in range(matrix.shape[0]):
        discretized_matrix[i: ] = np.digitize(matrix[i: ], 
                                              bins=np.histogram_bin_edges(matrix[i: ], num_bins)[:-1])
        
    return discretized_matrix

In [155]:
def load_vae_model(model_type, model_path, *args, **kwargs):
    model = None

    if model_type == "beta_vae":
        model = BetaVAEDSprites(*args, **kwargs)
        checkpoint = torch.load(abs_path + model_path)
        model.load_state_dict(checkpoint['state_dict'])
    elif model_type == "control_vae":
        # model = ControlVAEDSprites(*args, **kwargs)
        model = torch.load(abs_path + model_path)
    elif model_type == "factor_vae":
        # model = FactorVAEDSprites(*args, **kwargs)
        model = torch.load(abs_path + model_path)
    else:
      raise NotImplementedError()

    model.eval()

    return model

In [156]:
def get_factor_and_z_matrices(model, dataset, num_samples, batch_size):
    model = model.to(device)

    factor_matrix = None
    z_matrix = None
    
    i = 0
    while i < num_samples:
      ns = min(num_samples - i, batch_size)

      sampled_factors = dataset.sample_latent(ns)
      sampled_indices = dataset.latent_to_index(sampled_factors)
      sampled_x = dataset[sampled_indices].type(torch.float).to(device)
      sampled_z = model.get_latent_representation(sampled_x).cpu().detach().numpy()

      if factor_matrix is None:
        factor_matrix = sampled_factors[:]
        z_matrix = sampled_z[:]
      else:
        factor_matrix = np.vstack((factor_matrix, sampled_factors))
        z_matrix = np.vstack((z_matrix, sampled_z))

      i += ns

    factor_matrix = factor_matrix.transpose()
    z_matrix = z_matrix.transpose()

    return factor_matrix, z_matrix

In [157]:
def compute_mig(model, dataset, num_samples=100000, batch_size=1024):
    factor_matrix, z_matrix = get_factor_and_z_matrices(model, dataset, num_samples, batch_size)
    z_matrix = discretize_matrix(z_matrix, num_bins=20)

    factor_matrix = factor_matrix.astype('uint8')
    z_matrix = z_matrix.astype('uint8') 

    mutual_info_matrix = np.zeros((z_matrix.shape[0], factor_matrix.shape[0])) # z_dim * num of factors
    for i in range(z_matrix.shape[0]):
        for j in range(factor_matrix.shape[0]):
            mutual_info_matrix[i, j] = metrics.mutual_info_score(z_matrix[i, :], factor_matrix[j, :])

    sorted_mi_matrix = np.sort(mutual_info_matrix, axis=0)[::-1]

    factor_entropies = np.zeros(factor_matrix.shape[0])
    for i in range(len(factor_entropies)):
        factor_entropies[i] = metrics.mutual_info_score(factor_matrix[i, :], factor_matrix[i, :])

    factor_entropies[factor_entropies == 0] = np.nan
    mig = np.nanmean(np.divide((sorted_mi_matrix[0, :] - sorted_mi_matrix[1, :]), factor_entropies))

    return mig

In [158]:
dataset = CustomDSpritesDataset(load_dsprites(abs_path + "/datasets/dsprites.npz", False))

In [144]:
for beta in [1, 4]:
  model = load_vae_model('beta_vae', f'/experiments/trained_models/betavae_models/betavae_beta{beta}_e50_alldata_n.pth.tar')
  print(f"beta: {beta} | mig: {compute_mig(model, dataset):.3f}")

beta: 1 | mig: 0.053
beta: 4 | mig: 0.217


In [145]:
beta_controller_args = {
    'C' : 0.5,
    'C_max' : 12,
    'C_step_val' : 0.15,
    'C_step_period' : 5000,
    'Kp' : 0.01,
    'Ki' : -0.001,
    'Kd' : 0.0
}

for cmax in [8, 10, 12]:
  model = load_vae_model('control_vae', f'/experiments/trained_models/controlvae_epoch50_lr1e2_Cmax{cmax}.dat', beta_controller_args)
  print(f"cmax: {cmax} | mig: {compute_mig(model, dataset):.3f}")

cmax: 8 | mig: 0.266
cmax: 10 | mig: 0.297
cmax: 12 | mig: 0.294


In [160]:
for gamma, lrd in [(5, '0.0001'), (40, '5e-05')]:
  model = load_vae_model('factor_vae', f'/experiments/trained_models/factorvae_epochs{epochs}_gamma{gamma}_lrvae0.01_lrd{lrd}.dat', beta_controller_args)
  print(f"gamma: {gamma} | mig: {compute_mig(model, dataset):.3f}")

gamma: 5 | mig: 0.160
gamma: 40 | mig: 0.118
