# Catalyst - customizing what happens in `train()`
based on `Keras customizing what happens in fit`

## Introduction

When you're doing supervised learning, you can use `train()` and everything works smoothly.

A core principle of Catalyst is **progressive disclosure of complexity**. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaing a commensurate amount of high-level convenience. 

When you need to customize what `train()` does, you should **override the `_handle_batch` function of the `Runner` class**. This is the function that is called by `train()` for every batch of data. You will then be able to call `train()` as usual -- and it will be running your own learning algorithm.

Note that this pattern does not prevent you from building models with the Functional API. You can do this with **any** PyTorch model.

Let's see how that works.

## Setup

In [None]:
!pip install catalyst==20.10.1
# don't forget to restart runtime for correct `PIL` work with Colab

In [None]:
import catalyst
from catalyst import dl, utils
catalyst.__version__

## A first simple example

Let's start from a simple example:

- We create a new runner that subclasses `dl.Runner`.
- We just override the method `_handle_batch(self, batch)`.
- We do our train step with any possible custom logic.

The input argument `batch` is what gets passed to fit as training data. If you pass a `torch.utils.data.DataLoader`, by calling `train(loaders={"train": loader, "valid": loader}, ...)`, then `batch` will be what gets yielded by `loader` at each batch.

In the body of the `_handle_batch` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we log metrics via `self.batch_metrics`**, which passes them to the loggers.

In [None]:
import torch
from torch.nn import functional as F

class CustomRunner(dl.Runner):

  def _handle_batch(self, batch):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `train()`.
    x, y = batch

    y_pred = self.model(x) # Forward pass

    # Compute the loss value
    loss = F.mse_loss(y_pred, y)

    # Update metrics (includes the metric that tracks the loss)
    self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})

    if self.is_train_loader:
      # Compute gradients
      loss.backward()
      # Update weights
      # (the optimizer is stored in `self.state`)
      self.optimizer.step()
      self.optimizer.zero_grad()

Let's try this out:

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer, 
  loaders=loaders, 
  num_epochs=3,
  verbose=True, # you can pass True for more precise training process logging
  timeit=False, # you can pass True to measure execution time of different parts of train process
)

## Going high-level

Naturally, you could skip a loss function backward in `_handle_batch()`, and instead do everything with `Callbacks` in `train` params. Likewise for metrics. Here's a high-level example, that only uses `_handle_batch()` for model forward pass and metrics computation:

In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset


class CustomRunner(dl.Runner):

  def _handle_batch(self, batch):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `train()`.
    x, y = batch

    y_pred = self.model(x) # Forward pass

    # Compute the loss value
    # (the criterion is stored in `self.state` also)
    loss = self.criterion(y_pred, y)

    # Update metrics (includes the metric that tracks the loss)
    self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})


# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,       # you could also pass any PyTorch criterion for loss computation
  scheduler=None,            # or scheduler, but let's simplify the train loop for now :)
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  callbacks={
    "optimizer": dl.OptimizerCallback(
      metric_key="loss",     # you can also pass 'mae' to optimize it instead
                             # generaly, you can optimize any differentiable metric from `runner.batch_metrics`
      accumulation_steps=1,  # also you can pass any number of steps for gradient accumulation
      grad_clip_params=None, # or yor use `{"func": "clip_grad_norm_", max_norm=1, norm_type=2}`
                             #         or `{"func": "clip_grad_value_", clip_value=1}`
                             # for gradient clipping during training!
                             # for more information about gradient clipping please follow pytorch docs
                             # https://pytorch.org/docs/stable/nn.html#clip-grad-norm
    )
  }
)

## Metrics support through Callbacks

Let's go even deeper! Could we transfer different metrics/criterions computation to `Callbacks` too? Of course! If you want to support different losses, you'd simply do the following:

- Do your model forward pass as usual.
- Save model input to `runner.input` and model output to `runner.output`, so Callbacks can find it.
- Add extra callbacks, that will use data from `runner.input` and `runner.output` for computation.

That's it. That's the list. Let's see the example:

In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset


class CustomRunner(dl.Runner):

  def _handle_batch(self, batch):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `train()`.
    x, y = batch

    y_pred = self.model(x) # Forward pass
    
    # pass network input to state `input`
    self.input = {"features": x, "targets": y}
    # and network output to state `output`
    # we recommend to use key-value storage to make it Callbacks-friendly
    self.output = {"logits": y_pred}


# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,
  scheduler=None,
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  callbacks={
    "criterion": dl.CriterionCallback(  # special Callback for criterion computation
      input_key="targets",              # `input_key` specifies correct labels (or `y_true`) from `runner.input` 
      output_key="logits",              # `output_key` specifies model predictions (`y_pred`) from `runner.output`
      prefix="loss",                    # `prefix` - key to use with `runner.batch_metrics`
    ),  # alias for `runner.batch_metrics[prefix] = runner.criterion(runner.output[output_key], runner.input[input_key])`
    "metric": dl.MetricCallback(        # special Callback for metrics computation
      input_key="targets",              # shares logic with `CriterionCallback`
      output_key="logits",
      prefix="loss_mae",
      metric_fn=F.l1_loss,              # metric function to use
    ),  # alias for `runner.batch_metrics[prefix] = metric_fn(runner.output[output_key], runner.input[input_key])`
    "optimizer": dl.OptimizerCallback(
      metric_key="loss", 
      accumulation_steps=1,
      grad_clip_params=None,
    )
  }
)

## Simplify it a bit - SupervisedRunner

But can we simplify last example a bit? <br/>
What if we know, that we are going to train `Supervised` model, that will take some `features` in and output some `logits` back? <br/>
Looks like commom case... could we automate it? Let's check it out!

In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = dl.SupervisedRunner(  # `SupervisedRunner` works with any model like `some_output = model(some_input)`
  input_key="features",        # if your dataloader yields (x, y) tuple, it will be transformed to 
  output_key="logits",         # {input_key: x, input_target_key: y} and stored to runner.input
  input_target_key="targets",  # then the model will be used like
)                              # runner.output = model(runner.input[input_key])
                               # loss computation suppose to looks like
                               # loss = criterion(runner.output[input_target_key], runner.output[output_key])
                               # and stored to `runner.batch_metrics['loss']`

runner.train(
  model=model, 
  optimizer=optimizer,
  criterion=criterion,
  scheduler=None,
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  callbacks={
    "criterion_mse": dl.CriterionCallback(
      input_key="targets",
      output_key="logits",
      prefix="loss",
    ),
    "criterion_mae": dl.MetricCallback(
      input_key="targets",
      output_key="logits",
      prefix="mae",
      metric_fn=F.l1_loss,
    ),
    "optimizer": dl.OptimizerCallback(
      metric_key="loss", 
      accumulation_steps=1,
      grad_clip_params=None,
    )
  }
)

## Providing your own inference step

But let's return to the basics.

What if you want to do the same customization for calls to `runner.predict_*()`? Then you would override `predict_batch` in exactly the same way. Here's what it looks like:

In [None]:
import torch
from torch.nn import functional as F

class CustomRunner(dl.Runner):
    
  def predict_batch(self, batch):                 # here is the trick
    return self.model(batch[0].to(self.device))   # you can write any prediciton logic here

  def _handle_batch(self, batch):                 # our first time example
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `train()`.
    x, y = batch

    y_pred = self.model(x) # Forward pass

    # Compute the loss value
    loss = F.mse_loss(y_pred, y)

    # Update metrics (includes the metric that tracks the loss)
    self.batch_metrics.update({"loss": loss, "mae": F.l1_loss(y_pred, y)})

    if self.is_train_loader:
      # Compute gradients
      loss.backward()
      # Update weights
      # (the optimizer is stored in `self.state`)
      self.optimizer.step()
      self.optimizer.zero_grad()

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# Construct custom data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples, 1)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

# and model
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# Just use `train` as usual
runner = CustomRunner()
runner.train(
  model=model, 
  optimizer=optimizer, 
  loaders=loaders, 
  num_epochs=3,
  verbose=True,
  timeit=False,
  load_best_on_end=True, # flag to load best model at the end of the training process
  logdir="./logs",       # logdir to store models checkpoints (required for `load_best_on_end`)
)
# and use `batch` prediciton
prediction = runner.predict_batch(next(iter(loader))) # let's sample first batch from loader
# or `loader` prediction
for prediction in runner.predict_loader(loader=loader):
  assert prediction.detach().cpu().numpy().shape[-1] == 1 # as we have 1-class regression

