# Fashion MNIST Toy Model with PyTorch QAT and BST QAT

In [20]:
# Both torch 1.9.1 and 1.13.0 should work
import torch
from torch import nn

print(torch.__version__)

1.9.1+cu102


## 1. Regular Training with FP32

### 1.1 Loading data

In [21]:
import torchvision
from torch.utils.data import DataLoader

# Download training data from open datasets.
training_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

# Download test data from open datasets.
test_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


### 1.2 Defining Models

In [22]:
# Define model
class FashionCNN(nn.Module):
    def __init__(self):
        super(FashionCNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        # View avoids explicit data copy
        # out = out.view(out.size(0), -1)        
        out = torch.reshape(out, (out.size(0), -1))
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        
        return out

### 1.3 Defining Training and Testing Functions

In [23]:
# In a single training loop, the model makes predictions on the training dataset (fed to it in batches), 
# and backpropagates the prediction error to adjust the model’s parameters.
def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)        

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# We also check the model’s performance against the test dataset to ensure it is learning.
def test(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")            

### 1.4 Training and Testing

In [24]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

fp_model = FashionCNN()
fp_model.to(device)
print(fp_model)

# Define a loss function
loss_fn = nn.CrossEntropyLoss()

# Define an optimizer, for this model, Adam is better than SGD
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(fp_model.parameters(), lr=1e-3)

epochs = 5
for nepoch in range(epochs):
    print(f"Epoch {nepoch}\n-------------------------------")
    train(train_dataloader, fp_model, loss_fn, optimizer, device=device)
    test(test_dataloader, fp_model, loss_fn, device=device)
print("Done!")

Using cuda device
FashionCNN(
  (layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
)
Epoch 0
-------------------------------
loss: 2.387211  [    0/60000]
loss: 0.429180  [ 6400/60000]
loss: 0.272270  [12800/60000]
loss: 0.510478  [19200/60000]
loss: 0.340889

### 1.5 Saving Models

In [25]:
# torch.save(model, "data/fashion_mnist.pth")
torch.save(fp_model.state_dict(), "data/fashion_mnist.pth")
print("Saved PyTorch Model to data/fashion_mnist.pth")

Saved PyTorch Model to data/fashion_mnist.pth


### 1.6 Export to ONNX

In [26]:
# Input to the model
x = torch.randn(batch_size, 1, 28, 28, requires_grad=True)
fp_model.eval()
fp_model.to('cpu')
torch_out = fp_model(x)

# Export the model
torch.onnx.export(fp_model,                     # model being run
                  x,                            # model input (or a tuple for multiple inputs)
                  "data/fashion_mnist.onnx",    # where to save the model (can be a file or file-like object)
                  export_params=True,           # store the trained parameter weights inside the model file
                  opset_version=13,             # the ONNX version to export the model to
                  do_constant_folding=True,     # whether to execute constant folding for optimization
                  input_names = ['input'],      # the model's input names
                  output_names = ['output'],    # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

## PyTorch QAT

### 2.1 Loading Data

It is the same as training a float model.

### 2.2 Defining Models

- Insert QuantStub and DeQuantStub at the beginning and end of the network.
- Define `fuse_module` function to fuse CONV+BN etc.

In [27]:
from torch.quantization import QuantStub, DeQuantStub

# Define model
class FashionCNNQAT(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)

        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)

        out = self.layer1(x)
        out = self.layer2(out)
        # View avoids explicit data copy
        # out = out.view(out.size(0), -1)
        out = torch.reshape(out, (out.size(0), -1))
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)

        out = self.dequant(out)
        
        return out

    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics
    def fuse_model(self):
        for m in self.modules():
            if type(m) == nn.Sequential:
                torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)

### 2.3 Defining Training and Testing Functions

They are the same as training and testing float models.

### 2.4 Training and Testing

#### 2.4.1 Load pretrained float model

In [28]:
# Load pretrained float model
qat_model = FashionCNNQAT()
state_dict = torch.load("data/fashion_mnist.pth")
qat_model.load_state_dict(state_dict)
print(qat_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)


#### 2.4.2 Fuse model

In [29]:
qat_model.eval()
qat_model.fuse_model()
print(qat_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (1): Identity()
    (2): Identity()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (1): Identity()
    (2): Identity()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)


#### 2.4.3 Prepare for QAT

##### Option 1:  Using default qconfig

In [30]:

quantization_config = torch.quantization.get_default_qat_qconfig('fbgemm')

##### Option 2: Using customerized observers and quantization schemes

In [31]:
from torch.quantization.fake_quantize import FakeQuantize
from torch.quantization.qconfig import QConfig

activation_quant = FakeQuantize.with_args(observer=torch.quantization.default_observer.with_args(dtype=torch.qint8),
                                           quant_min=-128,
                                           quant_max=127,
                                           dtype=torch.qint8,
                                           qscheme=torch.per_tensor_affine)

weight_quant = FakeQuantize.with_args(observer=torch.quantization.default_observer.with_args(dtype=torch.qint8),
                                        quant_min=-128,
                                        quant_max=127,
                                        dtype=torch.qint8,
                                        qscheme=torch.per_tensor_affine)

quantization_config = QConfig(activation=activation_quant, weight=weight_quant) 

