<a href="https://colab.research.google.com/github/kalyaannnn/TransforMER/blob/main/QuantizationAwareTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from pathlib import Path

In [2]:
_ = torch.manual_seed(0)

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

In [7]:
mnist_trainset = datasets.MNIST(root = './data', download = True, train = True, transform = transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)

mnist_testset = datasets.MNIST(root = './data', download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size = 10, shuffle = True)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [9]:
class QuantizedNet(nn.Module):
  def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
    super(QuantizedNet, self).__init__()
    self.quant = torch.quantization.QuantStub()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()
    self.dequant = torch.quantization.DeQuantStub()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.quant(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.dequant(x)
    return x

In [10]:
net = QuantizedNet().to(device)

In [13]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net)
net_quantized

QuantizedNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [16]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

  total_iterations = 0

  for epoch in range(epochs):
    net.train()

    loss_sum = 0
    num_iterations = 0

    data_iterator = tqdm(train_loader, desc = f"Epoch {epoch + 1}")
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit
    for data in data_iterator:
      num_iterations += 1
      total_iterations += 1
      x, y = data
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      output = net(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss = avg_loss)
      loss.backward()
      optimizer.step()

      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return


def print_size_of_model(model):
  torch.save(model.state_dict(), "temp_delme.p")
  print('Size (KB) :', os.path.getsize("temp_delme.p")/1e3)
  os.remove('temp_delme.p')

In [17]:
train(train_loader, net_quantized, epochs = 1)

Epoch 1: 100%|██████████| 6000/6000 [00:47<00:00, 127.01it/s, loss=0.22]


In [20]:
def test(model : nn.Module, total_iterations : int = None):
  correct = 0
  total = 0

  iterations = 0

  model.eval()
  with torch.no_grad():
    for data in tqdm(test_loader, desc = 'Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = model(x.view(-1, 784))
      for idx, i in enumerate(output):
        if(torch.argmax(i) == y[idx]):
          correct += 1
        total += 1
      iterations += 1
      if total_iterations is not None and iterations >= total_iterations:
        break

  print(f"Accuracy : {round(correct/total, 3)}")

In [19]:
print(f'Stastics during training')
net_quantized

Stastics during training


QuantizedNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.6093528270721436, max_val=0.33149316906929016)
    (activation_post_process): MinMaxObserver(min_val=-45.35358810424805, max_val=30.954242706298828)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.47003060579299927, max_val=0.3765200674533844)
    (activation_post_process): MinMaxObserver(min_val=-28.6861572265625, max_val=20.01570701599121)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.3520001769065857, max_val=0.2868342399597168)
    (activation_post_process): MinMaxObserver(min_val=-30.697715759277344, max_val=22.732498168945312)
  )
  (relu): ReLU()
  (dequant): DeQuantSt

In [22]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:04<00:00, 212.63it/s]

Accuracy : 0.957



