In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os 

os.chdir("/content/drive/MyDrive/Colab Notebooks/project2")

## Import Modules

In [3]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import math

from math import exp
from tqdm.auto import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision.utils import make_grid

## Configuration

In [4]:
ROOT = os.getcwd()

class Config:
    seed = 42
    device = "cuda"if torch.cuda.is_available() else 'cpu'

    max_iter = 1500000
    batch_size = 64

    C_max = 16
    C_step_value = 0.15

    lr = 1e-4
    beta1 = 0.5
    beta2 = 0.999

    s_beta1 = 0.5

    log_dir = os.path.join(ROOT, "log")
    ckpt_dir = os.path.join(ROOT, "checkpoint")
    save_step = 250000

if not os.path.exists(Config.log_dir):  
    os.makedirs(Config.log_dir)

if not os.path.exists(Config.ckpt_dir):  
    os.makedirs(Config.ckpt_dir)

print(Config.device)

cuda


In [5]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = True 

seed_everything(Config.seed)

## Dataset & Add-On

In [6]:
!git clone https://github.com/YannDubs/disentangling-vae.git

fatal: destination path 'disentangling-vae' already exists and is not an empty directory.


In [7]:
os.chdir(os.path.join(ROOT, "disentangling-vae"))

from utils import datasets
from utils.viz_helpers import get_samples

train_loader = datasets.get_dataloaders("dsprites", batch_size=Config.batch_size)

## Model

In [8]:
class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)

class Encoder(nn.Module):
      def __init__(self, latent_dim, nc):
          super(Encoder, self).__init__()
          self.latent_dim = latent_dim
          self.nc = nc
          self.net = nn.Sequential(
              nn.Conv2d(nc, 32, 4, 2, 1),
              nn.ReLU(True),
              nn.Conv2d(32, 32, 4, 2, 1),
              nn.ReLU(True),
              nn.Conv2d(32, 64, 4, 2, 1),
              nn.ReLU(True),
              nn.Conv2d(64, 64, 4, 2, 1),
              nn.ReLU(True),
              nn.Conv2d(64, 256, 4, 1),
              nn.ReLU(True),
              View((-1, 256*1*1)),
              nn.Linear(256, latent_dim*2),
          )
          self.weight_init()

      def weight_init(self):
          for block in self._modules:
              for m in self._modules[block]:
                  if isinstance(m, (nn.Linear, nn.Conv2d)):
                      init.kaiming_normal_(m.weight, nonlinearity='relu')

      def forward(self, x):
          distributions = self.net(x)
          mu = distributions[:, :self.latent_dim]
          logvar = distributions[:, self.latent_dim:]

          return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, nc):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.nc = nc
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            View((-1, 256, 1, 1)),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 64, 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, nc, 4, 2, 1),
        )

        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                if isinstance(m, (nn.Linear, nn.ConvTranspose2d)):
                    init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, z):
        x_recon = self.net(z)

        return x_recon


class BetaVAE(nn.Module):

    def __init__(self, latent_dim=10):
        super(BetaVAE, self).__init__()
        self.latent_dim = latent_dim
        self.img_size = (1, 64, 64)
        self.num_pixels = self.img_size[1] * self.img_size[2]

        self.encoder = Encoder(latent_dim, self.img_size[0])
        self.decoder = Decoder(latent_dim, self.img_size[0])

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)

        return x_recon, mu, logvar

    def reparameterize(self, mu, logvar):
        std = logvar.div(2).exp()
        eps = Variable(std.data.new(std.size()).normal_())
        return mu + std*eps

In [9]:
model = BetaVAE().to(Config.device)
optimizer = optim.Adam(model.parameters(), lr=Config.lr, betas=(Config.beta1, Config.beta2))
print(model)

