# MPTorch Functionality Overview

## Introduction

In this notebook, we provide an overview of some of the main features of MPTorch. To install the current version of MPTorch, follow the instructions from the `README.md` in the GitHub [repo](https://github.com/mptorch/mptorch/blob/master/README.md).

In [None]:
import torch
import mptorch

## Quantization

There are currently three different types of number formats supported: floating-point, fixed-point, and block floating-point. Quantization functions for PyTorch tensors are provided for each format.


In [None]:
from mptorch.quant import float_quantize, fixed_point_quantize, block_quantize

In [None]:
full_prec_tensor = torch.rand(4)
print(f"Full Precision value: {full_prec_tensor}")
low_prec_tensor = float_quantize(
  full_prec_tensor, exp=5, man=2, rounding='nearest'
)
print(f"Low Precision value: {low_prec_tensor}")

Currently, nearest rounding and stochastic rounding are supported. The user can also specify if subnormal values are also allowed (yes by default).

In [None]:
torch.manual_seed(123)
torch.cuda.manual_seed(123)
torch.backends.cudnn.deterministic = True
full_prec_tensor = torch.rand(4)
nearest_round = float_quantize(full_prec_tensor, exp=5, man=2, rounding='nearest')
stochastic_round = float_quantize(full_prec_tensor, exp=5, man=2, rounding='stochastic')
print(f"Original: {full_prec_tensor}")
print(f"Nearest: {nearest_round}")
print(f"Stochastic: {stochastic_round}")

A `saturate` flag that determines the behavior of the quantizer in case of overflows is also provided. If set to `True` quantized values will be set to the max/min representable value in the format (depending on the sign), otherwise they will be set to `Inf`. Propagating infinities can be useful in case of implementing/testing methods such as loss scaling, where scale update algorithms can work by detecting overflows.

In [None]:
x = torch.tensor([-102402.2, 0.4, 68053.3])
qx_saturate = float_quantize(x, exp=5, man=2, rounding='nearest', saturate=True)
qx_inf = float_quantize(x, exp=5, man=2, rounding='nearest', saturate=False)
print(f"Original: {x}")
print(f"Saturated: {qx_saturate}")
print(f"With Inf: {qx_inf}")

## Creating a model

To create a model for low precision training we can specify the various formats and quantizations to use for the layer operations and signals, in both the forward and the backward pass. For linear and convolutional layers (these are the only supported ones so far), the user can set a `QAffineFormats` object.

In [None]:
from mptorch import FloatingPoint
import mptorch.quant as qpt
exp, man = 4, 2
fp_format = FloatingPoint(exp=exp, man=man, subnormals=True, saturate=False)
quant_fp = lambda x : qpt.float_quantize(
  x, exp=exp, man=man, rounding="nearest", subnormals=True, saturate=False
)

layer_formats = qpt.QAffineFormats(
    fwd_mac=(fp_format, fp_format), # format to use for addition/multiplication in FWD MM calls
    fwd_rnd="nearest",              # rounding mode of the FWD mode operators
    bwd_mac=(fp_format, fp_format), # format to use for addition/multiplication in BWD MM calls
    bwd_rnd="nearest",              # rounding mode of the BWD mode operators
    weight_quant=quant_fp,          # how weights should be quantized during FWD and BWD computations
    bias_quant=quant_fp,            # how bias should be quantized during FWD and BWD computations
    input_quant=quant_fp,           # how input signals should be quantized during FWD and BWD computations
    grad_quant=quant_fp,            # how gradients should be quantized during FWD and BWD computations
)

Having defined the arithmetic configuration for forward and backward computations involving an affine layer (linear or convolution), we can now define a complete model. We will use the same MLP model that we saw in the Brevitas notebook.

In [None]:
class Reshape(torch.nn.Module):
    def forward(self, x):
        return x.view(-1, 28 * 28)


model = torch.nn.Sequential(
    Reshape(),
    qpt.QLinear(784, 128, formats=layer_formats),
    torch.nn.ReLU(),
    qpt.QLinear(128, 96, formats=layer_formats),
    torch.nn.ReLU(),
    qpt.QLinear(96, 10, formats=layer_formats),
)

## A mixed-precision training example

We are almost ready to simulate a mixed-precision training workflow, but there are still a couple of things to do. We need to load the data for training (in our case MNIST), set the precision to use during the parameter update process (this is achieved through a `mptorch.optim.MPOptim` wrapper to standard PyTorch optimizers), choose our training (hyper)parameters and that's about it.

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import SGD
from mptorch.optim import OptimMP
from mptorch.utils import trainer

"""Hyperparameters"""
batch_size = 64  # batch size
lr_init = 0.05  # initial learning rate
num_epochs = 2  # epochs
momentum = 0.9
weight_decay = 0

"""Prepare the transforms on the dataset"""
device = "cuda" if torch.cuda.is_available() else "cpu"
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

"""download dataset: MNIST"""
train_dataset = datasets.MNIST(
    "./data", train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(
    "./data", train=False, transform=transform, download=False
)
test_loader = DataLoader(test_dataset, batch_size=int(batch_size), shuffle=False)

"""Prepare and launch the training process"""
model = model.to(device)
optimizer = SGD(
    model.parameters(), lr=lr_init, momentum=momentum, weight_decay=weight_decay
)

# choose the precision for the parameter update process (here it is full FP32 precision)
acc_q = lambda x: qpt.float_quantize(x, exp=8, man=23, rounding="nearest")
optimizer = OptimMP(
    optimizer,
    acc_quant=acc_q,
    momentum_quant=acc_q,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

trainer(
    model,
    train_loader,
    test_loader,
    num_epochs=num_epochs,
    lr=lr_init,
    batch_size=batch_size,
    optimizer=optimizer,
    device=device,
    init_scale=2**10,
)