# Catalyst – A PyTorch Framework For Accelerated Deep Learning 

Catalyst is a PyTorch framework developed with the intent of advancing research and development in the domain of deep learning. It enables code reusability, reproducibility and rapid experimentation so that users can conveniently create deep learning models and pipelines without writing another training loop.

Catalyst framework is part of the PyTorch ecosystem – a collection of numerous tools and libraries for AI development. It is also a part of the Catalyst Ecosystem – an MLOps ecosystem that expedites training, analysis and deployment of deep learning experiments through Catalyst, Alchemy and Reaction frameworks respectively

To read about it more, please refer [this](https://analyticsindiamag.com/guide-to-catalyst-a-pytorch-framework-for-accelerated-deep-learning/) article.

# Practical implementation

Here’s a demonstration of handling image classification task using Catalyst. We have used the well-known MNIST dataset having 10 output classes (for classifying images of handwritten digits from 0 to 9). 

Install Catalyst library using pip command

In [None]:
!python -m pip install pip --upgrade --user -q
!python -m pip install numpy pandas seaborn matplotlib scipy sklearn statsmodels tensorflow keras --user -q

In [None]:
!python -m pip install -U catalyst --user -q

In [None]:
import IPython
IPython.Application.instance().kernel.do_shutdown(True)

  Import required libraries and modules

In [None]:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST

  Define the neural network model, loss function and optimizer to be used

Here, we use a simple architecture comprising an input layer and a hidden layer each having 28 neurons while the output layer has 10 neurons since we have 10 possible output classes.

We have used the cross entropy loss function of PyTorch which combines LogSoftmax and NLLLoss in one class.

In [None]:
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

Load data for training and validation

Using DataLoader() method of PyTorch, we load the MNIST dataset from contrib API of  Catalyst. We create a dictionary to load both training and validation sets at once by specifying the ‘train’ parameter of MNIST() as True or False respectively.

In [None]:
loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
    "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32),
}

Define an experiment runner for handling experiments with supervised model using SupervisedRunner() method of Runners API.

In [None]:
runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")

Model training using Runners.train() method

In [None]:
# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        # catalyst[ml] required
        dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

In [None]:
features_batch = next(iter(loaders["valid"]))[0]
# model stochastic weight averaging
model.load_state_dict(utils.get_averaged_weights_by_path_mask(logdir="./logs", path_mask="*.pth"))
# model tracing
utils.trace_model(model=runner.model, batch=features_batch)
# model quantization
utils.quantize_model(model=runner.model)
# model pruning
utils.prune_model(model=runner.model, pruning_fn="l1_unstructured", amount=0.8)
# onnx export
utils.onnx_export(model=runner.model, batch=features_batch, file="./logs/mnist.onnx", verbose=True)

RuntimeError: ignored