BetaVAE(
  (encoder): Encoder(
    (net): Sequential(
      (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(64, 256, kernel_size=(4, 4), stride=(1, 1))
      (9): ReLU(inplace=True)
      (10): View()
      (11): Linear(in_features=256, out_features=20, bias=True)
    )
  )
  (decoder): Decoder(
    (net): Sequential(
      (0): Linear(in_features=10, out_features=256, bias=True)
      (1): View()
      (2): ReLU(inplace=True)
      (3): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(1, 1))
      (4): ReLU(inplace=True)
      (5): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding

## Objective

In [10]:
class PIDControl():
    def __init__(self):
        self.I_k1 = 0.0
        self.W_k1 = 1.0
        self.e_k1 = 0.0
        
    def _Kp_fun(self, Err, scale=1):
        return 1.0/(1.0 + float(scale)*exp(Err))
        
    def pid(self, exp_KL, kl_divergence, Kp=0.01, Ki=-0.001, Kd=0.01):
        error_k = exp_KL - kl_divergence
        ## comput U as the control factor
        Pk = Kp * self._Kp_fun(error_k)+1
        Ik = self.I_k1 + Ki * error_k
        
        ## window up for integrator
        if self.W_k1 < 1:
            Ik = self.I_k1
            
        Wk = Pk + Ik
        self.W_k1 = Wk
        self.I_k1 = Ik
        
        ## min and max value
        if Wk < 1:
            Wk = 1
        
        return Wk

In [11]:
def reconstruction_loss(x, x_recon):
    batch_size = x.size(0)
    assert batch_size != 0
    
    recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)

    return recon_loss

In [12]:
def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

## Training

In [None]:
def train(model, data_loader, optimizer, max_iter, C_max, C_step_value, device, log_dir, ckpt_dir, save_step):
    model.train()

    out = False
    global_iter = 0;
   
    pbar = tqdm(total=max_iter)
    pbar.update(global_iter)

    ## init PID control
    PID = PIDControl()
    C = 0.5
    period = 5000

    logs = pd.DataFrame(columns=["iter", "recons_loss", "total_kld", "C_t", "beta_t"])

    while not out:
        for _, (x, _) in enumerate(data_loader):
            global_iter += 1
            pbar.update(1)

            x = Variable(x.to(device))
            x_recon, mu, logvar = model(x)
            recon_loss = reconstruction_loss(x, x_recon)
            total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
            
            if global_iter%period==0:
                C += C_step_value
            if C > C_max:
                C = C_max
            ## dynamic pid
            beta = PID.pid(C, total_kld.item())
            
            beta_vae_loss = recon_loss + beta * total_kld

            optimizer.zero_grad()
            beta_vae_loss.backward()
            optimizer.step()
            
            if global_iter%20 == 0:               
                log = pd.DataFrame({"iter":global_iter, "recons_loss":recon_loss.item(), "total_kld":total_kld.item(), "C_t":C, "beta_t":beta}, index = [0])
                logs = logs.append(log, ignore_index = True)
                logs.to_csv(os.path.join(log_dir, "linear_log.csv"))

            if global_iter % save_step == 0:
                ckpt = os.path.join(ckpt_dir, "linear_step_{}.pt".format(global_iter))
                torch.save(model.state_dict(), ckpt)
                pbar.write('Saved checkpoint(iter:{})'.format(global_iter))

            if global_iter >= max_iter:
                out = True
                break

    ckpt = os.path.join(ckpt_dir, "linear_step_final.pt")
    torch.save(model.state_dict(), ckpt)
    logs.to_csv(os.path.join(log_dir, "linear_log.csv"))
    pbar.write('Saved checkpoint(iter:final)')

    pbar.write("[Training Finished]")
    pbar.close()

In [None]:
train(model, train_loader, optimizer, Config.max_iter, Config.C_max, Config.C_step_value, Config.device, Config.log_dir, Config.ckpt_dir, Config.save_step)

  0%|          | 0/1500000 [00:00<?, ?it/s]



Saved checkpoint(iter:250000)
Saved checkpoint(iter:500000)
Saved checkpoint(iter:750000)
Saved checkpoint(iter:1000000)
Saved checkpoint(iter:1250000)
Saved checkpoint(iter:1500000)
Saved checkpoint(iter:final)
[Training Finished]


## MIG Score

In [None]:
def logsumexp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

class Normal(nn.Module):
    def __init__(self, mu=0, sigma=1):
        super(Normal, self).__init__()
        self.normalization = Variable(torch.Tensor([np.log(2 * np.pi)]))

        self.mu = Variable(torch.Tensor([mu]))
        self.logsigma = Variable(torch.Tensor([math.log(sigma)]))

    def _check_inputs(self, size, mu_logsigma):
        if size is None and mu_logsigma is None:
            raise ValueError(
                'Either one of size or params should be provided.')
        elif size is not None and mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0).expand(size)
            logsigma = mu_logsigma.select(-1, 1).expand(size)
            return mu, logsigma
        elif size is not None:
            mu = self.mu.expand(size)
            logsigma = self.logsigma.expand(size)
            return mu, logsigma
        elif mu_logsigma is not None:
            mu = mu_logsigma.select(-1, 0)
            logsigma = mu_logsigma.select(-1, 1)
            return mu, logsigma
        else:
            raise ValueError(
                'Given invalid inputs: size={}, mu_logsigma={})'.format(
                    size, mu_logsigma))

    def sample(self, size=None, params=None):
        mu, logsigma = self._check_inputs(size, params)
        std_z = Variable(torch.randn(mu.size()).type_as(mu.data))
        sample = std_z * torch.exp(logsigma) + mu
        return sample

    def log_density(self, sample, params=None):
        if params is not None:
            mu, logsigma = self._check_inputs(None, params)
        else:
            mu, logsigma = self._check_inputs(sample.size(), None)
            mu = mu.type_as(sample)
            logsigma = logsigma.type_as(sample)

        c = self.normalization.type_as(sample.data)
        inv_sigma = torch.exp(-logsigma)
        tmp = (sample - mu) * inv_sigma
        return -0.5 * (tmp * tmp + 2 * logsigma + c)

    def NLL(self, params, sample_params=None):
        """Analytically computes
            E_N(mu_2,sigma_2^2) [ - log N(mu_1, sigma_1^2) ]
        If mu_2, and sigma_2^2 are not provided, defaults to entropy.
        """
        mu, logsigma = self._check_inputs(None, params)
        if sample_params is not None:
            sample_mu, sample_logsigma = self._check_inputs(None, sample_params)
        else:
            sample_mu, sample_logsigma = mu, logsigma

        c = self.normalization.type_as(sample_mu.data)
        nll = logsigma.mul(-2).exp() * (sample_mu - mu).pow(2)             + torch.exp(sample_logsigma.mul(2) - logsigma.mul(2)) + 2 * logsigma + c
        return nll.mul(0.5)

    def kld(self, params):
        """Computes KL(q||p) where q is the given distribution and p
        is the standard Normal distribution.
        """
        mu, logsigma = self._check_inputs(None, params)
        kld = logsigma.mul(2).add(1) - mu.pow(2) - logsigma.exp().pow(2)
        kld.mul_(-0.5)
        return kld

    def get_params(self):
        return torch.cat([self.mu, self.logsigma])

    @property
    def nparams(self):
        return 2

    @property
    def ndim(self):
        return 1

    @property
    def is_reparameterizable(self):
        return True

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' ({:.3f}, {:.3f})'.format(
            self.mu.data[0], self.logsigma.exp().data[0])
        return tmpstr

def estimate_entropies(qz_samples, qz_params, q_dist=Normal(), n_samples=10000, weights=None):
    """Computes the term:
        E_{p(x)} E_{q(z|x)} [-log q(z)]
    and
        E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]
    where q(z) = 1/N sum_n=1^N q(z|x_n).
    Assumes samples are from q(z|x) for *all* x in the dataset.
    Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x).
    Computes numerically stable NLL:
        - log q(z) = log N - logsumexp_n=1^N log q(z|x_n)
    Inputs:
    -------
        qz_samples (K, N) Variable
        qz_params  (N, K, nparams) Variable
        weights (N) Variable
    """

    # Only take a sample subset of the samples
    if weights is None:
        qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:n_samples].cuda()))
    else:
        sample_inds = torch.multinomial(weights, n_samples, replacement=True)
        qz_samples = qz_samples.index_select(1, sample_inds)

    K, S = qz_samples.size()
    N, _, nparams = qz_params.size()
    assert(nparams == q_dist.nparams)
    assert(K == qz_params.size(1))

    if weights is None:
        weights = -math.log(N)
    else:
        weights = torch.log(weights.view(N, 1, 1) / weights.sum())

    entropies = torch.zeros(K).cuda()

    k = 0
    while k < S:
        batch_size = min(10, S - k)
        logqz_i = q_dist.log_density(
            qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size],
            qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size])
        k += batch_size

        # computes - log q(z_i) summed over minibatch
        entropies += - logsumexp(logqz_i + weights, dim=0, keepdim=False).data.sum(1)

    entropies /= S

    return entropies


