In [1]:
## Importing Libraries
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
import os

In [2]:
## Load MNIST Dataset

#Make torch deterministic
_ = torch.manual_seed(0)


# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

#print(mnist_trainset[30][0].shape)
train_loader = DataLoader(mnist_trainset,shuffle=True,batch_size=10)

# for train_batch in train_loader:
#     X,Y = train_batch[0],train_batch[1]
#     print(X,Y)
#     break
#classes = [int(y) for x in train_loader for y in x[1]]
output_classes = 10 ## len(set(classes))



mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(mnist_testset,shuffle=True, batch_size=10)

device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
## Neural Network

class SimpleNeuralNet(nn.Module):
    def __init__(self, hidden_size_1:int=200,hidden_size_2=200) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(28*28,hidden_size_1)
        self.linear_2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.linear_3 = nn.Linear(hidden_size_2,output_classes)
        self.relu = nn.ReLU()


    def forward(self,img):
        ## (B,1,28,28) -> (B,28*28)
        x = img.view(-1,28*28)
        x = self.relu(self.linear_1(x)) ## (B,28*28) -> (B,200)
        x = self.relu(self.linear_2(x)) ## (B,200) -> (B,200)
        x = self.linear_3(x) ##(B,200) -> (B,10)
        return x


model = SimpleNeuralNet().to(device)



In [None]:
##training

def train(train_loader, model, epochs=5):
    optimizer  = torch.optim.Adam(model.parameters(),lr=0.001,eps = 1e-9)
    global_step = 0

    for epoch in range(0,epochs):
        torch.cuda.empty_cache()

        model.train()

        batch_iterator = tqdm(train_loader,desc = f"Processing Epoch: {epoch}")
        losses = []

        for train_batch in batch_iterator:
            X, Y = train_batch[0],train_batch[1] ##X -> (B,1,28,28) ; Y -> (B)
            logits = model(X) ## (B,10)

            loss = F.cross_entropy(logits,Y)
            batch_iterator.set_postfix({"loss":loss.item()})

            losses.append(loss.item())


            optimizer.zero_grad(set_to_none=True)

            ##back propagate the loss and compute the gradient
            loss.backward()

            ## update the weight
            optimizer.step()

            global_step += 1

        batch_iterator.write(f"Epoch :{epoch} | Avg. Training Loss: {sum(losses) / len(losses)}")
    return model


model_path = Path("model_path")
model_filename = str(model_path / "simple_net.pt")
train(train_loader,model)
torch.save(model.state_dict(),model_filename)

In [4]:
## test 
model_path = Path("model_path")
model_filename = str(model_path / "simple_net.pt")
def test(test_loader, model):
    total = 0
    correct_match = 0
    for test_batch in test_loader:
        X,Y = test_batch[0],test_batch[1]

        logits = model(X) ## (B,10)
        value,predicted_labels = torch.max(logits,dim=1)
        for y_predicted,y in zip(predicted_labels,Y):
            if y_predicted == y:
                correct_match += 1
            total +=1

    acc = correct_match / total
    print(f"Accu: {acc}")
    return acc


model.load_state_dict(torch.load(model_filename))

test(test_loader,model)

model_size = os.path.getsize(model_filename) / 1e3
print(f"ModelSize: {model_size} KB")


Accu: 0.9706
ModelSize: 798.959 KB


In [26]:
print(model.linear_1.weight,model.linear_1.weight.dtype)

Parameter containing:
tensor([[ 0.0476,  0.0670,  0.0185,  ...,  0.0698,  0.0516,  0.0499],
        [-0.0219, -0.0171, -0.0125,  ..., -0.0224, -0.0081, -0.0320],
        [-0.0385, -0.0035, -0.0516,  ..., -0.0387, -0.0172, -0.0103],
        ...,
        [ 0.0365,  0.0101,  0.0169,  ...,  0.0305,  0.0517,  0.0708],
        [ 0.0675,  0.0146,  0.0127,  ...,  0.0154,  0.0383,  0.0618],
        [ 0.0644,  0.0617,  0.0635,  ...,  0.0338,  0.0371,  0.0438]],
       requires_grad=True) torch.float32


In [64]:
##Configure the quantization settings
#print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = "qnnpack" ## this is important to set. 
quantization_config = torch.quantization.get_default_qconfig('qnnpack')


##Prepare the model for quantization
model.qconfig = quantization_config
quantized_model = torch.quantization.prepare(model)
print(quantized_model)



SimpleNeuralNet(
  (linear_1): Linear(
    in_features=784, out_features=200, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (linear_2): Linear(
    in_features=200, out_features=200, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (linear_3): Linear(
    in_features=200, out_features=10, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
)


In [65]:
##Calibrate the model
test(test_loader,quantized_model)
print(quantized_model)

Accu: 0.9706
SimpleNeuralNet(
  (linear_1): Linear(
    in_features=784, out_features=200, bias=True
    (activation_post_process): HistogramObserver(min_val=-88.74122619628906, max_val=59.636592864990234)
  )
  (linear_2): Linear(
    in_features=200, out_features=200, bias=True
    (activation_post_process): HistogramObserver(min_val=-75.33014678955078, max_val=63.807403564453125)
  )
  (linear_3): Linear(
    in_features=200, out_features=10, bias=True
    (activation_post_process): HistogramObserver(min_val=-153.69056701660156, max_val=44.299156188964844)
  )
  (relu): ReLU()
)


In [68]:
## Convert the model into a quantized form
torch.quantization.convert(quantized_model)

SimpleNeuralNet(
  (linear_1): QuantizedLinear(in_features=784, out_features=200, scale=0.45117953419685364, zero_point=159, qscheme=torch.per_tensor_affine)
  (linear_2): QuantizedLinear(in_features=200, out_features=200, scale=0.4129579961299896, zero_point=150, qscheme=torch.per_tensor_affine)
  (linear_3): QuantizedLinear(in_features=200, out_features=10, scale=0.7180463671684265, zero_point=198, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
)

In [None]:
print(f"Normal Model Weight: {model.linear_1.weight} | Dtype: {model.linear_1.weight.dtype}")
print(f"Quantized Model Weight: {quantized_model.linear_1.weight} | Dtype: {quantized_model.linear_1.weight.dtype}")

In [81]:
quantized_model_filename = str(model_path / "quantized_net.pt")
torch.save(quantized_model.state_dict(),quantized_model_filename)

quantized_model_size = os.path.getsize(quantized_model_filename) / 1e3
print(f"QuantizedModelSize: {quantized_model_size} KB")


test(test_loader,quantized_model)

QuantizedModelSize: 827.175 KB
Accu: 0.9706


0.9706