<a href="https://colab.research.google.com/github/kalyaannnn/TransforMER/blob/main/PostTrainingQuantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [2]:
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

mnist_trainset = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)

mnist_testset = datasets.MNIST(root = './data', train = False, download = True)
test_dataloader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15303985.98it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 456610.97it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4188025.30it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1143420.49it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






device(type='cuda')

In [14]:
class SimpleNet(nn.Module):
  def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
    super(SimpleNet, self).__init__()
    self.linear1 = nn.Linear(28 * 28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [15]:
net = SimpleNet().to(device)

In [7]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

  total_iterations = 0

  for epoch in range(epochs):
    net.train()

    loss_sum = 0
    num_iterations = 0


    data_iterator = tqdm(train_loader, desc = f'Epoch {epoch + 1}')
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit
    for data in data_iterator:
      num_iterations += 1
      total_iterations += 1
      x, y = data
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      output = net(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss = avg_loss)
      loss.backward()
      optimizer.step()

      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return

def print_size_of_model(model):
  torch.save(model.state_dict(), "temp_delme.p")
  print('Size (KB) :', os.path.getsize("temp_delme.p")/1e3)
  os.remove('temp_delme.p')

In [8]:
MODEL_FILENAME = 'simplenet_ptq.pt'

In [16]:
if Path(MODEL_FILENAME).exists():
  net.load_state_dict(torch.load(MODEL_FILENAME))
  print('Loaded Model from the disk')
else:
  train(train_loader, net, epochs = 1)
  torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 6000/6000 [00:41<00:00, 145.76it/s, loss=0.215]


In [17]:
def test(model : nn.Module, total_iterations : int = None):
  correct = 0
  total = 0

  iterations = 0

  model.eval()
  with torch.no_grad():
    for data in tqdm(test_dataloader, desc = 'Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = model(x.view(-1, 784))
      for idx, i in enumerate(output):
        if(torch.argmax(i) == y[idx]):
          correct += 1
        total += 1
      iterations += 1
      if total_iterations is not None and iterations >= total_iterations:
        break

  print(f"Accuracy : {round(correct/total, 3)}")

Weights and size of the model before quantization

In [19]:
print("Weights before quantization")
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0111, -0.0337,  0.0008,  ..., -0.0153,  0.0071,  0.0363],
        [ 0.0133,  0.0466,  0.0403,  ...,  0.0583,  0.0349,  0.0255],
        [ 0.0225, -0.0102,  0.0374,  ...,  0.0290, -0.0150, -0.0116],
        ...,
        [ 0.0106,  0.0444, -0.0069,  ...,  0.0517,  0.0119,  0.0226],
        [-0.0179, -0.0056,  0.0184,  ..., -0.0132, -0.0009, -0.0169],
        [ 0.0028,  0.0514,  0.0516,  ...,  0.0459,  0.0241, -0.0012]],
       device='cuda:0', requires_grad=True)
torch.float32


In [20]:
print("Size of the model before quantization")
print_size_of_model(net)

Size of the model before quantization
Size (KB) : 361.062


In [21]:
print(f"Accuracy of the mode before quantization")
test(net)

Accuracy of the mode before quantization


Testing: 100%|██████████| 6000/6000 [00:19<00:00, 313.01it/s]

Accuracy : 0.96





Insert Mix Max Observers in the model

In [34]:
class QuantizedNet(nn.Module):
  def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
    super(QuantizedNet, self).__init__()
    self.quant = torch.quantization.QuantStub()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()
    self.dequant = torch.quantization.DeQuantStub()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.quant(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.dequant(x)
    return x

In [35]:
net_quantized = QuantizedNet().to(device)

net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) #Insert Observer
net_quantized

QuantizedNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [36]:
test(net_quantized)

Testing: 100%|██████████| 6000/6000 [00:22<00:00, 262.14it/s]

Accuracy : 0.96





In [37]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [38]:
print(f'Statistics of the various layers')
net_quantized

Statistics of the various layers


QuantizedNet(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6587608456611633, zero_point=72, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.39274415373802185, zero_point=70, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.42928650975227356, zero_point=77, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [39]:
print(f'Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[ 3, -8,  0,  ..., -4,  2,  9],
        [ 3, 11, 10,  ..., 14,  9,  6],
        [ 6, -2,  9,  ...,  7, -4, -3],
        ...,
        [ 3, 11, -2,  ..., 13,  3,  6],
        [-4, -1,  4,  ..., -3,  0, -4],
        [ 1, 13, 13,  ..., 11,  6,  0]], device='cuda:0', dtype=torch.int8)


In [40]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[ 0.0111, -0.0337,  0.0008,  ..., -0.0153,  0.0071,  0.0363],
        [ 0.0133,  0.0466,  0.0403,  ...,  0.0583,  0.0349,  0.0255],
        [ 0.0225, -0.0102,  0.0374,  ...,  0.0290, -0.0150, -0.0116],
        ...,
        [ 0.0106,  0.0444, -0.0069,  ...,  0.0517,  0.0119,  0.0226],
        [-0.0179, -0.0056,  0.0184,  ..., -0.0132, -0.0009, -0.0169],
        [ 0.0028,  0.0514,  0.0516,  ...,  0.0459,  0.0241, -0.0012]],
       device='cuda:0', requires_grad=True)

Dequantized weights: 
tensor([[ 0.0123, -0.0328,  0.0000,  ..., -0.0164,  0.0082,  0.0369],
        [ 0.0123,  0.0450,  0.0409,  ...,  0.0573,  0.0369,  0.0246],
        [ 0.0246, -0.0082,  0.0369,  ...,  0.0287, -0.0164, -0.0123],
        ...,
        [ 0.0123,  0.0450, -0.0082,  ...,  0.0532,  0.0123,  0.0246],
        [-0.0164, -0.0041,  0.0164,  ..., -0.0123,  0.0000, -0.0164],
        [ 0.0041,  0.0532,  0.0532,  ...,  0.0450,  0.0246,  0.0000]],
       device='cuda:0')


In [41]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB) : 95.458
