**Static Quantization**:
   - Static quantization DOES quantize and store activations
   - Uses fixed scaling factors determined during calibration
   - Requires a separate calibration step

In [1]:
# importing the necessary libraries
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from tqdm import tqdm
from fastprogress.fastprogress import master_bar, progress_bar
from pathlib import Path
import os
import time

_ = torch.manual_seed(0)


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    size = os.path.getsize("temp_delme.p")/1e3
    print('Size (KB):', size)
    os.remove('temp_delme.p')
    return size


# Define the network:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x


def train_model(model, train_dl, valid_dl, criterion, optimizer):
    mb = master_bar(range(5))
    for epoch in mb:
        running_loss = 0.0
        correct_train, total_train = 0, 0
        # Progress bar for training batches
        pb = progress_bar(train_dl, parent=mb)

        for inputs, labels in pb:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Compute training accuracy
            preds = torch.argmax(outputs, dim=1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)

            mb.child.comment = f"Train Loss: {loss.item():.4f}"

        # Compute average train loss & accuracy
        avg_train_loss = running_loss / len(train_dl)
        train_accuracy = correct_train / total_train * 100

        # Validation Phase (No Gradients)
        val_loss = 0.0
        correct_val, total_val = 0, 0

        with torch.no_grad():
            # Progress bar for validation
            pb = progress_bar(valid_dl, parent=mb)

            for inputs, labels in pb:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Compute validation accuracy
                preds = torch.argmax(outputs, dim=1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)

                mb.child.comment = f"Valid Loss: {loss.item():.4f}"

        # Compute average validation loss & accuracy
        avg_val_loss = val_loss / len(valid_dl)
        val_accuracy = correct_val / total_val * 100

        # Write epoch summary
        mb.write(f"Epoch {epoch+1}: "
                 f"Train Loss={avg_train_loss:.4f}, Train Acc={train_accuracy:.2f}% | "
                 f"Val Loss={avg_val_loss: .4f}, Val Acc={val_accuracy: .2f}")


# evaluate
def evaluate(model, test_loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predicted = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("test acc", correct / total)
    return correct/total


def calibrate(model, dl):
    model.eval()
    with torch.no_grad():
        for idx, (images, _) in enumerate(tqdm(dl)):
            model(images)
            if idx == 4:
                break
    return model


def quantize_model(model, train_loader):
    # Create quantization-ready model
    qmodel = QNet().to('cpu')  # quantization needs to happen on CPU
    qmodel.load_state_dict(model.state_dict())
    qmodel.eval()

    # Set qconfig for static quantization
    qmodel.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')

    # Prepare model for quantization (adds observers)
    qmodel = torch.ao.quantization.prepare(qmodel)

    # Calibrate the model
    calibrate(qmodel, train_loader)

    # Convert to quantized model
    qmodel = torch.ao.quantization.convert(qmodel)

    return qmodel

  warn("Couldn't import ipywidgets properly, progress bar will use console behavior")


In [2]:
# config
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 32
epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# dataset and dataloader
train_ds = torchvision.datasets.CIFAR10(
    root='./data/cifar', train=True, transform=tfms, download=True)
test_ds = torchvision.datasets.CIFAR10(
    root='./data/cifar', train=False, transform=tfms, download=True)

train_ds, valid_ds = random_split(train_ds, [40000, 10000])
print(f"train ds length {len(train_ds)}")
print(f"test ds length {len(test_ds)}")
# train_ds = Subset(train_ds, range(100))
# valid_ds = Subset(train_ds, range(100))
# test_ds = Subset(train_ds, range(100))

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=batch_size, shuffle=True)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=batch_size, shuffle=False)
test_dl = torch.utils.data.DataLoader(
    test_ds, batch_size=batch_size, shuffle=False)

train ds length 40000
test ds length 10000


In [3]:
# defining loss function and optimizer
model = Net()
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# saving the model
model_path = "model.pkl"
if Path(model_path).exists():
    model = Net()
    model.load_state_dict(torch.load(model_path))
    print('model loaded successfully')
else:
    train_model(model, train_dl, valid_dl, criterion, optimizer)
    torch.save(model.state_dict(), 'model.pkl')

model loaded successfully


In [4]:
print('Weights before quantization')
print(model.fc1.weight)
print(model.fc1.weight.dtype)
print('Size of the model before quantization')
o_size = print_size_of_model(model)
print('Accuracy of the model before quantization: ')
tik = time.time()
o_acc = evaluate(model, test_dl)
tok = time.time()
o_time = tok-tik

