## MNIST VAE

The MNIST dataset is a dataset of handwritten digits that is commonly used as the 'Hello World' dataset in Deep Learning domain. It contains 60,000 training images and 10,000 testing images, and
`carefree-learn` provided a straightforward API to access it.

MNIST dataset can be used for training various image processing systems. In this article, we will focus on how to build our custom models to solve the Variational Auto Encoder (VAE) task on MNIST dataset.

In [1]:
# preparations

import torch
import cflearn

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from typing import Any
from typing import Dict
from typing import Optional
from cflearn.types import losses_type
from cflearn.types import tensor_dict_type
from cflearn.protocol import TrainerState
from cflearn.misc.toolkit import check_is_ci
from cflearn.misc.toolkit import interpolate
from cflearn.misc.toolkit import inject_debug
from cflearn.modules.blocks import Lambda
from cflearn.modules.blocks import UpsampleConv2d

# MNIST dataset could be prepared with this one line of code
data = cflearn.cv.MNISTData(batch_size=16, transform="for_generation")

# for reproduction
np.random.seed(142857)
torch.manual_seed(142857)

<torch._C.Generator at 0x20c148231b0>

As shown above, the MNIST dataset could be easily turned into a `DLDataModule` instance, which is the common data interface used in `carefree-learn`.

> The `transform` argument specifies which transform do we want to use to pre-process the input batch. See [`Transforms`](https://carefree0910.me/carefree-learn-doc/docs/user-guides/computer-vision#transforms) for more details.

### Build Model

For demo purpose, we are going to build a simple convolution-based VAE:

In [2]:
@cflearn.register_module("simple_vae")
class SimpleVAE(nn.Module):
    def __init__(self, in_channels: int, img_size: int):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(1),
        )
        self.decoder = nn.Sequential(
            Lambda(lambda t: t.view(-1, 4, 4, 4), name="reshape"),
            nn.Conv2d(4, 128, 1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            UpsampleConv2d(128, 64, kernel_size=3, padding=1, factor=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            UpsampleConv2d(64, 32, kernel_size=3, padding=1, factor=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            UpsampleConv2d(32, in_channels, kernel_size=3, padding=1, factor=2),
            Lambda(lambda t: interpolate(t, size=img_size, mode="bilinear")),
        )

    def forward(self, net: torch.Tensor) -> Dict[str, torch.Tensor]:
        net = self.encoder(net)
        mu, log_var = net.chunk(2, dim=1)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        net = eps * std + mu
        net = self.decoder(net)
        return {"mu": mu, "log_var": log_var, cflearn.PREDICTIONS_KEY: net}

There are quite a few details that worth to be mentioned:
+ We leveraged the [`register_module`](https://carefree0910.me/carefree-learn-doc/docs/developer-guides/computer-vision-customization#customize-models) API here, which can turn a general `nn.Module` instance to a [`ModelProtocol`](https://carefree0910.me/carefree-learn-doc/docs/design-principles/#model) in `carefree-learn`. After registered, it can be easily accessed with its name (`"simple_vae"`)
+ We leveraged some built-in [common blocks](https://carefree0910.me/carefree-learn-doc/docs/design-principles#common-blocks) of `carefree-learn` to build our simple VAE:
  + `Lambda`, which can turn a function to an `nn.Module`.
  + `UpsampleConv2d`, which can be used to upsample the input image.
  + `interpolate`, which is a handy function to resize the input image to the desired size.

After we finished implementing our model, we need to implement the special loss used in VAE tasks:

In [3]:
@cflearn.register_loss_module("simple_vae")
@cflearn.register_loss_module("simple_vae_foo")
class SimpleVAELoss(cflearn.LossModule):
    def forward(
        self,
        forward_results: tensor_dict_type,
        batch: tensor_dict_type,
        state: Optional[TrainerState] = None,
        **kwargs: Any,
    ) -> losses_type:
        # reconstruction loss
        original = batch[cflearn.INPUT_KEY]
        reconstruction = forward_results[cflearn.PREDICTIONS_KEY]
        mse = F.mse_loss(reconstruction, original)
        # kld loss
        mu = forward_results["mu"]
        log_var = forward_results["log_var"]
        kld_losses = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1)
        kld_loss = torch.mean(kld_losses, dim=0)
        # gather
        loss = mse + 0.001 * kld_loss
        return {"mse": mse, "kld": kld_loss, cflearn.LOSS_KEY: loss}

+ We used `register_loss_module` to register a general `LossModule` instance to a `LossProtocol` in `carefree-learn`.
+ We can call `register_loss_module` multiple times to assign multiple names to the same loss function.
+ When the loss function shares the same name with the model, we don't need to specify the `loss_name` argument explicitly:

In [4]:
# Notice that we don't need to explicitly specify `loss_name`!
cflearn.api.fit_cv(
    data,
    "simple_vae",
    {"in_channels": 1, "img_size": 28},
    fixed_epoch=1,                                  # for demo purpose, we only train our model for 1 epoch
    cuda=0 if torch.cuda.is_available() else None,  # use CUDA if possible
)

Layer (type)                             Input Shape                             Output Shape    Trainable Param #
------------------------------------------------------------------------------------------------------------------------
_                                                                                                                 
  SimpleVAE                                                                                                       
    Sequential-0                     [-1, 1, 28, 28]                                [-1, 128]               97,376
      Conv2d-0                       [-1, 1, 28, 28]                         [-1, 16, 28, 28]                  160
      ReLU-0                        [-1, 16, 28, 28]                         [-1, 16, 28, 28]                    0
      BatchNorm2d-0                 [-1, 16, 28, 28]                         [-1, 16, 28, 28]                   32
      MaxPool2d-0                   [-1, 16, 28, 28]                      

<cflearn.api.cv.pipeline.CarefreePipeline at 0x20c1468b9e8>

Of course, we can still specify `loss_name` explicitly:

In [5]:
cflearn.api.fit_cv(
    data,
    "simple_vae",
    {"in_channels": 1, "img_size": 28},
    loss_name="simple_vae_foo",                     # we used the second registered name here
    fixed_epoch=1,                                  # for demo purpose, we only train our model for 1 epoch
    cuda=0 if torch.cuda.is_available() else None,  # use CUDA if possible
)

Layer (type)                             Input Shape                             Output Shape    Trainable Param #
------------------------------------------------------------------------------------------------------------------------
_                                                                                                                 
  SimpleVAE                                                                                                       
    Sequential-0                     [-1, 1, 28, 28]                                [-1, 128]               97,376
      Conv2d-0                       [-1, 1, 28, 28]                         [-1, 16, 28, 28]                  160
      ReLU-0                        [-1, 16, 28, 28]                         [-1, 16, 28, 28]                    0
      BatchNorm2d-0                 [-1, 16, 28, 28]                         [-1, 16, 28, 28]                   32
      MaxPool2d-0                   [-1, 16, 28, 28]                      

<cflearn.api.cv.pipeline.CarefreePipeline at 0x20c0d5d9e80>