<a href="https://colab.research.google.com/github/cuichenx/FactorVAE/blob/master/FactorVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


# Load Data

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
# from tqdm import tqdm

# load data
dataset = np.load("drive/My Drive/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")
imgs = dataset['imgs']  # (737280, 64, 64)
num_imgs = imgs.shape[0]

# Define VAE and Discriminator

In [0]:


class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)


class FactorVAE(nn.Module):
    def __init__(self):
        super().__init__()

        # padding=1 so that dimension gets cut in half exactly
        # input (1, 64, 64)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),  # (32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(32, 32, 4, stride=2, padding=1),  # (32, 16, 16)
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # (64, 8, 8)
            nn.ReLU(),
            nn.Conv2d(64, 64, 4, stride=2, padding=1),  # (64, 4, 4)
            nn.ReLU(),
            View((-1, 64*4*4)),
            nn.Linear(64*4*4, 128),  # no relu here?
            nn.Linear(128, 20),
        )

        self.decoder = nn.Sequential(
            nn.Linear(10, 128),  
            nn.ReLU(),
            nn.Linear(128, 64*4*4),
            nn.ReLU(),
            View((-1, 64, 4, 4)),  # (64, 4, 4)
            nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1),  # (64, 8, 8)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # (32, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1),  # (32, 32, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),  # (1, 64, 64)
            # nn.Sigmoid()
        )

    def forward(self, x):
        t = self.encoder(x)
        mu = t[..., :10]
        log_sigma2 = t[..., 10:]
        sigma = torch.exp(1/2 * log_sigma2)  # output is log variance
        epsilon = torch.randn_like(mu)  # mean 0 std 1 shaped like mu
        z = mu + sigma * epsilon
        x_tilde = self.decoder(z)

        return x_tilde, z, mu, log_sigma2


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(10, 1000),
            nn.LeakyReLU(),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(),
            nn.Linear(1000, 2),  # predicts two logits
        )

    def forward(self, x):
        return self.mlp(x)


# Training

In [7]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
debug_steps = 10 if device == 'cpu' else 500
vae_model = FactorVAE().to(device)
vae_optimizer = optim.Adam(vae_model.parameters(), lr=1e-4, betas=(0.9, 0.999))

d_model = Discriminator().to(device)
d_optimizer = optim.Adam(d_model.parameters(), lr=1e-4, betas=(0.5, 0.9))

num_iterations = int(3e5)
batch_size = 64  # m = 64