Weights before quantization
Parameter containing:
tensor([[ 0.0104, -0.0451, -0.0420,  ..., -0.0018, -0.0146,  0.0187],
        [-0.0603,  0.0163,  0.0094,  ...,  0.0270, -0.0356,  0.0184],
        [-0.0088,  0.0259,  0.0162,  ..., -0.0550,  0.0287,  0.0143],
        ...,
        [-0.0406, -0.0324,  0.0329,  ...,  0.0080, -0.0130, -0.0109],
        [-0.0100, -0.0422,  0.0076,  ..., -0.0117, -0.0102,  0.0500],
        [ 0.0027,  0.0210,  0.0457,  ..., -0.0049, -0.0600, -0.0109]],
       requires_grad=True)
torch.float32
Size of the model before quantization
Size (KB): 251.618
Accuracy of the model before quantization: 


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:01<00:00, 206.00it/s]

test acc 0.4745





In [5]:
print("applying post training tynamic quantization on model")
quantized_model = quantize_model(model, train_dl)



applying post training tynamic quantization on model


  0%|▍                                                                                                                                                       | 4/1250 [00:00<00:24, 51.65it/s]


In [6]:
print("model summary")
print(quantized_model)
print('Weights after quantization')
print(torch.int_repr(quantized_model.fc1.weight()))
print('Size of the model after quantization')
q_size = print_size_of_model(quantized_model)
print('Accuracy of the model after quantization: ')
start = time.time()
q_acc = evaluate(quantized_model, test_dl)
end = time.time()
q_time = end-start

model summary
QNet(
  (quant): Quantize(scale=tensor([0.0157]), zero_point=tensor([64]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.0658501610159874, zero_point=59)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=0.14691703021526337, zero_point=56)
  (fc1): QuantizedLinear(in_features=400, out_features=120, scale=0.15533238649368286, zero_point=68, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinear(in_features=120, out_features=84, scale=0.08135645091533661, zero_point=43, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=0.09499918669462204, zero_point=64, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)
Weights after quantization
tensor([[  17,  -76,  -71,  ...,   -3,  -24,   31],
        [ -79,   21,   12,  ...,   35,  -47,   24],
        [ -16,   48,  

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:01<00:00, 171.96it/s]

test acc 0.4726





In [7]:
print('Original weights: ')
print(model.fc1.weight)
print('')
print('Dequantized weights: ')
print(torch.dequantize(quantized_model.fc1.weight()))
print('')

print(f"{'Prec':<6} | {'Accuracy':<8} | {'Model Size':<10} | {'Time Taken'}")
print(f"{'FP32':<6} | {o_acc:<8.4f} | {o_size:<10.3f} | {o_time:<.6f}")
print(f"{'INT8':<6} | {q_acc:<8.4f} | {q_size:<10.3f} | {q_time:<.6f}")

Original weights: 
Parameter containing:
tensor([[ 0.0104, -0.0451, -0.0420,  ..., -0.0018, -0.0146,  0.0187],
        [-0.0603,  0.0163,  0.0094,  ...,  0.0270, -0.0356,  0.0184],
        [-0.0088,  0.0259,  0.0162,  ..., -0.0550,  0.0287,  0.0143],
        ...,
        [-0.0406, -0.0324,  0.0329,  ...,  0.0080, -0.0130, -0.0109],
        [-0.0100, -0.0422,  0.0076,  ..., -0.0117, -0.0102,  0.0500],
        [ 0.0027,  0.0210,  0.0457,  ..., -0.0049, -0.0600, -0.0109]],
       requires_grad=True)

Dequantized weights: 
tensor([[ 0.0101, -0.0453, -0.0423,  ..., -0.0018, -0.0143,  0.0185],
        [-0.0601,  0.0160,  0.0091,  ...,  0.0266, -0.0358,  0.0183],
        [-0.0087,  0.0260,  0.0162,  ..., -0.0552,  0.0287,  0.0141],
        ...,
        [-0.0406, -0.0326,  0.0332,  ...,  0.0080, -0.0129, -0.0111],
        [-0.0101, -0.0420,  0.0072,  ..., -0.0116, -0.0101,  0.0499],
        [ 0.0029,  0.0210,  0.0456,  ..., -0.0051, -0.0600, -0.0108]])

Prec   | Accuracy | Model Size | Time Ta