<a href="https://colab.research.google.com/github/mehwishferoz/EIS/blob/master/Custom_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
class CustomCNN(nn.Module):
  def __init__(self):
    super(CustomCNN, self).__init__()
    # 1st Conv layer
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1);
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    # 2nd Conv layer
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

    # 3rd Conv layer
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

    # FC layers
    self.fc1 = nn.Linear(128 * 28 * 28, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, 40)

    def forward(self, x):
      x = self.pool1(F.relu(self.conv1(x)))
      x = self.pool2(F.relu(self.conv2(x)))
      x = self.pool3(F.relu(self.conv3(x)))
      x = x.view(-1, 128 * 28 * 28)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

In [None]:
# Instantiate the model
model = CustomCNN()

# Quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Fuse Conv, BN and ReLU layers (if any)
# Assuming no BatchNorm in this specific model, so no fusion step required

# Prepare the model for quantization
model_prepared = torch.quantization.prepare(model)

# Convert the model to a quantized version
model_quantized = torch.quantization.convert(model_prepared)

# Save the quantized model to a file
torch.save(model_quantized.state_dict(), "model_32.pt")

# You can also save the entire model with structure
torch.save(model_quantized, "model_32_complete.pt")