# Marginal loglikelihood

In [1]:
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.distributions as tdist
import numpy as np

import pmldiku
from pmldiku import data, vae, model_utils

%load_ext autoreload
%autoreload 1
%aimport pmldiku.data, pmldiku.vae, pmldiku.model_utils

In [2]:
trained_models = model_utils.show_trained_models()
trained_models

Model mnist-CVAE-MSE-epoch=24-val_loss=4918.31.ckpt; size 1.18 mb.
Model mnist-VAE-CB-epoch=40-val_loss=-205819.23.ckpt; size 8.53 mb.
Model mnist-CVAE-BCE-epoch=16-val_loss=19557.79.ckpt; size 1.18 mb.
Model mnist-VAE-BCE-epoch=34-val_loss=18057.60.ckpt; size 8.53 mb.
Model mnist-VAE-MSE-epoch=21-val_loss=18096.77.ckpt; size 8.53 mb.
Model mnist-VAE-MSE-epoch=15-val_loss=4599.13.ckpt; size 8.53 mb.


TrainedModels(models={'mnist-CVAE-MSE-epoch=24-val_loss=4918.31.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-CVAE-MSE-epoch=24-val_loss=4918.31.ckpt'), 'mnist-VAE-CB-epoch=40-val_loss=-205819.23.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-VAE-CB-epoch=40-val_loss=-205819.23.ckpt'), 'mnist-CVAE-BCE-epoch=16-val_loss=19557.79.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-CVAE-BCE-epoch=16-val_loss=19557.79.ckpt'), 'mnist-VAE-BCE-epoch=34-val_loss=18057.60.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-VAE-BCE-epoch=34-val_loss=18057.60.ckpt'), 'mnist-VAE-MSE-epoch=21-val_loss=18096.77.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-VAE-MSE-epoch=21-val_loss=18096.77.ckpt'), 'mnist-VAE-MSE-epoch=15-val_loss=4599.13.ckpt': PosixPath('/scratch/fjr906/projects/pml/pmldiku-exam-paper/code/models/mnist-VAE-

In [5]:
torch.manual_seed(1)

CUDA = True
BATCH_SIZE = 128
LOGPX_LOSS = "bce"
DEVICE_NAME = "cuda" if CUDA else "cpu"

DEVICE = torch.device(DEVICE_NAME)
kwargs = {'num_workers': 4, 'pin_memory': DEVICE} 

train_loader = data.load_mnist(train=True).setup_data_loader(batch_size=BATCH_SIZE, **kwargs)
val_loader = data.load_mnist(train=False).setup_data_loader(batch_size=BATCH_SIZE, **kwargs)  

In [6]:
X, y = next(iter(val_loader))

# Compute marginal likelihood

## VAE

In [14]:
base_vae = vae.BaseVAE()
model = vae.LitVAE.load_from_checkpoint(trained_models.models["mnist-VAE-BCE-epoch=34-val_loss=18057.60.ckpt"], vae=base_vae, logpx_loss=LOGPX_LOSS)
model

LitVAE(
  (vae): BaseVAE(
    (fc1): Linear(in_features=784, out_features=400, bias=True)
    (fc1a): Linear(in_features=400, out_features=100, bias=True)
    (fc21): Linear(in_features=100, out_features=2, bias=True)
    (fc22): Linear(in_features=100, out_features=2, bias=True)
    (fc3): Linear(in_features=2, out_features=100, bias=True)
    (fc3a): Linear(in_features=100, out_features=400, bias=True)
    (fc4): Linear(in_features=400, out_features=784, bias=True)
  )
)

In [15]:
mloglik_vae = vae.MarginalLogLikVAE(val_loader, model, latent_dim=2, L=5)
logpx_vae = mloglik_vae.estimate()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:11<00:00,  1.10it/s]


In [19]:
logpx_vae

tensor(-39.0804, grad_fn=<AddBackward0>)

## CVAE

In [16]:
base_vae = vae.CVAE(hidden_dim=2)
model = vae.LitVAE.load_from_checkpoint(trained_models.models["mnist-CVAE-BCE-epoch=16-val_loss=19557.79.ckpt"], vae=base_vae, logpx_loss=LOGPX_LOSS)
model

LitVAE(
  (vae): CVAE(
    (encoder): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Flatten(start_dim=1, end_dim=-1)
    )
    (fc_mean): Linear(in_features=3136, out_features=2, bias=True)
    (fc_logvar): Linear(in_features=3136, out_features=2, bias=True)
    (decoder): Sequential(
      (0): Linear(in_features=2, out_features=3136, bias=True)
      (1): Unflatten(dim=1, unflattened_size=[64, 7, 7])
      (2): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): ReLU()
      (6): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(1, 1))
      (7): Sigmoid()
    )
  )
)

In [20]:
mloglik_vae = vae.MarginalLogLikVAE(val_loader, model, latent_dim=2, L=5)
logpx_cvae = mloglik_vae.estimate()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:54<00:00,  1.45s/it]


In [21]:
logpx_cvae

tensor(-44.1215, grad_fn=<AddBackward0>)