### General Settings

Change the respective settings to run appropriately

Use `limit_train_batches`, `limit_val_batches`, `limit_test_batches` as required

In [1]:
project_dir = '/Users/rajjain/PycharmProjects/ADRL-Course-Work/'
data_dir = project_dir + 'data/'
celeba_data_dir = '/Users/rajjain/Desktop/CourseWork/CelebA/'
use_gpu = False
num_cpus = 2

### Imports

In [2]:
from torch.nn import init, Linear, Sequential, Conv2d, PReLU, Flatten, Unflatten, Module, BatchNorm2d, \
    ConvTranspose2d, Hardsigmoid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning import LightningModule
from torchvision.datasets import CelebA
from pytorch_lightning import Trainer
from torchvision import transforms
from datetime import datetime
from torchinfo import summary
from torch.optim import Adam
from torch.nn import MSELoss
import torch
import numpy
import gc
import os

from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

# VAE for dSprites dataset

## Model

In [3]:
class dCNNEncoderLayer(Module):
    """One layer of the encoder"""
    def __init__(self, in_channels, out_channels, conv_kernel, conv_stride, conv_padding: object = 0, normalise=True):
        super(dCNNEncoderLayer, self).__init__()
        self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(conv_kernel, conv_kernel),
                           stride=(conv_stride, conv_stride), padding=conv_padding)
        self.activation = PReLU(num_parameters=out_channels, init=0.25)
        self.norm = BatchNorm2d(num_features=out_channels)
        self.normalise = normalise

    def initialise(self):
        init.kaiming_normal_(self.conv.weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        if self.normalise:
            x = self.norm(x)
        return x


class dCNNDecoderLayer(Module):
    """One layer of the decoder"""
    def __init__(self, in_channels, out_channels, conv_kernel, conv_stride, op_padding=0, normalise=True):
        super(dCNNDecoderLayer, self).__init__()
        self.conv_trans = ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=(conv_kernel, conv_kernel), stride=(conv_stride, conv_stride),
                                          output_padding=op_padding)
        self.activation = PReLU(num_parameters=out_channels, init=0.25)
        self.normalise = normalise
        if self.normalise:
            self.norm = BatchNorm2d(num_features=out_channels)

    def initialise(self):
        init.kaiming_normal_(self.conv_trans.weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.conv_trans(x)
        x = self.activation(x)
        if self.normalise:
            x = self.norm(x)
        return x


class dCNNEncoder(Module):
    """Encoder"""
    def __init__(self, latent_dim):
        super(dCNNEncoder, self).__init__()
        self.layer1 = dCNNEncoderLayer(in_channels=1, out_channels=4, conv_kernel=3, conv_stride=1)
        self.layer2 = dCNNEncoderLayer(in_channels=4, out_channels=8, conv_kernel=3, conv_stride=1)
        self.layer3 = dCNNEncoderLayer(in_channels=8, out_channels=10, conv_kernel=4, conv_stride=2)
        self.layer4 = dCNNEncoderLayer(in_channels=10, out_channels=12, conv_kernel=4, conv_stride=2)
        self.layer5 = dCNNEncoderLayer(in_channels=12, out_channels=16, conv_kernel=4, conv_stride=2)
        self.flatten = Flatten()
        self.mean = Sequential(
            Linear(in_features=400, out_features=latent_dim),
            PReLU(num_parameters=latent_dim, init=0.25),  # negative and positive values
        )
        self.logvar = Sequential(  # As suggested in original VAE paper - instead of dealing with zeros to log
            Linear(in_features=400, out_features=latent_dim),
            PReLU(num_parameters=latent_dim, init=0.25),  # negative and positive values
        )

    def initialise(self):
        self.layer1.initialise()
        self.layer2.initialise()
        self.layer3.initialise()
        self.layer4.initialise()
        self.layer5.initialise()
        init.kaiming_normal_(self.mean._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')
        init.kaiming_normal_(self.logvar._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.flatten(x)
        mean, logvar = self.mean(x), self.logvar(x)
        logvar = torch.clamp(logvar, min=-16, max=16)
        return mean, logvar


class dCNNDecoder(Module):
    """Decoder"""
    def __init__(self, latent_dim):
        super(dCNNDecoder, self).__init__()
        self.expand = Sequential(
            Linear(in_features=latent_dim, out_features=400),
            PReLU(num_parameters=400, init=0.25),
        )
        self.unflatten = Unflatten(1, (16, 5, 5))
        self.layer1 = dCNNDecoderLayer(in_channels=16, out_channels=12, conv_kernel=4, conv_stride=2, op_padding=1)
        self.layer2 = dCNNDecoderLayer(in_channels=12, out_channels=10, conv_kernel=4, conv_stride=2, op_padding=1)
        self.layer3 = dCNNDecoderLayer(in_channels=10, out_channels=8, conv_kernel=4, conv_stride=2)
        self.layer4 = dCNNDecoderLayer(in_channels=8, out_channels=4, conv_kernel=3, conv_stride=1)
        self.output = ConvTranspose2d(in_channels=4, out_channels=1, kernel_size=(3, 3), stride=(1, 1))
        self.output_activation = Hardsigmoid()

    def initialise(self):
        init.kaiming_normal_(self.expand._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')
        self.layer1.initialise()
        self.layer2.initialise()
        self.layer3.initialise()
        self.layer4.initialise()
        init.xavier_normal_(self.output.weight)

    def forward(self, x):
        x = self.expand(x)
        x = self.unflatten(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.output(x)
        x = self.output_activation(x)
        return x


class dSprites_VAE(LightningModule):
    """Overall VAE Model"""
    def __init__(self, latent_dim, num_z, reduce_kl):
        super(dSprites_VAE, self).__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.num_z = num_z  # todo: use num_z
        self.reduce_kl = reduce_kl

        self.encoder = dCNNEncoder(latent_dim)
        self.decoder = dCNNDecoder(latent_dim)

        # Initialisations
        seed_everything(0)
        self.encoder.initialise()
        self.decoder.initialise()

        self.float()

    def forward(self, x):
        # Get mean and std (in the form of log of variance) from the encoder
        mean, logvar = self.encoder(x)

        # Get latent sample - We'll get one sample for now - as done in the original paper
        epsilon = torch.normal(0, 1, size=mean.shape, device=self.device)
        z = mean + 0.5 * torch.exp(logvar) * epsilon  

        # Get output from decoder
        x_hat = self.decoder(z)
        return mean, logvar, x_hat

    def _common_step(self, batch, btype):
        x, = batch
        x = x.float()
        mean, logvar, x_hat = self(x)

        # mean across samples of the L2 norm of the different vector
        indiv_recon = MSELoss(reduction='none')(input=x_hat, target=x)
        reconstruction_loss = 0.5 * indiv_recon.sum(dim=[1, 2, 3]).mean()
        if self.reduce_kl:
            # mean of KL loss across samples
            kl_loss = 0.5 * (mean ** 2 + torch.exp(logvar) - logvar - 1).sum(dim=1).mean()  
        else:
            kl_loss = 0

        overall_loss = reconstruction_loss + kl_loss

        if torch.isnan(overall_loss).item() or torch.isinf(overall_loss).item():
            numpy.savez_compressed(project_dir + 'issue_values', indiv_recon=indiv_recon.detach().cpu().numpy(),
                                   mean=mean.detach().cpu().numpy(), logvar=logvar.detach().cpu().numpy(),
                                   x=x.cpu().numpy(), x_hat=x_hat.detach().cpu().numpy())
            raise Exception('Nan/Inf encountered')

        # Log the losses
        not_training = btype != 'train'
        self.log(f'{btype}/loss', overall_loss,              on_step=False, on_epoch=True, prog_bar=True,  
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/recon_loss', reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/kl_loss', kl_loss,                on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)

        return overall_loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer


## Model Summary

In [20]:
summary_string = str(summary(model=dSprites_VAE(latent_dim=64, num_z=1, reduce_kl=True),
                             input_size=[(10, 1, 64, 64)],
                             dtypes=[torch.float],
                             depth=3,
                             col_names=['input_size', 'output_size', 'num_params'],
                             row_settings=['depth', 'var_names'],
                             verbose=0,
                             device=torch.device('cpu')))
print(summary_string)

Global seed set to 0


Layer (type (var_name):depth-idx)                  Input Shape               Output Shape              Param #
dSprites_VAE (dSprites_VAE)                        [10, 1, 64, 64]           [10, 64]                  --
├─dCNNEncoder (encoder): 1-1                       [10, 1, 64, 64]           [10, 64]                  --
│    └─dCNNEncoderLayer (layer1): 2-1              [10, 1, 64, 64]           [10, 4, 62, 62]           --
│    │    └─Conv2d (conv): 3-1                     [10, 1, 64, 64]           [10, 4, 62, 62]           40
│    │    └─PReLU (activation): 3-2                [10, 4, 62, 62]           [10, 4, 62, 62]           4
│    │    └─BatchNorm2d (norm): 3-3                [10, 4, 62, 62]           [10, 4, 62, 62]           8
│    └─dCNNEncoderLayer (layer2): 2-2              [10, 4, 62, 62]           [10, 8, 60, 60]           --
│    │    └─Conv2d (conv): 3-4                     [10, 4, 62, 62]           [10, 8, 60, 60]           296
│    │    └─PReLU (activation): 3-5       

## Data Related Utilities

In [5]:
def create_dsprites_split():
    """Split dSprites dataset into train, val and test"""
    num_val = num_test = 1000
    seed_everything(0)

    all_data = numpy.load(data_dir + 'dsprites.npz')['imgs']
    all_data = numpy.expand_dims(all_data, axis=1)
    numpy.random.shuffle(all_data)

    num_samples = all_data.shape[0]
    train, val, test = all_data[:num_samples - num_test - num_val], \
                       all_data[num_samples - num_test - num_val: num_samples - num_test], \
                       all_data[num_samples - num_test:]

    numpy.savez_compressed(data_dir + 'dsprites_split.npz', train=train, val=val, test=test)


def get_dsprites():  # use the above function to create the split
    data = numpy.load(data_dir + 'dsprites_split.npz')
    train, val, test = data['train'], data['val'], data['test']
    return train, val, test


def get_dataloader(data: numpy.ndarray, bs: int, shuffle: bool):
    data_tensor = torch.tensor(data)
    dataset = TensorDataset(data_tensor)
    return DataLoader(dataset, batch_size=bs, shuffle=shuffle, num_workers=num_cpus)


def get_dsprites_dataloaders(train, val, test, bs):
    """Get the dSprites dataloaders"""
    train_dataloader = get_dataloader(train, bs, shuffle=True)
    val_dataloader = get_dataloader(val, bs=val.shape[0], shuffle=False)
    test_dataloader = get_dataloader(test, bs=test.shape[0], shuffle=False)
    return train_dataloader, val_dataloader, test_dataloader


# VAE for CelebA dataset

## Model

In [6]:
class cCNNEncoderLayer(Module):
    """One layer of the encoder"""
    def __init__(self, in_channels, out_channels, conv_kernel, conv_stride, conv_padding: object = 0, normalise=True):
        super(cCNNEncoderLayer, self).__init__()
        self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(conv_kernel, conv_kernel),
                           stride=(conv_stride, conv_stride), padding=conv_padding)
        self.activation = PReLU(num_parameters=out_channels, init=0.25)
        self.norm = BatchNorm2d(num_features=out_channels)
        self.normalise = normalise

    def initialise(self):
        init.kaiming_normal_(self.conv.weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        if self.normalise:
            x = self.norm(x)
        return x


class cCNNDecoderLayer(Module):
    """One layer of the decoder"""
    def __init__(self, in_channels, out_channels, conv_kernel, conv_stride, conv_padding: object = 0, normalise=True):
        super(cCNNDecoderLayer, self).__init__()
        self.conv_trans = ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=(conv_kernel, conv_kernel), stride=(conv_stride, conv_stride),
                                          padding=conv_padding)
        self.activation = PReLU(num_parameters=out_channels, init=0.25)
        self.normalise = normalise
        if self.normalise:
            self.norm = BatchNorm2d(num_features=out_channels)

    def initialise(self):
        init.kaiming_normal_(self.conv_trans.weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.conv_trans(x)
        x = self.activation(x)
        if self.normalise:
            x = self.norm(x)
        return x


class cCNNEncoder(Module):
    """Encoder"""
    def __init__(self, latent_dim):
        super(cCNNEncoder, self).__init__()
        self.layer1 = cCNNEncoderLayer(in_channels=3, out_channels=6, conv_kernel=3, conv_stride=1)
        self.layer2 = cCNNEncoderLayer(in_channels=6, out_channels=8, conv_kernel=3, conv_stride=1)
        self.layer3 = cCNNEncoderLayer(in_channels=8, out_channels=10, conv_kernel=4, conv_stride=2)
        self.layer4 = cCNNEncoderLayer(in_channels=10, out_channels=12, conv_kernel=4, conv_stride=2)
        self.layer5 = cCNNEncoderLayer(in_channels=12, out_channels=12, conv_kernel=6, conv_stride=2)
        self.flatten = Flatten()
        self.mean = Sequential(
            Linear(in_features=5472, out_features=latent_dim),
            PReLU(num_parameters=latent_dim, init=0.25),  # negative and positive values
        )
        self.logvar = Sequential(  # As suggested in original VAE paper - instead of dealing with zeros to log
            Linear(in_features=5472, out_features=latent_dim),
            PReLU(num_parameters=latent_dim, init=0.25),  # negative and positive values
        )

    def initialise(self):
        self.layer1.initialise()
        self.layer2.initialise()
        self.layer3.initialise()
        self.layer4.initialise()
        self.layer5.initialise()
        init.kaiming_normal_(self.mean._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')
        init.kaiming_normal_(self.logvar._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.flatten(x)
        mean, logvar = self.mean(x), self.logvar(x)
        logvar = torch.clamp(logvar, min=-16, max=16)
        return mean, logvar


class cCNNDecoder(Module):
    """Decoder"""
    def __init__(self, latent_dim):
        super(cCNNDecoder, self).__init__()
        self.expand = Sequential(
            Linear(in_features=latent_dim, out_features=5472),
            PReLU(num_parameters=5472, init=0.25),
        )
        self.unflatten = Unflatten(1, (12, 24, 19))
        self.layer1 = cCNNDecoderLayer(in_channels=12, out_channels=12, conv_kernel=6, conv_stride=2)
        self.layer2 = cCNNDecoderLayer(in_channels=12, out_channels=10, conv_kernel=4, conv_stride=2)
        self.layer3 = cCNNDecoderLayer(in_channels=10, out_channels=8, conv_kernel=4, conv_stride=2)
        self.layer4 = cCNNDecoderLayer(in_channels=8, out_channels=6, conv_kernel=3, conv_stride=1)
        self.output = ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 3), stride=(1, 1))
        self.output_activation = Hardsigmoid()

    def initialise(self):
        init.kaiming_normal_(self.expand._modules['0'].weight, a=0.25, nonlinearity='leaky_relu')
        self.layer1.initialise()
        self.layer2.initialise()
        self.layer3.initialise()
        self.layer4.initialise()
        init.xavier_normal_(self.output.weight)

    def forward(self, x):
        x = self.expand(x)
        x = self.unflatten(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.output(x)
        x = self.output_activation(x)
        return x


class celeba_VAE(LightningModule):
    """Overall VAE Model"""
    def __init__(self, latent_dim, num_z, reduce_kl):
        super(celeba_VAE, self).__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.num_z = num_z  # todo: use num_z
        self.reduce_kl = reduce_kl

        self.encoder = cCNNEncoder(latent_dim)
        self.decoder = cCNNDecoder(latent_dim)

        # Initialisations
        seed_everything(0)
        self.encoder.initialise()
        self.decoder.initialise()

        self.float()

    def forward(self, x):
        # Get mean and std (in the form of log of variance) from the encoder
        mean, logvar = self.encoder(x)

        # Get latent sample - We'll get one sample for now - as done in the original paper
        epsilon = torch.normal(0, 1, size=mean.shape, device=self.device)
        z = mean + 0.5 * torch.exp(logvar) * epsilon

        # Get output from decoder
        x_hat = self.decoder(z)
        return mean, logvar, x_hat

    def _common_step(self, batch, btype):
        x, = batch
        x = x.float()
        mean, logvar, x_hat = self(x)

        # Mean across samples of the L2 norm of the difference vector
        indiv_recon = MSELoss(reduction='none')(input=x_hat, target=x)
        reconstruction_loss = 0.5 * indiv_recon.sum(dim=[1, 2, 3]).mean()
        if self.reduce_kl:
            # mean of KL loss across samples
            kl_loss = 0.5 * (mean ** 2 + torch.exp(logvar) - logvar - 1).sum(dim=1).mean()  
        else:
            kl_loss = 0

        overall_loss = reconstruction_loss + kl_loss

        if torch.isnan(overall_loss).item() or torch.isinf(overall_loss).item():
            numpy.savez_compressed(project_dir + 'issue_values', indiv_recon=indiv_recon.detach().cpu().numpy(),
                                   mean=mean.detach().cpu().numpy(), logvar=logvar.detach().cpu().numpy(),
                                   x=x.cpu().numpy(), x_hat=x_hat.detach().cpu().numpy())
            raise Exception('Nan/Inf encountered')

        # Log the losses
        not_training = btype != 'train'
        self.log(f'{btype}/loss', overall_loss,              on_step=False, on_epoch=True, prog_bar=True,  
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/recon_loss', reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/kl_loss', kl_loss,                on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)

        return overall_loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer


class celeba_VAE_MAE(LightningModule):
    """MAE as loss instead of MSE"""
    def __init__(self, latent_dim, num_z, reduce_kl):
        super(celeba_VAE_MAE, self).__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.num_z = num_z  # todo: use num_z
        self.reduce_kl = reduce_kl

        self.encoder = cCNNEncoder(latent_dim)
        self.decoder = cCNNDecoder(latent_dim)

        # Initialisations
        seed_everything(0)
        self.encoder.initialise()
        self.decoder.initialise()

        self.float()

    def forward(self, x):
        # Get mean and std (in the form of log of variance) from the encoder
        mean, logvar = self.encoder(x)

        # Get latent sample - We'll get one sample for now - as done in the original paper
        epsilon = torch.normal(0, 1, size=mean.shape, device=self.device)
        z = mean + 0.5 * torch.exp(logvar) * epsilon

        # Get output from decoder
        x_hat = self.decoder(z)
        return mean, logvar, x_hat

    def _common_step(self, batch, btype):
        x, = batch
        x = x.float()
        mean, logvar, x_hat = self(x)

        # Mean across samples of the L1 norm of the difference vector
        indiv_recon = L1Loss(reduction='none')(input=x_hat, target=x)
        reconstruction_loss = 0.5 * indiv_recon.sum(dim=[1, 2, 3]).mean()
        if self.reduce_kl:
            # mean of KL loss across samples
            kl_loss = 0.5 * (mean ** 2 + torch.exp(logvar) - logvar - 1).sum(dim=1).mean()
        else:
            kl_loss = 0

        overall_loss = reconstruction_loss + kl_loss

        if torch.isnan(overall_loss).item() or torch.isinf(overall_loss).item():
            numpy.savez_compressed(project_dir + 'issue_values', indiv_recon=indiv_recon.detach().cpu().numpy(),
                                   mean=mean.detach().cpu().numpy(), logvar=logvar.detach().cpu().numpy(),
                                   x=x.cpu().numpy(), x_hat=x_hat.detach().cpu().numpy())
            raise Exception('Nan/Inf encountered')

        # Log the losses
        not_training = btype != 'train'
        self.log(f'{btype}/loss', overall_loss,              on_step=False, on_epoch=True, prog_bar=True,  
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/recon_loss', reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/kl_loss', kl_loss,                on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)

        return overall_loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer


class celeba_VAE_meanmse(LightningModule):
    """MSE as loss but averaged across dimensions and across samples"""
    def __init__(self, latent_dim, num_z, reduce_kl):
        super(celeba_VAE_meanmse, self).__init__()
        self.save_hyperparameters()
        self.latent_dim = latent_dim
        self.num_z = num_z  # todo: use num_z
        self.reduce_kl = reduce_kl

        self.encoder = cCNNEncoder(latent_dim)
        self.decoder = cCNNDecoder(latent_dim)

        # Initialisations
        seed_everything(0)
        self.encoder.initialise()
        self.decoder.initialise()

        self.float()

    def forward(self, x):
        # Get mean and std (in the form of log of variance) from the encoder
        mean, logvar = self.encoder(x)

        # Get latent sample - We'll get one sample for now - as done in the original paper
        epsilon = torch.normal(0, 1, size=mean.shape, device=self.device)
        z = mean + 0.5 * torch.exp(logvar) * epsilon

        # Get output from decoder
        x_hat = self.decoder(z)
        return mean, logvar, x_hat

    def _common_step(self, batch, btype):
        x, = batch
        x = x.float()
        mean, logvar, x_hat = self(x)

        # Mean across samples and dimensions of the L2 norm of the difference vector
        indiv_recon = MSELoss(reduction='none')(input=x_hat, target=x)
        reconstruction_loss = 0.5 * indiv_recon.mean()
        if self.reduce_kl:
            # mean of KL loss across samples
            kl_loss = 0.5 * (mean ** 2 + torch.exp(logvar) - logvar - 1).sum(dim=1).mean()  
        else:
            kl_loss = 0

        overall_loss = reconstruction_loss + kl_loss

        if torch.isnan(overall_loss).item() or torch.isinf(overall_loss).item():
            numpy.savez_compressed(project_dir + 'issue_values', indiv_recon=indiv_recon.detach().cpu().numpy(),
                                   mean=mean.detach().cpu().numpy(), logvar=logvar.detach().cpu().numpy(),
                                   x=x.cpu().numpy(), x_hat=x_hat.detach().cpu().numpy())
            raise Exception('Nan/Inf encountered')

        # Log the losses
        not_training = btype != 'train'
        self.log(f'{btype}/loss', overall_loss,              on_step=False, on_epoch=True, prog_bar=True,  
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/recon_loss', reconstruction_loss, on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/kl_loss', kl_loss,                on_step=False, on_epoch=True, prog_bar=False, 
                 logger=True, reduce_fx=torch.mean, sync_dist=not_training)

        return overall_loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer


## Model Summary

In [7]:
summary_string = str(summary(model=celeba_VAE(latent_dim=64, num_z=1, reduce_kl=True),
                             input_size=[(10, 3, 218, 178)],
                             dtypes=[torch.float],
                             depth=3,
                             col_names=['input_size', 'output_size', 'num_params'],
                             row_settings=['depth', 'var_names'],
                             verbose=0,
                             device=torch.device('cpu')))
print(summary_string)

Global seed set to 0


Layer (type (var_name):depth-idx)                  Input Shape               Output Shape              Param #
celeba_VAE (celeba_VAE)                            [10, 3, 218, 178]         [10, 64]                  --
├─cCNNEncoder (encoder): 1-1                       [10, 3, 218, 178]         [10, 64]                  --
│    └─cCNNEncoderLayer (layer1): 2-1              [10, 3, 218, 178]         [10, 6, 216, 176]         --
│    │    └─Conv2d (conv): 3-1                     [10, 3, 218, 178]         [10, 6, 216, 176]         168
│    │    └─PReLU (activation): 3-2                [10, 6, 216, 176]         [10, 6, 216, 176]         6
│    │    └─BatchNorm2d (norm): 3-3                [10, 6, 216, 176]         [10, 6, 216, 176]         12
│    └─cCNNEncoderLayer (layer2): 2-2              [10, 6, 216, 176]         [10, 8, 214, 174]         --
│    │    └─Conv2d (conv): 3-4                     [10, 6, 216, 176]         [10, 8, 214, 174]         440
│    │    └─PReLU (activation): 3-5     

## Data Related Utilities

In [18]:
def get_celeba_dataloaders(train, val, test, bs):
    """Setting num_workers = 0 as some issue with jupyter and pytorch. in normal implementation, num_cpus is used. 
    Check GitHub code"""
    
    def custom_collate_fn(batch):
        imgs = torch.stack([elem[0] for elem in batch])
        return [imgs]
    
    train_dataloader = DataLoader(
        CelebA(celeba_data_dir, split=train, target_type=[], transform=transforms.Compose([transforms.ToTensor()])),
        bs, shuffle=True, num_workers=0, collate_fn=custom_collate_fn
    )
    val_dataloader = DataLoader(
        CelebA(celeba_data_dir, split=val, target_type=[], transform=transforms.Compose([transforms.ToTensor()])),
        bs, shuffle=False, num_workers=0, collate_fn=custom_collate_fn
    )
    test_dataloader = DataLoader(
        CelebA(celeba_data_dir, split=test, target_type=[], transform=transforms.Compose([transforms.ToTensor()])),
        bs, shuffle=False, num_workers=0, collate_fn=custom_collate_fn
    )
    return train_dataloader, val_dataloader, test_dataloader


# Common Training Utilities

In [13]:
def train_and_test(bs: int, max_epochs: int,
                   tags: list[str], gpu_num: list[int], interactive: bool,
                   model_class, model_kwargs: dict,
                   loss_desc: str,
                   train, val, test, get_dataloaders, input_shape):
    seed_everything(0)

    train_dataloader, val_dataloader, test_dataloader = get_dataloaders(train, val, test, bs)

    folder_name = datetime.utcnow().isoformat(sep="T", timespec="microseconds")
    results_dir = project_dir + f'vae/results/run_{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    early_stop_callback = EarlyStopping(monitor='val/loss', mode='min', patience=100)
    checkpoint_callback = ModelCheckpoint(monitor='val/loss', mode='min', dirpath=results_dir, filename='best')

    if use_gpu:
        trainer_kwargs = dict(accelerator="gpu", devices=gpu_num,
                              strategy=None if interactive else DDPStrategy(find_unused_parameters=False))
    else:
        trainer_kwargs = dict()

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs',
                                  default_hp_metric=False)

    model = model_class(**model_kwargs)

    trainer = Trainer(auto_lr_find=True, default_root_dir=results_dir, max_epochs=max_epochs,
                      callbacks=[early_stop_callback, checkpoint_callback], logger=[tf_logger],
                      log_every_n_steps=1, num_sanity_val_steps=0,  
                      limit_train_batches=4, limit_val_batches=1, limit_test_batches=1,
                      **trainer_kwargs)
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

    model = model_class.load_from_checkpoint(checkpoint_callback.best_model_path)
    if use_gpu:
        trainer_kwargs = dict(accelerator="gpu", devices=gpu_num[:1],
                              strategy=None if interactive else DDPStrategy(find_unused_parameters=False))
    else:
        trainer_kwargs = dict()
    trainer = Trainer(auto_lr_find=True, default_root_dir=results_dir, max_epochs=max_epochs,
                      callbacks=[early_stop_callback, checkpoint_callback], logger=[tf_logger],
                      log_every_n_steps=1, num_sanity_val_steps=0,  
                      limit_train_batches=4, limit_val_batches=1, limit_test_batches=1,
                      **trainer_kwargs)
    trainer.test(model, dataloaders=test_dataloader)

    summary_string = str(summary(model=model, input_size=[(10, *input_shape)],
                                 dtypes=[torch.float], depth=3, verbose=0,
                                 col_names=['input_size', 'output_size', 'num_params'],
                                 row_settings=['depth', 'var_names'], device=torch.device('cpu'))) + '\n' + loss_desc
    with open(results_dir + 'model_desc.md', 'w') as f:
        f.write(summary_string)

    gc.collect()


# Train & Test

## dSprites

In [14]:
train, val, test = get_dsprites()
train_and_test(bs=256, max_epochs=2, tags=[], gpu_num=[], interactive=True,
               model_class=dSprites_VAE, model_kwargs=dict(latent_dim=32, num_z=1, reduce_kl=True),
               loss_desc='Recon & KL Loss, Mean across samples',
               train=train, val=val, test=test, get_dataloaders=get_dsprites_dataloaders,
               input_shape=(1, 64, 64))

Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.

  | Name    | Type        | Params
----------------------------------------
0 | encoder | dCNNEncoder | 32.5 K
1 | decoder | dCNNDecoder | 20.3 K
----------------------------------------
52.9 K    Trainable params
0         Non-trainable params
52.9 K    Total params
0.211     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test/kl_loss          0.9098517298698425
        test/loss            442.3057556152344
     test/recon_loss         441.3959045410156
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


## CelebA

In [19]:
train_and_test(bs=256, max_epochs=2, tags=[], gpu_num=[], interactive=True,
               model_class=celeba_VAE, model_kwargs=dict(latent_dim=256, num_z=1, reduce_kl=True),
               loss_desc='Recon & KL Loss, Mean across samples',
               train='train', val='valid', test='test', get_dataloaders=get_celeba_dataloaders, 
               input_shape=(3, 218, 178))

Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.

  | Name    | Type        | Params
----------------------------------------
0 | encoder | cCNNEncoder | 2.8 M 
1 | decoder | cCNNDecoder | 1.4 M 
----------------------------------------
4.2 M     Trainable params
0         Non-trainable params
4.2 M     Total params
16.931    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test/kl_loss          14.146760940551758
        test/loss            5421.04736328125
     test/recon_loss          5406.900390625
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


# Plots & Analysis

See GitHub code for examples of usage

## dSprites

In [24]:
def disentanglement_results(result_dir):
    model = dSprites_VAE.load_from_checkpoint(project_dir + 'vae/dsprites/results/' + result_dir + '/best.ckpt')

    zs = []
    for dimi in numpy.linspace(-2, 2, 10):
        for dimj in numpy.linspace(-2, 2, 10):
            zs.append([dimi, dimj] + [0] * (model.latent_dim - 2))

    z = torch.tensor(numpy.asarray(zs), dtype=torch.float)

    decoded = model.decoder(z)
    x_hat = decoded.detach().numpy().squeeze(axis=1)

    fig = make_subplots(rows=10, cols=10, shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=0.001, vertical_spacing=0.001)

    for i in range(x_hat.shape[0]):
        col = i % 10 + 1
        row = int(i / 10) + 1
        fig.add_heatmap(z=1-x_hat[i, :, :], row=row, col=col, colorscale='Greys', showscale=False)

    if model.reduce_kl:
        title = f'dSprites - Disentanglement Effect - Latent Dim = {model.latent_dim}'
        img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'dSprites - Disentanglement Effect - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.update_layout(width=1200, height=1200, title=title, title_x=0.5, title_font_size=20, title_xref='paper',
                      title_y=0.94)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.write_image(project_dir + f'vae/dsprites/img_results/vae_dsprites_disent_{img_name}')


In [None]:
def see_some_generations(result_dir):
    model = dSprites_VAE.load_from_checkpoint(project_dir + 'vae/dsprites/results/' + result_dir + '/best.ckpt')
    num_imgs = 100
    z = torch.normal(0, 1, size=(num_imgs, model.latent_dim))
    decoded = model.decoder(z)
    x_hat = decoded.detach().numpy().squeeze(axis=1)

    fig = make_subplots(rows=10, cols=10, shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=0.001, vertical_spacing=0.001)

    for i in range(x_hat.shape[0]):
        col = i % 10 + 1
        row = int(i / 10) + 1
        fig.add_heatmap(z=1-x_hat[i, :, :], row=row, col=col, colorscale='Greys', showscale=False)

    if model.reduce_kl:
        title = f'dSprites - Generated Images - Latent Dim = {model.latent_dim}'
        img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'dSprites - Generated Images - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.update_layout(width=1200, height=1200, title=title, title_x=0.5, title_font_size=20, title_xref='paper',
                      title_y=0.94)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.write_image(project_dir + f'vae/dsprites/img_results/vae_dsprites_gen_{img_name}')


In [None]:
def see_some_reconstructions(result_dir):
    model = dSprites_VAE.load_from_checkpoint(project_dir + 'vae/dsprites/results/' + result_dir + '/best.ckpt')

    targets = numpy.load(data_dir + 'dsprites_split.npz')['test']
    targets_tensor = torch.tensor(targets, dtype=torch.float)
    _, _, predictions = model(targets_tensor)
    predictions = predictions.detach().numpy()
    
    assert targets.shape == predictions.shape
    indexes = numpy.random.choice(targets.shape[0], (50, ))
    actual = numpy.squeeze(targets[indexes], axis=1)
    pred = numpy.squeeze(predictions[indexes], axis=1)

    fig = make_subplots(rows=10, cols=10, shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=0.001, vertical_spacing=0.001)

    for i in range(indexes.shape[0]):
        tcol = (i % 5) + 1
        tcol = 2 * tcol - 1
        pcol = tcol + 1
        row = int(i / 5) + 1
        fig.add_heatmap(z=1-actual[i, :, :], row=row, col=tcol, colorscale='Greys', showscale=False)
        fig.add_heatmap(z=1-pred[i, :, :], row=row, col=pcol, colorscale='Greys', showscale=False)

    if model.reduce_kl:
        title = f'dSprites - Reconstructed Images - Latent Dim = {model.latent_dim}'
        img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'dSprites - Reconstructed Images - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.update_layout(width=1200, height=1200, title=title, title_x=0.5, title_font_size=20, title_xref='paper',
                      title_y=0.94)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.write_image(project_dir + f'vae/dsprites/img_results/vae_dsprites_recon_{img_name}')


## CelebA

In [21]:
def disentanglement_results(result_dir: str, extra_config: str):
    model = celeba_VAE.load_from_checkpoint(project_dir + 'vae/celeba/results/' + result_dir + '/best.ckpt')

    zs = []
    for dimi in numpy.linspace(-2, 2, 10):
        for dimj in numpy.linspace(-2, 2, 10):
            zs.append([dimi, dimj] + [0] * (model.latent_dim - 2))

    z = torch.tensor(numpy.asarray(zs), dtype=torch.float)

    decoded = model.decoder(z)
    x_hat = decoded.detach().numpy()
    x_hat = numpy.transpose(x_hat, (0, 2, 3, 1))

    fig, axes = plt.subplots(10, 10, sharex='all', sharey='all', figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(x_hat.shape[0]):
        ax = axes[i]
        ax.set_axis_off()
        ax.imshow(x_hat[i])

    if model.reduce_kl:
        if extra_config:
            title = f'CelebA - Disentanglement Effect - Latent Dim = {model.latent_dim} with {extra_config}'
            img_name = f'dim={model.latent_dim}_extra={extra_config.replace(" ", "")}.png'
        else:
            title = f'CelebA - Disentanglement Effect - Latent Dim = {model.latent_dim}'
            img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'CelebA - Disentanglement Effect - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.suptitle(title)
    plt.savefig(project_dir + f'vae/celeba/img_results/vae_celeba_disent_{img_name}')


In [22]:
def see_some_generations(result_dir: str, extra_config: str):
    model = celeba_VAE.load_from_checkpoint(project_dir + 'vae/celeba/results/' + result_dir + '/best.ckpt')
    num_imgs = 100
    z = torch.normal(0, 1, size=(num_imgs, model.latent_dim))
    decoded = model.decoder(z)
    x_hat = decoded.detach().numpy()
    x_hat = numpy.transpose(x_hat, (0, 2, 3, 1))

    fig, axes = plt.subplots(10, 10, sharex='all', sharey='all', figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(x_hat.shape[0]):
        ax = axes[i]
        ax.set_axis_off()
        ax.imshow(x_hat[i])

    if model.reduce_kl:
        if extra_config:
            title = f'CelebA - Generated Images - Latent Dim = {model.latent_dim} with {extra_config}'
            img_name = f'dim={model.latent_dim}_extra={extra_config.replace(" ", "")}.png'
        else:
            title = f'CelebA - Generated Images - Latent Dim = {model.latent_dim}'
            img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'CelebA - Generated Images - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.suptitle(title)
    plt.savefig(project_dir + f'vae/celeba/img_results/vae_celeba_gen_{img_name}')


In [23]:
def see_some_reconstructions(result_dir: str, targets, extra_config: str):
    model = celeba_VAE.load_from_checkpoint(project_dir + 'vae/celeba/results/' + result_dir + '/best.ckpt')
    _, _, predictions = model(targets)
    predictions = predictions.detach().numpy()

    assert targets.shape == predictions.shape
    predictions = numpy.transpose(predictions, (0, 2, 3, 1))
    targets = numpy.transpose(targets, (0, 2, 3, 1))

    fig, axes = plt.subplots(10, 10, sharex='all', sharey='all', figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(targets.shape[0]):
        target_idx = 2 * i
        ax = axes[target_idx]
        ax.set_axis_off()
        ax.imshow(targets[i])

        pred_idx = target_idx + 1
        ax = axes[pred_idx]
        ax.set_axis_off()
        ax.imshow(predictions[i])

    if model.reduce_kl:
        if extra_config:
            title = f'CelebA - Reconstructed Images - Latent Dim = {model.latent_dim} with {extra_config}'
            img_name = f'dim={model.latent_dim}_extra={extra_config.replace(" ", "")}.png'
        else:
            title = f'CelebA - Reconstructed Images - Latent Dim = {model.latent_dim}'
            img_name = f'dim={model.latent_dim}.png'
    else:
        title = f'CelebA - Reconstructed Images - Latent Dim = {model.latent_dim}, KL Loss not minimised'
        img_name = f'dim={model.latent_dim}_kl_notminimised.png'

    fig.suptitle(title)
    plt.savefig(project_dir + f'vae/celeba/img_results/vae_celeba_recon_{img_name}')


## Marginal Likelihood

In [25]:
def gradient_evaluated_at(model, x: torch.Tensor, z: torch.Tensor):
    """Evaluate Gradient at one z"""
    z.requires_grad_()
    x_hat = model.decoder(z)
    x_hat = torch.squeeze(x_hat, dim=0)
    mse_loss = 0.5 * MSELoss(reduction='sum')(input=x_hat, target=x)
    mse_loss.backward()
    gradient = -z - z.grad
    gradient.detach_()
    z.detach_()
    return gradient


def take_one_step(model, x: torch.Tensor, old_z: torch.Tensor, step_size: int = 0.001):
    """Make one gradient ascent step"""
    gradient = gradient_evaluated_at(model, x, old_z)
    new_z = old_z + step_size * gradient
    return new_z


def get_samples_from_posterior(model, x: torch.Tensor, ignore: int = 20, L: int = 50) -> numpy.ndarray:
    """Get L samples from the MCMC"""
    z = torch.normal(0, 1, size=(1, model.latent_dim))
    for _ in range(ignore):  # ignore first few steps
        z = take_one_step(model, x, z)

    zs = []
    for _ in range(L):  # start recording the samples
        z = take_one_step(model, x, z)
        zs.append(z)

    zs = torch.concat(zs).numpy()
    return zs


def get_q(model, x: torch.Tensor, apply_pca=True):
    """Use GMMs as density estimator. Use PCA on data if mentioned. Use the learned PCA going forward"""
    zs = get_samples_from_posterior(model, x)
    gmm = GaussianMixture(n_components=4, init_params='k-means++')
    if apply_pca:
        pca = PCA(n_components=4)
        zs = pca.fit_transform(zs)
    else:
        pca = None
    gmm.fit(zs)
    return gmm, pca


def evaluate_prior(model, z):
    """Evaluate the log prior"""
    normalise = -model.latent_dim * numpy.log(2 * math.pi) / 2
    log_prob = -z.dot(z) / 2
    return normalise + log_prob


def evaluate_posterior(model, x, z):
    """Evaluate the log posterior"""
    z = torch.tensor(z).reshape(1, -1)
    x_hat = model.decoder(z)
    x_hat = torch.squeeze(x_hat, dim=0)
    diff = x - x_hat
    diff = diff.detach().numpy().reshape(-1)
    normalise = -model.latent_dim * numpy.log(2 * math.pi) / 2
    log_prob = -diff.dot(diff) / 2
    return normalise + log_prob


def log_likelihood_estimate(model, x: torch.Tensor):
    """Get the likelihood estimate for one sample, x"""
    gmm, pca = get_q(model, x, apply_pca=model.latent_dim > 5)  # Get the density estimator. 
                                                                # Apply PCA if latent dim is more than 5

    zs = get_samples_from_posterior(model, x)  # Get fresh samples from MCMC 
    if pca:
        tzs = pca.transform(zs)
    else:
        tzs = zs

    values = []
    for j in range(zs.shape[0]):
        z = zs[j]
        tz = tzs[j]
        log_q = gmm.score_samples(tz.reshape(1, -1))[0]  # Get the density estimate 
        log_prior = evaluate_prior(model, z)
        log_posterior = evaluate_posterior(model, x, z)
        value = log_q - log_prior - log_posterior  # Get all values in terms of log
        val = numpy.exp(value)  # Get actual value
        values.append(val)
    inv = numpy.mean(values)  # Take the mean
    ll = -numpy.log(inv)  # Get the log likelihood estimate 
    return ll
