In [None]:
import os
import urllib
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

# Definitions

## Variational Autoencoder (VAE)

In [None]:
LAT_DIM = 2
EPOCH_NUM = 50
BATCH_SIZE = 128
CAPACITY = 64
LRN_RATE = 1e-3
VAR_BETA = 1

KERN_SIZE = 4
STRIDE = 2
PAD = 1

GPU = True

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        # CAPACITY * 14 * 14
        self.conv1 = nn.Conv2d(
            in_channels = 1,
            out_channels = CAPACITY,
            kernel_size = KERN_SIZE,
            stride = STRIDE,
            padding = PAD,
        )
        # CAPACITY * 7 * 7
        self.conv2 = nn.Conv2d(
            in_channels = CAPACITY,
            out_channels = CAPACITY * 2,
            kernel_size = KERN_SIZE,
            stride = STRIDE,
            padding = PAD,
        )
        self.fc_mu = nn.Linear(
            in_features = CAPACITY * 2 * 7 * 7,
            out_features = LAT_DIM,
        )
        self.fc_logvar = nn.Linear(
            in_features = CAPACITY * 2 * 7 * 7,
            out_features = LAT_DIM,
        )
        
    def forward(self, x):
        x = functional.relu(self.conv2(
            functional.relu(self.conv1(x))
        ))
        x = x.view(x.size(0), -1)
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc = nn.Linear(
            in_features = LAT_DIM,
            out_features = CAPACITY * 2 * 7 * 7,
        )
        self.conv2 = nn.ConvTranspose2d(
            in_channels = CAPACITY * 2,
            out_channels = CAPACITY,
            kernel_size = KERN_SIZE,
            stride = STRIDE,
            padding = PAD,
        )
        self.conv1 = nn.ConvTranspose2d(
            in_channels = CAPACITY,
            out_channels = 1,
            kernel_size = KERN_SIZE,
            stride = STRIDE,
            padding = PAD,
        )

    def forward(self, x):
        x = self.fc(x)
        return torch.sigmoid(self.conv1(
                functional.relu(self.conv2(
                    x.view(x.size(0), CAPACITY * 2, 7, 7)
            ))
        ))

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar

    def latent_sample(self, mu, logvar):
        if self.training:
            # the reparameterization trick
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

In [None]:
def reconstruction_erorr(recon_x, x):
    return functional.binary_cross_entropy(
        recon_x.view(-1, 784),
        x.view(-1, 784),
        reduction = "sum",
    )

def vae_loss(recon_loss, mu, logvar):
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + VAR_BETA * kl_divergence

## RRH for Gaussian Mixtures

In [None]:
def mvn_renyi(C, q=1):
    """ Computes the Rényi heterogeneity for a multivariate Gaussian 
    Arguments: 
        C: `ndarray((n,n))`. Covariance matrix
        q: `0<float`. Order of the heterogeneity
    Returns: 
        `float`
    """
    n = C.shape[0]
    SqrtDetC = np.sqrt(np.linalg.det(C))
    if q == 1: 
        out = (2*np.pi*np.e)**(n/2) * SqrtDetC
    elif q == np.inf: 
        out = (2*np.pi)**(n/2) * SqrtDetC
    elif q!=1 and q!=0 and q!=np.inf:
        out = ((2*np.pi)**(n/2))*(q**(n/(2*(q-1))))*SqrtDetC
    return out

In [None]:
def mvn_renyi_alpha(C,  q=1):
    """ Computes the alpha-heterogeneity for a Gaussian mixture where each sample has equal weight

    Arguments: 

        cov: `ndarray((nsamples, n, n))`. Covariance matrices 
        q: `0<float`. Order of the heterogeneity metric

    Returns: 

        `float`. The alpha-heterogeneity
    """
    K, n, _ = C.shape
    p = np.repeat(1/K, K)
    if q == 1:
        out = np.exp((n + np.sum(p*np.log(np.linalg.det(2*np.pi*C))))/2)
    elif q!=np.inf and q!=1 and q!=0:
        wbar = (p**q)/np.sum(p**q)
        out = ((2*np.pi)**(n/2))*np.sum(wbar*np.sqrt(np.linalg.det(C)))/(q**(n/2))**(1/(1-q))
    return out


