In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os



# Load MNIST Dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device 
device = "cpu"


# Define the model

In [6]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub() # observer attached
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub() # observer attached

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net = VerySimpleNet()

NameError: name 'QuantizedVerySimpleNet' is not defined

# Train the model

In [4]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y  =  data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print("Size (KB):", os.path.getsize("temp_delme.p"))
    os.remove("temp_delme.p")

def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

MODEL_FILENAME = "simplenet_ptq.pt"

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print("Model Loaded")
else:
    train(train_loader, net, epochs=5)
    torch.save(net.state_dict(), MODEL_FILENAME)



NameError: name 'net' is not defined

# Define the testing loop


In [None]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break

    print("Accuracy: ", {round(correct/total, 3)})
    

# Print weight and size of the model before quantization

In [None]:
print("Before Quantization")
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Before Quantization
Parameter containing:
tensor([[ 0.0395,  0.0484,  0.0550,  ...,  0.0748,  0.0439,  0.0949],
        [ 0.0086, -0.0313, -0.0395,  ..., -0.0131, -0.0173,  0.0230],
        [ 0.1131,  0.1254,  0.1067,  ...,  0.0882,  0.1117,  0.1030],
        ...,
        [ 0.0932,  0.0657,  0.0992,  ...,  0.0686,  0.0909,  0.0601],
        [-0.0316, -0.0026,  0.0270,  ..., -0.0317, -0.0047,  0.0196],
        [ 0.0045, -0.0258, -0.0381,  ..., -0.0010, -0.0300,  0.0055]],
       requires_grad=True)
torch.float32


In [None]:
print("Size of the model before quantization")
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360998


In [None]:
print(f"Accuracy of the model before quantization")
test(net)

Accuracy of the model before quantization


Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1342.65it/s]

Accuracy:  {0.973}





# Insert min-max observers in the model

In [15]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub() # observer attached
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub() # observer attached

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

# copy weights from the pretrained model to this new model

In [16]:
net_quantized = QuantizedVerySimpleNet().to(device)
# copy the weights from the original model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers

net_quantized


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Calibrate the model using the test set

In [17]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1234.59it/s]

Accuracy:  {0.973}





In [18]:
print(f"Check statistics of various layers")
net_quantized

Check statistics of various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-90.62647247314453, max_val=65.07781982421875)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-92.76065063476562, max_val=48.8410758972168)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-95.79275512695312, max_val=32.822120666503906)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantize the model using the statistics collected

In [22]:
torch.backends.quantized.engine = 'qnnpack' # this is needed else it throws an error
net_quantized = torch.ao.quantization.convert(net_quantized)


In [23]:
print(f'Check the statistics of the various layers after conversion')
net_quantized

Check the statistics of the various layers after conversion


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=1.226017951965332, zero_point=74, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=1.1149742603302002, zero_point=83, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=1.012715458869934, zero_point=95, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print the weights of the model after quantization

In [25]:
# print the weights of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[ 4,  5,  6,  ...,  8,  5, 10],
        [ 1, -3, -4,  ..., -1, -2,  3],
        [12, 14, 12,  ..., 10, 12, 11],
        ...,
        [10,  7, 11,  ...,  7, 10,  7],
        [-3,  0,  3,  ..., -3, -1,  2],
        [ 0, -3, -4,  ...,  0, -3,  1]], dtype=torch.int8)


# Compare the dequantized weights and the original weights

In [26]:
print('Original weights')
print(net.linear1.weight)
print('')
print("Quantized weights")
print(net_quantized.linear1.weight())

Original weights
Parameter containing:
tensor([[ 0.0395,  0.0484,  0.0550,  ...,  0.0748,  0.0439,  0.0949],
        [ 0.0086, -0.0313, -0.0395,  ..., -0.0131, -0.0173,  0.0230],
        [ 0.1131,  0.1254,  0.1067,  ...,  0.0882,  0.1117,  0.1030],
        ...,
        [ 0.0932,  0.0657,  0.0992,  ...,  0.0686,  0.0909,  0.0601],
        [-0.0316, -0.0026,  0.0270,  ..., -0.0317, -0.0047,  0.0196],
        [ 0.0045, -0.0258, -0.0381,  ..., -0.0010, -0.0300,  0.0055]],
       requires_grad=True)

Quantized weights
tensor([[ 0.0367,  0.0459,  0.0551,  ...,  0.0735,  0.0459,  0.0918],
        [ 0.0092, -0.0276, -0.0367,  ..., -0.0092, -0.0184,  0.0276],
        [ 0.1102,  0.1286,  0.1102,  ...,  0.0918,  0.1102,  0.1010],
        ...,
        [ 0.0918,  0.0643,  0.1010,  ...,  0.0643,  0.0918,  0.0643],
        [-0.0276,  0.0000,  0.0276,  ..., -0.0276, -0.0092,  0.0184],
        [ 0.0000, -0.0276, -0.0367,  ...,  0.0000, -0.0276,  0.0092]],
       size=(100, 784), dtype=torch.qint8,
    

# Print size and accuracy of the quantized model

In [27]:
print("Size of the model after quantization")
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 95266


In [28]:
print("Testing the model after quantization")
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 762.67it/s]

Accuracy:  {0.971}