In [32]:
qat_model.qconfig = quantization_config
print(qat_model.qconfig)

QConfig(activation=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=functools.partial(functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), dtype=torch.qint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine), weight=functools.partial(<class 'torch.quantization.fake_quantize.FakeQuantize'>, observer=functools.partial(functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), dtype=torch.qint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine))


In [33]:
qat_model.train()
qat_model.to(device)
torch.quantization.prepare_qat(qat_model, inplace=True)
print(qat_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): ConvReLU2d(
      1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)



#### 2.4.4 Training and Testing

- Training can be done on GPU/CUDA, Testing has to be done on CPU.

In [34]:
# Define a loss function
loss_fn = nn.CrossEntropyLoss()

# Define an optimizer, for this model, Adam is better than SGD
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(qat_model.parameters(), lr=1e-3)

epochs = 3
for nepoch in range(epochs):
    print(f"Epoch {nepoch}\n-------------------------------")
    qat_model.to(device)
    train(train_dataloader, qat_model, loss_fn, optimizer, device=device)
    if nepoch > 1:
        # Freeze quantizer parameters
        qat_model.apply(torch.quantization.disable_observer)
    if nepoch > 0:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    # Convert the quantized model first
    qat_model.to('cpu')
    quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
    quantized_model.eval()
    test(test_dataloader, quantized_model, loss_fn, device='cpu')
print("Done!")

Epoch 0
-------------------------------
loss: 0.151610  [    0/60000]
loss: 0.278644  [ 6400/60000]
loss: 0.157579  [12800/60000]
loss: 0.294432  [19200/60000]
loss: 0.347096  [25600/60000]
loss: 0.406027  [32000/60000]
loss: 0.168670  [38400/60000]
loss: 0.294289  [44800/60000]
loss: 0.223967  [51200/60000]
loss: 0.119438  [57600/60000]


RuntimeError: expected scalar type QUInt8 but found QInt8

### 2.5 Saving Models

In [None]:
# torch.save(model, "data/fashion_mnist.pth")
torch.save(quantized_model.state_dict(), "data/fashion_mnist_qat.pth")
print("Saved PyTorch Model to data/fashion_mnist_qat.pth")

Saved PyTorch Model to data/fashion_mnist_qat.pth


### 2.6 Save JIT model

- Directly export to ONNX won't work.

In [None]:
torch.jit.save(torch.jit.script(quantized_model), "data/fashion_mnist_jit.pth")

## 3. BST QAT

### 3.0 Make sure bstnnx_training package is in the system python path

In [None]:
# Make sure bstnnx_training is in the system Python path
import sys
import os

sys.path.append("/home/hongbing/Projects/bst-study/bstnnx_training")

import bstnnx_training

print(bstnnx_training.__version__)

### 3.1 Loading data

It is the same as training a float model and PyTorch QAT

### 3.2 Defining Models

It is the same as PyTorch QAT

### 3.3 Defining Training and Testing Functions

They are teh same as training and testing float models and PyTorch QAT.

### 3.4 Training and Testing

#### 3.4.1 Load pretrained float model

It is the same as PyTorch QAT.

In [36]:
# Load pretrained float model
bst_model = FashionCNNQAT()
state_dict = torch.load("data/fashion_mnist.pth")
bst_model.load_state_dict(state_dict)
print(bst_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)


#### 3.4.2 Fuse Model

It is the same as PyTorch QAT.

In [40]:
bst_model.eval()
bst_model.fuse_model()
print(bst_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (1): Identity()
    (2): Identity()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (1): Identity()
    (2): Identity()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)


- BST also supports auto fuse detection with the replacement of PyTorch's fuse_modules() function
- It seems ReLU wasn't fused. Let's ignore this apporach for now

In [39]:
from bstnnx_training.PyTorch.QAT.core import fuse_modules

x = torch.randn(batch_size, 1, 28, 28, requires_grad=True)
fused_model = fuse_modules(bst_model, auto_detect=True, input_tensor=x)
print(fused_model)

FashionCNNQAT(
  (layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): Identity()
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=2304, out_features=600, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=600, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)


#### 3.4.3 Prepare for BST QAT

- We can't use default qconfig, only customerized observers and quatization schemes are supported

In [41]:
from bstnnx_training.PyTorch.QAT.core.observer.observer import BSTObserver
from bstnnx_training.PyTorch.QAT.core.fake_quantize import FakeQuantize
from bstnnx_training.PyTorch.QAT.core.qconfig import QConfig

bst_activation_quant = FakeQuantize.with_args(observer=BSTObserver.with_args(dtype=torch.qint8),
                                              quant_min=-128,
                                              quant_max=127,
                                              dtype=torch.qint8,
                                              qscheme=torch.per_tensor_affine,
                                              reduce_range=False)
bst_weight_quant = FakeQuantize.with_args(observer=BSTObserver.with_args(dtype=torch.qint8),
                                          quant_min=-128,
                                          quant_max=127,
                                          dtype=torch.qint8,
                                          qscheme=torch.per_tensor_affine,
                                          reduce_range=False)

bst_model.qconfig = QConfig(activation=bst_activation_quant, weight=bst_weight_quant)

ModuleNotFoundError: No module named 'bstnnx_training.PyTorch.QAT.core.observer.bst_observer'