def MIG(mi_normed):
    return torch.mean(mi_normed[:, 0] - mi_normed[:, 1])

def compute_metric_shapes(marginal_entropies, cond_entropies):
    factor_entropies = [6, 40, 32, 32]
    mutual_infos = marginal_entropies[None] - cond_entropies
    mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0)
    mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None]
    metric = MIG(mi_normed)
    return metric

In [None]:
model = BetaVAE().to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.ckpt_dir, "linear_step_final.pt")))
model.eval()

test_loader = datasets.get_dataloaders("dsprites", batch_size=Config.batch_size, shuffle=False)
q_dist = Normal()
N = len(test_loader.dataset)  
K = model.latent_dim             
nparams = q_dist.nparams

print('Computing q(z|x) distributions.')
qz_params = torch.Tensor(N, K, nparams)

n = 0
for i, (x, _) in enumerate(test_loader):
    print(i + 1, len(test_loader), end='\r')
    batch_size = x.size(0)
    qz_params[n:n + batch_size] = model.encoder.net.forward(x.cuda()).view(batch_size, nparams, model.latent_dim).transpose(1, 2).data

    n += batch_size

qz_params = Variable(qz_params.view(3, 6, 40, 32, 32, K, nparams).cuda())
qz_params[:,:,:,:,:,:,1] = qz_params[:,:,:,:,:,:,1]/2 
qz_samples = q_dist.sample(params=qz_params)

