<a href="https://colab.research.google.com/github/karnwatcharasupat/latte/blob/issues%2F17-examples/examples/mnist-torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example Notebook for using Latte with Pytorch Lightning

## Getting Started

### Installing Latte and Dependencies

In [1]:
# This command automatically install PyTorch and TorchMetrics.
# For users with existing pytorch>=1.3.1 and torchmetrics>=0.2.0 installation, 
#   use `pip install latte-metrics` with no extras
!pip install -q latte-metrics[pytorch]  

# Pytorch Lightning is installed independently
!pip install -q pytorch-lightning       

### Downloading dataset

In [None]:
!mkdir /content/dataset
!gdown --id "1fFGJW0IHoBmLuD6CEKCB8jz3Y5LJ5Duk" -O /content/dataset/morphomnist.zip
!unzip "/content/dataset/morphomnist.zip" -d /content/dataset/

### Cloning Morpho-MNIST measurement code

In [3]:
!git clone https://github.com/dccastro/Morpho-MNIST

fatal: destination path 'Morpho-MNIST' already exists and is not an empty directory.


## Creating a simple VAE

Using the model from
> A. Pati and A. Lerch, Attribute-based regularization of latent spaces for variational auto-encoders. Neural Computing & Applications, 33, 4429–4444 (2021). https://doi.org/10.1007/s00521-020-05270-2



In [None]:
import torch
from torch import nn
from torch import distributions

class MnistVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_size = 784
        self.z_dim = 16
        self.inter_dim = 19
        self.enc_conv = nn.Sequential(
            nn.Conv2d(1, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.Conv2d(64, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.Conv2d(64, 8, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
        )
        self.enc_lin = nn.Sequential(
            nn.Linear(2888, 256),
            nn.SELU()
        )
        self.enc_mean = nn.Linear(256, self.z_dim)
        self.enc_log_std = nn.Linear(256, self.z_dim)
        self.dec_lin = nn.Sequential(
            nn.Linear(self.z_dim, 256),
            nn.SELU(),
            nn.Linear(256, 2888),
            nn.SELU()
        )
        self.dec_conv = nn.Sequential(
            nn.ConvTranspose2d(8, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.ConvTranspose2d(64, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.ConvTranspose2d(64, 1, 4, 1),
        )

        self.xavier_initialization()

    def xavier_initialization(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)

    def encode(self, x):
        hidden = self.enc_conv(x)
        hidden = hidden.view(x.size(0), -1)
        hidden = self.enc_lin(hidden)
        z_mean = self.enc_mean(hidden)
        z_log_std = self.enc_log_std(hidden)
        z_distribution = distributions.Normal(loc=z_mean, scale=torch.exp(z_log_std))
        return z_distribution

    def decode(self, z):
        hidden = self.dec_lin(z)
        hidden = hidden.view(z.size(0), -1, self.inter_dim, self.inter_dim)
        hidden = self.dec_conv(hidden)
        return hidden

    def reparametrize(self, z_dist):
        # sample from distribution
        z_tilde = z_dist.rsample()

        # compute prior
        prior_dist = torch.distributions.Normal(
            loc=torch.zeros_like(z_dist.loc),
            scale=torch.ones_like(z_dist.scale)
        )
        z_prior = prior_dist.sample()
        return z_tilde, z_prior, prior_dist

    def forward(self, x):
        # compute distribution using encoder
        z_dist = self.encode(x)

        # reparametrize
        z_tilde, z_prior, prior_dist = self.reparametrize(z_dist)

        # compute output of decoding layer
        output = self.decode(z_tilde).view(x.size())

        return output, z_dist, prior_dist, z_tilde, z_prior