<a href="https://colab.research.google.com/github/karnwatcharasupat/latte/blob/issues%2F17-examples/examples/mnist-torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using Latte with Pytorch Lightning

Before you begin, please turn on GPU accelerator at `Runtime > Change runtime type > Hardware accelerator > GPU`.


## Installing Latte and Dependencies

In [18]:
# This command automatically install PyTorch and TorchMetrics.
# For users with existing pytorch>=1.3.1 and torchmetrics>=0.2.0 installation, 
#   use `pip install latte-metrics` with no extras
!pip install -q latte-metrics[pytorch]  

# Pytorch Lightning is installed independently
!pip install -q pytorch-lightning       

## Preparing data

### Downloading dataset

In [19]:
!mkdir -p /content/dataset
!gdown --id "1fFGJW0IHoBmLuD6CEKCB8jz3Y5LJ5Duk" -O /content/dataset/morphomnist.zip
!unzip -o "/content/dataset/morphomnist.zip" -d /content/dataset/

Downloading...
From: https://drive.google.com/uc?id=1fFGJW0IHoBmLuD6CEKCB8jz3Y5LJ5Duk
To: /content/dataset/morphomnist.zip
  0% 0.00/15.5M [00:00<?, ?B/s]100% 15.5M/15.5M [00:00<00:00, 137MB/s]
Archive:  /content/dataset/morphomnist.zip
 extracting: /content/dataset/global/train-pert-idx1-ubyte.gz  
  inflating: /content/dataset/global/train-images-idx3-ubyte.gz  
  inflating: /content/dataset/global/train-morpho.csv  
 extracting: /content/dataset/global/train-labels-idx1-ubyte.gz  
 extracting: /content/dataset/global/t10k-pert-idx1-ubyte.gz  
  inflating: /content/dataset/global/t10k-images-idx3-ubyte.gz  
  inflating: /content/dataset/global/t10k-morpho.csv  
 extracting: /content/dataset/global/t10k-labels-idx1-ubyte.gz  
  inflating: /content/dataset/global/README-global.txt  


### Cloning Morpho-MNIST measurement code

In [20]:
!git clone https://github.com/dccastro/Morpho-MNIST

fatal: destination path 'Morpho-MNIST' already exists and is not an empty directory.


In [21]:
import sys
sys.path.append('/content/Morpho-MNIST')

### Creating dataloader

In [33]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from morphomnist import io, morpho

class MorphoMnistDataset():

    def __init__(self, root_dir='/content/dataset/global'):
        super().__init__()
        self.kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
        self.root_dir = root_dir
        self.data_path_str = "-images-idx3-ubyte.gz"
        self.label_path_str = "-labels-idx1-ubyte.gz"
        self.morpho_path_str = "-morpho.csv"

        self.train_dataset = self._create_dataset(dataset_type="train")
        self.val_dataset = self._create_dataset(dataset_type="t10k")

    def data_loaders(self, batch_size):
        train_dl = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            **self.kwargs
        )
        val_dl = DataLoader(
            self.val_dataset,
            batch_size=batch_size,
            shuffle=False,
        )
        return train_dl, val_dl

    def _create_dataset(self, dataset_type="train"):
        data_path = os.path.join(
            self.root_dir,
            dataset_type + self.data_path_str
        )
        morpho_path = os.path.join(
            self.root_dir,
            dataset_type + self.morpho_path_str
        )
        images = io.load_idx(data_path)
        images = np.expand_dims(images, axis=1).astype('float32') / 255.0
        morpho_labels = pd.read_csv(morpho_path).values.astype('float32')
        dataset = TensorDataset(
            torch.from_numpy(images),
            torch.from_numpy(morpho_labels)
        )
        return dataset

## Creating a simple VAE

Using the model from
> A. Pati and A. Lerch, Attribute-based regularization of latent spaces for variational auto-encoders. Neural Computing & Applications, 33, 4429–4444 (2021). https://doi.org/10.1007/s00521-020-05270-2



In [34]:
from torch import nn, distributions

class ImageVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_size = 784
        self.z_dim = 16
        self.inter_dim = 19
        self.enc_conv = nn.Sequential(
            nn.Conv2d(1, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.Conv2d(64, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.Conv2d(64, 8, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
        )
        self.enc_lin = nn.Sequential(
            nn.Linear(2888, 256),
            nn.SELU()
        )
        self.enc_mean = nn.Linear(256, self.z_dim)
        self.enc_log_std = nn.Linear(256, self.z_dim)
        self.dec_lin = nn.Sequential(
            nn.Linear(self.z_dim, 256),
            nn.SELU(),
            nn.Linear(256, 2888),
            nn.SELU()
        )
        self.dec_conv = nn.Sequential(
            nn.ConvTranspose2d(8, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.ConvTranspose2d(64, 64, 4, 1),
            nn.SELU(),
            nn.Dropout(0.5),
            nn.ConvTranspose2d(64, 1, 4, 1),
        )

        self.xavier_initialization()

    def xavier_initialization(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)

    def encode(self, x):
        hidden = self.enc_conv(x)
        hidden = hidden.view(x.size(0), -1)
        hidden = self.enc_lin(hidden)
        z_mean = self.enc_mean(hidden)
        z_log_std = self.enc_log_std(hidden)
        z_distribution = distributions.Normal(loc=z_mean, scale=torch.exp(z_log_std) + 1e-16)
        return z_distribution

    def decode(self, z):
        hidden = self.dec_lin(z)
        hidden = hidden.view(z.size(0), -1, self.inter_dim, self.inter_dim)
        hidden = self.dec_conv(hidden)
        return hidden

    def reparametrize(self, z_dist):
        # sample from distribution
        z_tilde = z_dist.rsample()

        # compute prior
        prior_dist = torch.distributions.Normal(
            loc=torch.zeros_like(z_dist.loc),
            scale=torch.ones_like(z_dist.scale)
        )
        return z_tilde, prior_dist

    def forward(self, x):
        # compute distribution using encoder
        z_dist = self.encode(x)

        # reparametrize
        z_tilde, prior_dist = self.reparametrize(z_dist)

        # compute output of decoding layer
        output = self.decode(z_tilde).view(x.size())

        return output, z_dist, prior_dist, z_tilde

## Defining the Loss Function

In [49]:
from torch.nn import functional as F

def ar_signed_loss(z, a, factor=10.0):

    n_attr = a.shape[-1]

    # compute latent distance matrix
    lc_dist_mat = z[:, None, :n_attr] - z[None, :, :n_attr]

    # compute attribute distance matrix
    attribute_dist_mat = a[:, None, ...] - a[None, :, :]

    # compute regularization loss
    lc_tanh = torch.tanh(lc_dist_mat * factor)
    attribute_sign = torch.sign(attribute_dist_mat)
    ar_loss = F.l1_loss(lc_tanh, attribute_sign.float(), reduction='sum')/z.shape[0]

    return ar_loss

def compute_loss(x, xhat, zd, z0, z, a):

    recon_loss = F.mse_loss(x, torch.sigmoid(xhat), reduction='sum')/z.shape[0]

    kld_loss = distributions.kl.kl_divergence(zd, z0).sum(-1).mean()

    ar_loss = ar_signed_loss(z, a)

    return recon_loss + kld_loss + ar_loss

## Initialize Model, Optimizer, and Metrics

In [50]:
batch_size = 32
num_epochs = 10
lr = 1e-4

In [51]:
import latte
from latte.metrics.torch.bundles import DependencyAwareMutualInformationBundle

latte.seed(42) 
# there is no need for this
# this is just to demonstrate that you can manually set a seed
# Latte uses seed=42 by default anyway

train_dl, val_dl =  MorphoMnistDataset().data_loaders(batch_size)

model = ImageVAE().cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

dami = DependencyAwareMutualInformationBundle(reg_dim=range(7))

Docstring for `DependencyAwareMutualInformationBundle`

    Calculate between latent vectors and attributes:
        - Mutual Information Gap (MIG) 
        - Dependency-Aware Mutual Information Gap (DMIG) 
        - Dependency-Blind Mutual Information Gap (XMIG) 
        - Dependency-Aware Latent Information Gap (DLIG) 

    Parameters
    ----------
    reg_dim : Optional[List], optional
        regularized dimensions, by default None
        Attribute `a[:, i]` is regularized by `z[:, reg_dim[i]]`. If `None`, `a[:, i]` is assumed to be regularized by `z[:, i]`.
    discrete : bool, optional
        Whether the attributes are discrete, by default False

    Returns
    -------
    Dict[str, np.ndarray]
        A dictionary of mutual information metrics with keys ['MIG', 'DMIG', 'XMIG', 'DLIG'] each mapping to a corresponding metric np.ndarray of shape (n_attributes,).
    
    References
    ----------
    .. [1] Q. Chen, X. Li, R. Grosse, and D. Duvenaud, “Isolating sources of disentanglement in variational autoencoders”, in Proceedings of the 32nd International Conference on Neural Information Processing Systems, 2018.
    .. [2] K. N. Watcharasupat and A. Lerch, “Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes”, in Extended Abstracts of the Late-Breaking Demo Session of the 22nd International Society for Music Information Retrieval Conference, 2021.
    .. [3] K. N. Watcharasupat, “Controllable Music: Supervised Learning of Disentangled Representations for Music Generation”, 2021.

## Training the model

In [None]:
from tqdm.notebook import tqdm
n_batch = len(train_dl)

postfix = {}


for epoch_index in range(num_epochs):

    model.train()
    with tqdm(total=n_batch) as prog_bar:
        prog_bar.set_description(f"epoch {epoch_index+1}/{num_epochs}")

        for i, data in enumerate(train_dl):
            prog_bar.update()
            
            inputs, attributes = data
            inputs = inputs.cuda()
            attributes = attributes.cuda()

            recon, z_dist, prior_dist, z_tilde = model(inputs)
            
            model.zero_grad()

            loss = compute_loss(
                inputs, recon, z_dist, prior_dist, z_tilde, attributes
            )

            postfix.update({"train/loss": f"{loss.detach().cpu().numpy():3.2g}"})
            
            # compute train MIG using information the last batch of the train set in each epoch
            # the return values are converted back to torch dtype automatically
            if i == n_batch - 1:

                # Latte automatically move all data to CPU
                # There is no need to call `.cpu()` here.
                dami.update(z_tilde, attributes)

                train_metrics = dami.compute()
                
                # Morpho-MNIST has 7 attributes. 
                # Each metric in `dami_train` has shape (7,) 
                #   with each element representing the MIG for each attribute.
                # We only put the mean metrics over the attributes here on the progress bar for demonstration
                postfix.update({f"train/{k}": f"{train_metrics[k].numpy().mean():.3g}" for k in train_metrics})

            loss.backward()
            optimizer.step()

            prog_bar.set_postfix(postfix)

        # reset cache for validation loop
        dami.reset()

        model.eval()
        val_loss = 0

        for data in val_dl:

            inputs, attributes = data
            inputs = inputs.cuda()
            attributes = attributes.cuda()

            recon, z_dist, prior_dist, z_tilde = model(inputs)

            loss = compute_loss(
                inputs, recon, z_dist, prior_dist, z_tilde, attributes
            )

            val_loss += loss

            # use the entire validation set to compute metrics this time
            dami.update(z_tilde, attributes)

        # only compute once at the end of the validation loop 
        # using all validation batches
        val_metrics = dami.compute()

        print(f"Epoch {epoch_index+1}/{num_epochs}")
        print(f"Validation loss: {val_loss/len(val_dl):3.2g}")
        for metric in val_metrics:
            print(f"Validation {metric}:\t{val_metrics[metric].numpy().mean():.3g}")

        # reset cache for the next train loop
        dami.reset()

  0%|          | 0/1875 [00:00<?, ?it/s]