print('Estimating marginal entropies.')

marginal_entropies = estimate_entropies(
    qz_samples.view(N, K).transpose(0, 1),
    qz_params.view(N, K, nparams),
    q_dist)

marginal_entropies = marginal_entropies.cpu()
cond_entropies = torch.zeros(4, K)

print('Estimating conditional entropies for scale.')
for i in range(6):
    qz_samples_scale = qz_samples[:, i, :, :, :, :].contiguous()
    qz_params_scale = qz_params[:, i, :, :, :, :].contiguous()

    cond_entropies_i = estimate_entropies(
        qz_samples_scale.view(N // 6, K).transpose(0, 1),
        qz_params_scale.view(N // 6, K, nparams),
        q_dist)

    cond_entropies[0] += cond_entropies_i.cpu() / 6

print('Estimating conditional entropies for orientation.')
for i in range(40):
    qz_samples_scale = qz_samples[:, :, i, :, :, :].contiguous()
    qz_params_scale = qz_params[:, :, i, :, :, :].contiguous()

    cond_entropies_i = estimate_entropies(
        qz_samples_scale.view(N // 40, K).transpose(0, 1),
        qz_params_scale.view(N // 40, K, nparams),
        q_dist)

    cond_entropies[1] += cond_entropies_i.cpu() / 40

print('Estimating conditional entropies for pos x.')
for i in range(32):
    qz_samples_scale = qz_samples[:, :, :, i, :, :].contiguous()
    qz_params_scale = qz_params[:, :, :, i, :, :].contiguous()

    cond_entropies_i = estimate_entropies(
        qz_samples_scale.view(N // 32, K).transpose(0, 1),
        qz_params_scale.view(N // 32, K, nparams),
        q_dist)

    cond_entropies[2] += cond_entropies_i.cpu() / 32

print('Estimating conditional entropies for pox y.')
for i in range(32):
    qz_samples_scale = qz_samples[:, :, :, :, i, :].contiguous()
    qz_params_scale = qz_params[:, :, :, :, i, :].contiguous()

    cond_entropies_i = estimate_entropies(
        qz_samples_scale.view(N // 32, K).transpose(0, 1),
        qz_params_scale.view(N // 32, K, nparams),
        q_dist)

    cond_entropies[3] += cond_entropies_i.cpu() / 32

metric = compute_metric_shapes(marginal_entropies, cond_entropies)

print()
print('MIG: {}'.format(metric.cpu().numpy()))

Computing q(z|x) distributions.
Estimating marginal entropies.
Estimating conditional entropies for scale.
Estimating conditional entropies for orientation.
Estimating conditional entropies for pos x.
Estimating conditional entropies for pox y.

MIG: 0.5279951095581055


## Reconstruction & Latent Traversal

In [13]:
def reconstruct(model, upsample_factor, data, size):
    n_samples = size[0] // 2 * size[1]

    with torch.no_grad():
        originals = data.to(Config.device)[:n_samples, ...]
        recs, _, _ = model(originals)

    originals = originals.cpu()
    recs = recs.view(-1, * model.img_size).cpu()

    to_plot = torch.cat([originals, recs])

    to_plot = F.interpolate(to_plot, scale_factor=upsample_factor)
    grid = make_grid(to_plot)
    img_grid = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
    img_grid = img_grid.to('cpu', torch.uint8).numpy()

    return img_grid

def traverse(model, upsample_factor, n_per_latent, n_latents, max_traversal):
    latent_samples = [traverse_line(dim, n_latents, n_per_latent, max_traversal) for dim in range(n_latents)]
    latent_samples = torch.cat(latent_samples, dim=0).to(Config.device)

    decoded_traversal = model.decoder(latent_samples).cpu()
    decoded_traversal = decoded_traversal[range(n_per_latent * n_latents), ...]

    to_plot = F.interpolate(decoded_traversal.data, scale_factor=upsample_factor)
    grid = make_grid(to_plot)
    img_grid = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
    img_grid = img_grid.to('cpu', torch.uint8).numpy()

    return img_grid
        
def traverse_line(idx, latent_dim, n_samples, max_traversal):
    samples = torch.zeros(n_samples, latent_dim)
    traversals = torch.linspace(*(-1 * max_traversal, max_traversal), steps=n_samples)

    for i in range(n_samples):
        samples[i, idx] = traversals[i]

    return samples

In [14]:
model = BetaVAE().to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.ckpt_dir, "linear_step_final.pt")))
model.eval()

samples = get_samples("dsprites", 6 * 7)
n_latents = model.latent_dim

reconstructions = reconstruct(model, 1, samples[:2 * 8, ...], (2, 8))
traversals = traverse(model, 1, 8, n_latents, 2)

reconstructions = Image.fromarray(reconstructions)
traversals = Image.fromarray(traversals)

reconstructions_name = os.path.join(Config.log_dir, "linear_reconstruct.png")
traversals_name = os.path.join(Config.log_dir, "linear_traverse.png")

reconstructions.save(reconstructions_name)
traversals.save(traversals_name)

Selected idcs: [670487, 116739, 26225, 288389, 256787, 234053, 146316, 107473, 709570, 571858, 91161, 619176, 442417, 33326, 31244, 98246, 229258, 243962, 529903, 631262, 27824, 588508, 208496, 681453, 735392, 571412, 439898, 231148, 471029, 617889, 291704, 6814, 167414, 732052, 443143, 356778, 291369, 163032, 225772, 352944, 107175, 97251]
