In [1]:
# importing the necessary libraries
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from tqdm import tqdm
from fastprogress.fastprogress import master_bar, progress_bar
from pathlib import Path
import os
import time

_ = torch.manual_seed(0)


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    size = os.path.getsize("temp_delme.p")/1e3
    print('Size (KB):', size)
    os.remove('temp_delme.p')
    return size


# Define the network:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.quantization.QuantStub() # quantization layer
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x


def train_model(model, train_dl, valid_dl, criterion, optimizer):
    mb = master_bar(range(5))
    for epoch in mb:
        running_loss = 0.0
        correct_train, total_train = 0, 0
        # Progress bar for training batches
        pb = progress_bar(train_dl, parent=mb)

        for inputs, labels in pb:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Compute training accuracy
            preds = torch.argmax(outputs, dim=1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)

            mb.child.comment = f"Train Loss: {loss.item():.4f}"

        # Compute average train loss & accuracy
        avg_train_loss = running_loss / len(train_dl)
        train_accuracy = correct_train / total_train * 100

        # Validation Phase (No Gradients)
        val_loss = 0.0
        correct_val, total_val = 0, 0

        with torch.no_grad():
            # Progress bar for validation
            pb = progress_bar(valid_dl, parent=mb)

            for inputs, labels in pb:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Compute validation accuracy
                preds = torch.argmax(outputs, dim=1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)

                mb.child.comment = f"Valid Loss: {loss.item():.4f}"

        # Compute average validation loss & accuracy
        avg_val_loss = val_loss / len(valid_dl)
        val_accuracy = correct_val / total_val * 100

        # Write epoch summary
        mb.write(f"Epoch {epoch+1}: "
                 f"Train Loss={avg_train_loss:.4f}, Train Acc={train_accuracy:.2f}% | "
                 f"Val Loss={avg_val_loss: .4f}, Val Acc={val_accuracy: .2f}")


# evaluate
def evaluate(model, test_loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predicted = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("test acc", correct / total)
    return correct/total


def calibrate(model, dl):
    model.eval()
    with torch.no_grad():
        for idx, (images, _) in enumerate(tqdm(dl)):
            model(images)
            if idx == 4:
                break
    return model


def quantize_model(model, train_loader):
    # Create quantization-ready model
    qmodel = Net().to('cpu')  # quantization needs to happen on CPU
    qmodel.load_state_dict(model.state_dict())
    qmodel.eval()

    # Set qconfig for static quantization
    qmodel.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')

    # Prepare model for quantization (adds observers)
    qmodel = torch.ao.quantization.prepare(qmodel)

    # Calibrate the model
    calibrate(qmodel, train_loader)

    # Convert to quantized model
    qmodel = torch.ao.quantization.convert(qmodel)

    return qmodel

  warn("Couldn't import ipywidgets properly, progress bar will use console behavior")


In [2]:
# config
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 32
epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# dataset and dataloader
train_ds = torchvision.datasets.CIFAR10(
    root='./data/cifar', train=True, transform=tfms, download=True)
test_ds = torchvision.datasets.CIFAR10(
    root='./data/cifar', train=False, transform=tfms, download=True)

train_ds, valid_ds = random_split(train_ds, [40000, 10000])
print(f"train ds length {len(train_ds)}")
print(f"test ds length {len(test_ds)}")
# train_ds = Subset(train_ds, range(100))
# valid_ds = Subset(train_ds, range(100))
# test_ds = Subset(train_ds, range(100))

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=batch_size, shuffle=True)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=batch_size, shuffle=False)
test_dl = torch.utils.data.DataLoader(
    test_ds, batch_size=batch_size, shuffle=False)

train ds length 40000
test ds length 10000


In [3]:
# insert min-max observers in the model
model = Net().to(device)
model.train()
model.qconfig = torch.ao.quantization.default_qconfig
model = torch.ao.quantization.prepare_qat(model)
print("Model before training")
print(model)

Model before training
Net(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (conv1): Conv2d(
    3, 6, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (fc2): Linear(
    in_features=120, out_features=84, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): M

In [4]:
# defining loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# saving the model
model_path = "model_qat.pkl"
if Path(model_path).exists():
    model = Net()
    model.load_state_dict(torch.load(model_path))
    print('model loaded successfully')
else:
    train_model(model, train_dl, valid_dl, criterion, optimizer)
    torch.save(model.state_dict(), model_path)

Epoch 1: Train Loss=2.2679, Train Acc=14.82% | Val Loss= 2.1236, Val Acc= 25.06                     
Epoch 2: Train Loss=1.9576, Train Acc=29.41% | Val Loss= 1.7920, Val Acc= 35.13                     
Epoch 3: Train Loss=1.6837, Train Acc=38.14% | Val Loss= 1.5950, Val Acc= 42.33                     
Epoch 4: Train Loss=1.5304, Train Acc=44.22% | Val Loss= 1.5104, Val Acc= 44.51                     
Epoch 5: Train Loss=1.4410, Train Acc=48.04% | Val Loss= 1.4249, Val Acc= 48.83                     


In [5]:
print("Model after training")
print(model)
# quantization of model
model.eval()
model = torch.ao.quantization.convert(model)

Model after training
Net(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-1.0, max_val=1.0)
  )
  (conv1): Conv2d(
    3, 6, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.29749810695648193, max_val=0.3572477102279663)
    (activation_post_process): MinMaxObserver(min_val=-5.882468223571777, max_val=5.583756446838379)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.21562230587005615, max_val=0.22265036404132843)
    (activation_post_process): MinMaxObserver(min_val=-9.163378715515137, max_val=11.960596084594727)
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.11489372700452805, max_val=0.10775470733642578)
    (activation_post_process): MinMaxObserver(min_val=-8.196516036987305, max_val=10.78540

In [6]:
print('Weights after quantization')
print(torch.int_repr(model.fc1.weight()))
print('Size of the model after quantization')
q_size = print_size_of_model(model)
print('Accuracy of the model after quantization: ')
start = time.time()
q_acc = evaluate(model, test_dl)
end = time.time()
q_time = end-start

print(f"{'Prec':<6} | {'Accuracy':<8} | {'Model Size':<10} | {'Time Taken'}")
print(f"{'INT8':<6} | {q_acc:<8.4f} | {q_size:<10.3f} | {q_time:<.6f}")


Weights after quantization
tensor([[ -4, -21,  57,  ...,  37,  39, -76],
        [-11,   7,  12,  ...,  10,  17,  78],
        [ -5, -13,  70,  ...,  12, -45,  20],
        ...,
        [ 13,  -6, -11,  ..., -33, -37,   6],
        [ 29,  -4,  19,  ..., -63, -46,  30],
        [ 23,   9, -25,  ...,   0,   7,  52]], dtype=torch.int8)
Size of the model after quantization
Size (KB): 69.922
Accuracy of the model after quantization: 


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:02<00:00, 106.34it/s]

test acc 0.486
Prec   | Accuracy | Model Size | Time Taken
INT8   | 0.4860   | 69.922     | 2.949037