In [None]:
def scale_to_cov(scales):
    return np.vstack([np.expand_dims(np.diagflat(s), 0) for s in scales])

In [None]:
def pool_covariance(means, covs):
    K = covs.shape[0] 
    p = np.repeat(1/K, K)
    cov_ = np.einsum('ijk,i->jk', covs, p) + np.einsum('ij,ik,i->jk', means, means, p)
    mu_ = np.einsum('ij,i->j', means, p)
    return cov_ - np.einsum('i,j->ij', mu_, mu_)

# MNIST Experiments

## Make MNIST training and evaluation sets

In [None]:
img_transform = transforms.Compose([
    transforms.ToTensor()
])

def load_mnist(train):
    dataset = MNIST(
        root = './data/MNIST',
        download = True,
        train = train,
        transform = img_transform,
    )
    return DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

train_dataloader = load_mnist(train = True)
test_dataloader = load_mnist(train = False)

In [None]:
# Place into numpy arrays for easier manipulation
traindata = list(train_dataloader)
traindata = [[sample[0].numpy(), sample[1].numpy()] for sample in traindata]
X = np.vstack([sample[0] for sample in traindata])
y = np.hstack([sample[1] for sample in traindata])

## Load and Evaluate Pre-Trained VAE

Load

In [None]:
pretrained_vae = VariationalAutoencoder()
device = torch.device("cuda:0" if GPU and torch.cuda.is_available() else "cpu")
pretrained_vae = pretrained_vae.to(device)

filename = 'vae_2d.pth'

if not os.path.isdir('./pretrained'):
    os.makedirs('./pretrained')
print('downloading ...')
urllib.request.urlretrieve(
    "http://geometry.cs.ucl.ac.uk/creativeai/pretrained/" + filename,
    "./pretrained/" + filename,
)
pretrained_vae.load_state_dict(torch.load('./pretrained/' + filename))
print('done')

downloading ...
done


Evaluate

In [None]:
# set to evaluation mode
pretrained_vae.eval()

test_loss_avg, recon_loss_avg, num_batches = 0, 0, 0
for image_batch, _ in test_dataloader:
    
    with torch.no_grad():
    
        image_batch = image_batch.to(device)

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = pretrained_vae(image_batch)

        # reconstruction error
        recon_loss = reconstruction_erorr(image_batch_recon, image_batch)
        loss = vae_loss(recon_loss, latent_mu, latent_logvar)

        recon_loss_avg += recon_loss
        test_loss_avg += loss.item()
        num_batches += 1
    
recon_loss_avg /= num_batches
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (recon_loss_avg))
print('average error: %f' % (test_loss_avg))

average reconstruction error: 18493.664062
average error: 19294.807194


## Train and Evaluate VAE

In [None]:
vae = VariationalAutoencoder()
device = torch.device("cuda:0" if GPU and torch.cuda.is_available() else "cpu")
vae = vae.to(device)

num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

optimizer = torch.optim.Adam(
    params = vae.parameters(),
    lr = LRN_RATE,
    weight_decay = 1e-5,
)

# set to training mode
vae.train()

train_loss_avg = []

print('Training ...')
for epoch in range(EPOCH_NUM):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in train_dataloader:
        
        image_batch = image_batch.to(device)

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
        
        # reconstruction error
        recon_loss = reconstruction_erorr(image_batch_recon, image_batch)
        loss = vae_loss(recon_loss, latent_mu, latent_logvar)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, EPOCH_NUM, train_loss_avg[-1]))