Finally, after model training and evaluation, it's time to prepare it for deployment. PyTorch upport model tracing for production-friendly Deep Leanring models deployment.

Could we make it quick with Catalyst? Sure!

In [None]:
# you can trace your model through batch 'mode'
traced_model = runner.trace(batch=next(iter(loader)))
# or loader 'mode' - it will take first batch automatically
traced_model = runner.trace(loader=loader)

## Wrapping up: an end-to-end GAN example

Let's walk through an end-to-end example that leverages everything you just learned.

Let's consider:

- A generator network meant to generate 28x28x1 images.
- A discriminator network meant to classify 28x28x1 images into two classes ("fake" - 1 and "real" - 0).



In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from catalyst.contrib.nn import GlobalMaxPool2d, Flatten, Lambda

# Create the discriminator
discriminator = nn.Sequential(
    nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    GlobalMaxPool2d(),
    Flatten(),
    nn.Linear(128, 1)
)

# Create the generator
latent_dim = 128
generator = nn.Sequential(
    # We want to generate 128 coefficients to reshape into a 7x7x128 map
    nn.Linear(128, 128 * 7 * 7),
    nn.LeakyReLU(0.2, inplace=True),
    Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 1, (7, 7), padding=3),
    nn.Sigmoid(),
)

# Final model
model = {
    "generator": generator,
    "discriminator": discriminator,
}

optimizer = {
    "generator": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
}


Here's a feature-complete `GANRunner`, overriding `predict_batch()` to use its own signature, and implementing the entire GAN algorithm in 16 lines in `_handle_batch`:

In [None]:
class GANRunner(dl.Runner):
  
  def _init(self, latent_dim: int):
    self.latent_dim = latent_dim
    self.experiment = None  # spoiler for next lesson ;)

  def predict_batch(self, batch):
    random_latent_vectors = torch.randn(1, self.latent_dim).to(self.device)
    generated_images = self.model["generator"](random_latent_vectors)
    return generated_images

  def _handle_batch(self, batch):
    real_images, _ = batch
    batch_metrics = {}
    
    # Sample random points in the latent space
    batch_size = real_images.shape[0]
    random_latent_vectors = torch.randn(batch_size, self.latent_dim).to(self.device)
    
    # Decode them to fake images
    generated_images = self.model["generator"](random_latent_vectors).detach()
    # Combine them with real images
    combined_images = torch.cat([generated_images, real_images])
    
    # Assemble labels discriminating real from fake images
    labels = torch.cat([
        torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))
    ]).to(self.device)
    # Add random noise to the labels - important trick!
    labels += 0.05 * torch.rand(labels.shape).to(self.device)
    
    # Train the discriminator
    predictions = self.model["discriminator"](combined_images)
    batch_metrics["loss_discriminator"] = \
      F.binary_cross_entropy_with_logits(predictions, labels)
    
    # Sample random points in the latent space
    random_latent_vectors = torch.randn(batch_size, self.latent_dim).to(self.device)
    # Assemble labels that say "all real images"
    misleading_labels = torch.zeros((batch_size, 1)).to(self.device)
    
    # Train the generator
    generated_images = self.model["generator"](random_latent_vectors)
    predictions = self.model["discriminator"](generated_images)
    batch_metrics["loss_generator"] = \
      F.binary_cross_entropy_with_logits(predictions, misleading_labels)
    
    self.batch_metrics.update(**batch_metrics)

Let's test-drive it:

In [None]:
import os
from torch.utils.data import DataLoader
from catalyst.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST

loaders = {
  "train": DataLoader(
    MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), 
    batch_size=64),
}

runner = GANRunner(latent_dim=latent_dim)
runner.train(
    model=model, 
    optimizer=optimizer,
    loaders=loaders,
    callbacks=[
        dl.OptimizerCallback(
            optimizer_key="generator", 
            metric_key="loss_generator"
        ),
        dl.OptimizerCallback(
            optimizer_key="discriminator", 
            metric_key="loss_discriminator"
        ),
    ],
    main_metric="loss_generator",
    num_epochs=20,
    verbose=True,
    logdir="./logs_gan",
)

The idea behind deep learning are simple, so why should their implementation be painful?

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

utils.set_global_seed(42)
generated_image = runner.predict_batch(None)
plt.imshow(generated_image[0, 0].detach().cpu().numpy())

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./logs_gan