In [2]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [3]:
_ = torch.manual_seed(0)

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cuda" if torch.cuda.is_available() else 'cpu'

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:17<00:00, 576689.52it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 168035.61it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:03<00:00, 489884.84it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [5]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [6]:
net = VerySimpleNet().to(device)


In [7]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 6000/6000 [00:44<00:00, 135.34it/s, loss=0.223]


In [9]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

In [10]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 361.062


In [11]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:32<00:00, 31.15it/s]

Accuracy: 0.964





In [22]:
class QuantizedSimpleNeuralnet(nn.Module):
    def __init__(self,hidden_size_1=100,hidden_size_2=100):
        super().__init__() 
        self.quantization = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequantization = torch.quantization.DeQuantStub()
    def forward(self,img):
        x = img.view(-1,28*28)
        x = self.quantization(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequantization(x)
        return x

        

In [23]:
quantized_neural_net = QuantizedSimpleNeuralnet().to(device)
quantized_neural_net.load_state_dict(net.state_dict())
quantized_neural_net.eval()

QuantizedSimpleNeuralnet(
  (quantization): QuantStub()
  (linear1): Linear(in_features=784, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=100, bias=True)
  (linear3): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (dequantization): DeQuantStub()
)

In [24]:
# add observer on all intermediate layers
quantized_neural_net.qconfig = torch.ao.quantization.default_qconfig
quantized_neural_net = torch.ao.quantization.prepare(quantized_neural_net)
quantized_neural_net

QuantizedSimpleNeuralnet(
  (quantization): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequantization): DeQuantStub()
)

In [25]:
test(quantized_neural_net)

Testing: 100%|██████████| 1000/1000 [00:34<00:00, 29.08it/s]

Accuracy: 0.964





In [26]:
print(f'Verify the statistics collected when running inference')
quantized_neural_net

Verify the statistics collected when running inference


QuantizedSimpleNeuralnet(
  (quantization): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-54.36076736450195, max_val=35.85297393798828)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-25.92667007446289, max_val=27.277904510498047)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-28.495838165283203, max_val=21.301504135131836)
  )
  (relu): ReLU()
  (dequantization): DeQuantStub()
)

In [27]:
real_quan_nn = torch.ao.quantization.convert(quantized_neural_net)

In [28]:
print('Original weights of the first linear layer')
print(net.linear1.weight)
print('Dequantized weights of the first linear layer')
print(torch.dequantize(real_quan_nn.linear1.weight()))

Original weights of the first linear layer
Parameter containing:
tensor([[ 7.4281e-05,  1.9500e-02, -2.9053e-02,  ...,  2.2277e-02,
          4.0743e-03,  2.4020e-03],
        [-1.5416e-02, -1.0620e-02, -6.0812e-03,  ..., -1.5903e-02,
         -1.6012e-03, -2.5585e-02],
        [ 2.0023e-02,  5.5070e-02,  6.8960e-03,  ...,  1.9828e-02,
          4.1354e-02,  4.8199e-02],
        ...,
        [ 4.9189e-02,  5.2900e-02,  1.8281e-02,  ...,  1.2918e-02,
          3.2172e-02, -4.7868e-03],
        [-9.3255e-03, -1.2189e-03,  3.0838e-02,  ...,  1.1121e-02,
          1.1175e-02,  1.0678e-02],
        [ 1.7474e-02,  1.2149e-02, -2.0101e-03,  ...,  3.4410e-02,
         -1.4872e-02,  5.2539e-03]], device='cuda:0', requires_grad=True)
Dequantized weights of the first linear layer
tensor([[ 0.0000,  0.0185, -0.0278,  ...,  0.0231,  0.0046,  0.0046],
        [-0.0139, -0.0093, -0.0046,  ..., -0.0139,  0.0000, -0.0278],
        [ 0.0185,  0.0555,  0.0046,  ...,  0.0185,  0.0416,  0.0463],
        ..

In [36]:
print(f'data type of quantized weights')
real_quan_nn.linear1.weight().dtype

data type of quantized weights


torch.qint8

In [34]:
print('the size of the model after quantization')
print_size_of_model(real_quan_nn)
# we can see that the model size is reduced by factor of 4 , because we convert from 32 singleprecision floating point to 8 bits integer

the size of the model after quantization
Size (KB): 95.458