Number of parameters: 308357
Training ...
Epoch [1 / 50] average reconstruction error: 24048.731531
Epoch [2 / 50] average reconstruction error: 21667.731683
Epoch [3 / 50] average reconstruction error: 21049.359298
Epoch [4 / 50] average reconstruction error: 20652.804427
Epoch [5 / 50] average reconstruction error: 20374.810378
Epoch [6 / 50] average reconstruction error: 20175.633152
Epoch [7 / 50] average reconstruction error: 20017.137179
Epoch [8 / 50] average reconstruction error: 19916.152313
Epoch [9 / 50] average reconstruction error: 19817.292373
Epoch [10 / 50] average reconstruction error: 19746.652104
Epoch [11 / 50] average reconstruction error: 19674.891341
Epoch [12 / 50] average reconstruction error: 19610.826786
Epoch [13 / 50] average reconstruction error: 19559.579778
Epoch [14 / 50] average reconstruction error: 19513.977491
Epoch [15 / 50] average reconstruction error: 19456.698226
Epoch [16 / 50] average reconstruction error: 19425.032179
Epoch [17 / 50] average

In [None]:
plt.ion()

fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
# set to evaluation mode
vae.eval()

test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
    
    with torch.no_grad():
    
        image_batch = image_batch.to(device)

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)

        # reconstruction error
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

        test_loss_avg += loss.item()
        num_batches += 1
    
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))

## Experiment 1: $\beta$ heterogeneity of each digit class

### 1.1 RRH on MNIST-trained VAE

In [None]:
gammas = []
alphas = []
betas = []
for i in range(10):
    mu, logvar = vae.encoder(torch.Tensor(X[y == i]).to(device))
    loc = mu.cpu().detach().numpy()
    scale = logvar.exp().cpu().detach().numpy()
    cov = scale_to_cov(scale)
    cov = scale_to_cov(scale)
    gamma = mvn_renyi(pool_covariance(loc, cov), q=1)
    alpha = mvn_renyi_alpha(cov,q=1)
    beta = gamma/alpha
    gammas.append(gamma)
    alphas.append(alpha)
    betas.append(beta)

In [None]:
hetvalues = [gammas, alphas, betas]
plotlabels = [r"Pooled", r"Within-Observation", r"Between-Observation"]

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(9, 2.5))
ax[0].set_ylabel("Heterogeneity")

for i in range(3): 
    ax[i].set_title(plotlabels[i])
    ax[i].set_xlabel("Digit")
    ax[i].set_xticks(np.arange(10))
    ax[i].set_xticklabels(np.arange(10))
    ax[i].bar(
        np.arange(10),
        hetvalues[i],
        facecolor = plt.get_cmap("Greys")(0.4), 
        edgecolor = "black",
    )
plt.tight_layout()
plt.savefig("digit-class-heterogeneity.pdf", bbox_inches="tight")

### 1.2 RRH on Pre-trained VAE

In [None]:
gammas = []
alphas = []
betas = []
for i in range(10):
    mu, logvar = pretrained_vae.encoder(torch.Tensor(X[y == i]).to(device))
    loc = mu.cpu().detach().numpy()
    scale = logvar.exp().cpu().detach().numpy()
    cov = scale_to_cov(scale)
    cov = scale_to_cov(scale)
    gamma = mvn_renyi(pool_covariance(loc, cov), q=1)
    alpha = mvn_renyi_alpha(cov,q=1)
    beta = gamma/alpha
    gammas.append(gamma)
    alphas.append(alpha)
    betas.append(beta)

hetvalues = [gammas, alphas, betas]
plotlabels = [r"Pooled", r"Within-Observation", r"Between-Observation"]

fig, ax = plt.subplots(ncols=3, figsize=(9, 2.5))
ax[0].set_ylabel("Heterogeneity")

for i in range(3): 
    ax[i].set_title(plotlabels[i])
    ax[i].set_xlabel("Digit")
    ax[i].set_xticks(np.arange(10))
    ax[i].set_xticklabels(np.arange(10))
    ax[i].bar(
        np.arange(10),
        hetvalues[i],
        facecolor = plt.get_cmap("Greys")(0.4), 
        edgecolor = "black",
    )
plt.tight_layout()
plt.savefig("digit-class-heterogeneity.pdf", bbox_inches="tight")