In [2]:
!pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.4.2-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)
[K     |████████████████████████████████| 10.3 MB 2.4 MB/s eta 0:00:01
Collecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Collecting kiwisolver>=1.0.1
  Using cached kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)
Installing collected packages: cycler, kiwisolver, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.2


In [3]:
# Standard
import numpy as np
from matplotlib.pyplot import imshow, figure

# Pytorch
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.utils import make_grid

# Pytorch Lightning
import pytorch_lightning as pl
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)

In [4]:
CIFAR10_DIR = '/home/djbiega/Documents/school/practicum/vae/'

In [9]:
class MyCIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

        # This is inefficient, but whatever
        train_dataset = torchvision.datasets.CIFAR10(
            CIFAR10_DIR,
            train=True
        )
        mean, std = self._get_mean_std(train_dataset)
        self.transforms = torchvision.transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_dataset = torchvision.datasets.CIFAR10(
                CIFAR10_DIR,
                train=True,
                transform = self.transforms
            )

            num_train = int(0.75*len(train_dataset))
            num_val = int(0.25*len(train_dataset))
            self.train_dataset, self.val_dataset = random_split(
                train_dataset, 
                lengths=[num_train, num_val]
            )
        if stage == 'test' or stage is None:
            self.test_dataset = torchvision.datasets.CIFAR10(
                CIFAR10_DIR,
                train=False,
                transform = self.transforms
            )
        
    def train_dataloader(self):
        cifar_train = DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True,
            num_workers=4
        )
        return cifar_train 
        
    def val_dataloader(self):
        cifar_val = DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=True,
            num_workers=4
        )
        return cifar_val
        
    def test_dataloader(self):
        cifar_test = DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=True,
            num_workers=4
        )
        return cifar_test
    
    def _get_mean_std(self, dataset):
        x = np.concatenate([np.asarray(dataset[i][0]) for i in range(len(dataset))])
        mean = np.mean(x, axis=(0, 1))
        std = np.std(x, axis=(0, 1))
        return mean, std

In [6]:
class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        x_hat = self.decoder(z)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(),
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        })

        return elbo

In [7]:
class ImageSampler(pl.Callback):
    def __init__(self):
        super().__init__()
        self.img_size = None
        self.num_preds = 16

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        figure(figsize=(8, 3), dpi=300)

        # Z COMES FROM NORMAL(0, 1)
        rand_v = torch.rand((self.num_preds, pl_module.hparams.latent_dim), device=pl_module.device)
        p = torch.distributions.Normal(torch.zeros_like(rand_v), torch.zeros_like(rand_v))
        z = p.rsample()

        # SAMPLE IMAGES
        with torch.no_grad():
            pred = pl_module.decoder(z.to(pl_module.device)).cpu()

        # UNDO DATA NORMALIZATION
        normalize = cifar10_normalization()
        mean, std = np.array(normalize.mean), np.array(normalize.std)
        img = make_grid(pred).permute(1, 2, 0).numpy() * std + mean

        # PLOT IMAGES
        trainer.logger.experiment.add_image('img',torch.tensor(img).permute(2, 0, 1), global_step=trainer.global_step)



In [10]:
dataset = MyCIFAR10DataModule()
sampler = ImageSampler()

vae = VAE()
trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[sampler])
trainer.fit(vae, dataset)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | encoder | ResNetEncoder | 11.2 M
1 | decoder | ResNetDecoder | 8.6 M 
2 | fc_mu   | Linear        | 131 K 
3 | fc_var  | Linear        | 131 K 
------------------------------------------
20.1 M    Trainable params
0         Non-trainable params
20.1 M    Total params
80.228    Total estimated model params size (MB)


Epoch 0: 100%|██████████| 1172/1172 [37:19<00:00,  1.91s/it, loss=2.49e+03, v_num=1]

ValueError: The parameter scale has invalid values

<Figure size 2400x900 with 0 Axes>