np.random.shuffle(imgs)
val_imgs = imgs[:num_imgs//5]
train_imgs = imgs[num_imgs//5:]

# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss_function(x_tilde, x, mu, log_sigma2, dz_logits):
    # since dsprites data consists of binary images, distribution of x is bernoulli
    # use negative cross entropy as reconstruction objective log p(x|z)
    # i.e. use cross entropy as reconstruction loss
    reconstruction_loss = F.binary_cross_entropy_with_logits(x_tilde, x)

    # KL divergence of two gaussian distributions has analytical solution
    # from appendix B of Kingma 2014
    KL_div = - 1/2 * torch.mean(1 + log_sigma2 - mu ** 2 - log_sigma2.exp())

    # Total Correlation using the density ratio trick
    # discriminator logits has shape (64, 2)
    # Dz = exp(logits[:, 0]) / (exp(logits[:, 0])+exp(logits[:, 1]))
    # 1-Dz = exp(logits[:, 1]) / (exp(logits[:, 0])+exp(logits[:, 1]))
    # Dz / (1-Dz) = exp(logits[:, 0]) / exp(logits[:, 1]) = exp(logits[:, 0]-logits[:, 1])
    # TC = mean(log(Dz / (1-Dz))) = mean(logits[:, 0]-logits[:, 1])
    TC = torch.mean(dz_logits[:, 0] - dz_logits[:, 1])  # can be negative?

    return reconstruction_loss, KL_div, TC
    # return reconstruction_loss, KL_div, 0


def d_loss_function(dz_logits, dz_perm_logits):
    # dz_logits has shape (64, 2)
    # Dz = exp(logits[:, 0]) / (exp(logits[:, 0])+exp(logits[:, 1]))
    # log(Dz) = logits[:, 0] - log((exp(logits[:, 0])+exp(logits[:, 1])))
    #         = logits[:, 0] - logsumexp(logits, dim=1)
    # similarly
    # log(1-Dz_perm) = permlogits[:, 1] - logsumexp(permlogits, dim=1)

    return -1/2 * torch.mean(
                      dz_logits[:, 0] - torch.logsumexp(dz_logits, dim=1)
                    + dz_perm_logits[:, 1] - torch.logsumexp(dz_perm_logits, dim=1), dim=0)  # TODO: check this

def permute_dims(z):
    """
    z has shape (N, d), N=batch_size=64, d = 10
    """
    N, d = z.shape
    for j in range(d):
        z[:, j] = z[np.random.permutation(N), j]
    return z


def train():
    vae_model.train()
    d_model.train()

    ## train VAE
    vae_optimizer.zero_grad()
    vae_batch = train_imgs[np.random.randint(num_imgs*4//5, size=batch_size)]  # (N, 64, 64)
    x = torch.tensor(vae_batch[:, np.newaxis, :, :]).float().to(device)  # (N, 1, 64, 64)
    x_tilde, z, mu, log_sigma2 = vae_model(x)

    Dz = d_model(z)
    reconstruction_loss, KL_div, TC = vae_loss_function(x_tilde, x, mu, log_sigma2, Dz)
    vae_loss = reconstruction_loss + KL_div + gamma * TC
    vae_loss.backward(retain_graph=True)  ## set to true
    vae_optimizer.step()

    ## train discriminator
    d_optimizer.zero_grad()
    d_batch = train_imgs[np.random.randint(num_imgs*4//5, size=batch_size)]  # (N, 64, 64)
    x_prime = torch.tensor(d_batch[:, np.newaxis, :, :]).float().to(device)  # (N, 1, 64, 64)

    _, z_prime, _, _ = vae_model(x_prime)  # (N, 10)

    z_perm = permute_dims(z_prime)  # (N, 10)
    Dz_perm = d_model(z_perm)
    d_loss = d_loss_function(Dz, Dz_perm)
    d_loss.backward()
    d_optimizer.step()

    return reconstruction_loss.item(), KL_div.item(), TC.item(), d_loss.item()


def validate():
    vae_model.eval()
    d_model.eval()

    vae_eval_batch = val_imgs #[np.random.randint(num_imgs//5, size=batch_size)]  # (N, 64, 64)
    x = torch.tensor(vae_eval_batch[:, np.newaxis, :, :]).float().to(device)  # (N, 1, 64, 64)
    x_tilde, z, mu, log_sigma2 = vae_model(x)

    Dz = d_model(z)
    reconstruction_loss, KL_div, TC = vae_loss_function(x_tilde, x, mu, log_sigma2, Dz)
    
    return reconstruction_loss.item(), KL_div.item(), TC.item()

def main(iterations, gamma):
    cumul_recon_loss = 0
    cumul_KLD_loss = 0
    cumul_TC_loss = 0
    cumul_vae_loss = 0
    cumul_d_loss = 0

    for it in range(1, iterations+1):
        reconstruction_loss, KL_div, TC, d_loss = train()

        cumul_recon_loss += reconstruction_loss
        cumul_KLD_loss += KL_div
        cumul_TC_loss += TC
        cumul_vae_loss += reconstruction_loss + KL_div + TC
        cumul_d_loss += d_loss

        if it % debug_steps == 0:  # Training Loss
            print(f"iteration {it: 6d}: " +
                  f"vae_loss {cumul_vae_loss/debug_steps:.4f}, " + 
                  f"recon_loss {cumul_recon_loss/debug_steps:.5f}, " +
                  f"KLD {cumul_KLD_loss/debug_steps:.4g}, " +
                  f"TC {cumul_TC_loss/debug_steps:.4g}, " +
                  f"d_loss {cumul_d_loss/debug_steps:.6f}")
            cumul_recon_loss = 0
            cumul_KLD_loss = 0
            cumul_TC_loss = 0
            cumul_vae_loss = 0
            cumul_d_loss = 0

        if it % 5000 == 0:  # Validation Loss
            recon_loss, KL_div, TC = validate()
            print(f"===> Validation: iteration {it: 6d}: " + 
                  f"recon_loss {recon_loss:.5f}, " +
                  f"KLD {KL_div:.4g}, " +
                  f"TC {TC:.4g}")
        if it % 25000 == 0:  # Save model
            torch.save(vae_model, f"drive/My Drive/FactorVAEModels/vae{gamma}_it{it:06d}_recon{recon_loss:.4f}.pt")

gamma = 30
main(num_iterations, gamma)

iteration    500: vae_loss 5.9956, recon_loss 0.22537, KLD 5.717, TC 0.05316, d_loss 0.670750
iteration   1000: vae_loss 10.1512, recon_loss 0.14347, KLD 9.949, TC 0.05843, d_loss 0.666296
iteration   1500: vae_loss 12.0733, recon_loss 0.14747, KLD 11.87, TC 0.05699, d_loss 0.670676
iteration   2000: vae_loss 0.5871, recon_loss 0.14235, KLD 0.4419, TC 0.002877, d_loss 0.693286
iteration   2500: vae_loss 0.2549, recon_loss 0.14034, KLD 0.1146, TC -2.317e-05, d_loss 0.693262
iteration   3000: vae_loss 0.2899, recon_loss 0.13956, KLD 0.1495, TC 0.0008505, d_loss 0.693218
iteration   3500: vae_loss 4.5496, recon_loss 0.14148, KLD 4.388, TC 0.01976, d_loss 0.683977
iteration   4000: vae_loss 12.9646, recon_loss 0.14216, KLD 12.73, TC 0.09318, d_loss 0.651279
iteration   4500: vae_loss 12.3008, recon_loss 0.14152, KLD 12.06, TC 0.1001, d_loss 0.652934
iteration   5000: vae_loss 16.0448, recon_loss 0.14010, KLD 15.8, TC 0.1067, d_loss 0.658459
===> Validation: iteration   5000: recon_loss 0.1

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


iteration   5500: vae_loss 17.0919, recon_loss 0.14344, KLD 16.78, TC 0.1702, d_loss 0.653434
iteration   6000: vae_loss 6.5499, recon_loss 0.15024, KLD 6.254, TC 0.1458, d_loss 0.651732
iteration   6500: vae_loss 0.9278, recon_loss 0.13985, KLD 0.782, TC 0.005983, d_loss 0.693336
iteration   7000: vae_loss 0.7106, recon_loss 0.13845, KLD 0.571, TC 0.001138, d_loss 0.692908
iteration   7500: vae_loss 0.6651, recon_loss 0.13810, KLD 0.5265, TC 0.0004987, d_loss 0.693230
iteration   8000: vae_loss 0.4139, recon_loss 0.13850, KLD 0.2752, TC 0.0001875, d_loss 0.693155
iteration   8500: vae_loss 6.6550, recon_loss 0.13924, KLD 6.424, TC 0.0914, d_loss 0.663227
iteration   9000: vae_loss 11.4223, recon_loss 0.14137, KLD 11.15, TC 0.1281, d_loss 0.660484
iteration   9500: vae_loss 9.6299, recon_loss 0.14001, KLD 9.393, TC 0.09679, d_loss 0.664784
iteration  10000: vae_loss 18.0600, recon_loss 0.14000, KLD 17.77, TC 0.154, d_loss 0.671009
===> Validation: iteration  10000: recon_loss 0.15189, 

In [0]:
# torch.save(vae_model, "vae_model.pt")
# torch.save(d_model, "d_model.pt")

In [0]:
# import matplotlib.pyplot as plt
# # check for reconstruction
# img = imgs[12345]
# plt.imshow(img)

please see other jupyter notebook for experiments
