In [1]:
pip install brevitas


Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from brevitas.nn import QuantConv2d, QuantLinear
import torch.nn.utils.prune as prune


In [3]:
# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [4]:

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)




In [5]:
# Define CNN model with Brevitas quantization (3-bit)
class QuantizedCNN(nn.Module):
    def __init__(self):
        super(QuantizedCNN, self).__init__()
        # Quantized convolutional layer (3-bit quantization)
        self.conv1 = QuantConv2d(
            in_channels=1, out_channels=16, kernel_size=3, padding=1,
            weight_bit_width=3, bias=False)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # Quantized fully connected layer (3-bit quantization)
        self.fc = QuantLinear(16 * 14 * 14, 10, weight_bit_width=3, bias=True)


    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(-1, 16 * 14 * 14)  # Flatten the output for the fully connected layer
        x = self.fc(x)
        return x




In [6]:
# Initialize the model, loss function, and optimizer
model = QuantizedCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [7]:
# Print initial weights (before pruning)
print("Weights before pruning:")
print("Conv1 Weights:", model.conv1.weight.data)
print("FC Weights:", model.fc.weight.data)




Weights before pruning:
Conv1 Weights: tensor([[[[-0.0910, -0.2786,  0.2632],
          [-0.0710,  0.1576,  0.0442],
          [-0.1047, -0.2730,  0.2813]]],


        [[[-0.1965,  0.0082, -0.2192],
          [ 0.2025,  0.2554, -0.0031],
          [ 0.1823,  0.2169,  0.0707]]],


        [[[ 0.2922,  0.2998,  0.0036],
          [ 0.0670, -0.0631, -0.0249],
          [ 0.2618,  0.0044,  0.1424]]],


        [[[-0.3102, -0.0480, -0.3234],
          [-0.0128, -0.2154, -0.0218],
          [-0.2907, -0.1889, -0.3316]]],


        [[[-0.1970, -0.0988, -0.1040],
          [-0.0712, -0.0657,  0.1960],
          [ 0.2197, -0.0778, -0.1412]]],


        [[[-0.2037,  0.0770,  0.2956],
          [ 0.1054, -0.2304, -0.1873],
          [ 0.2428, -0.0365, -0.1745]]],


        [[[ 0.1458,  0.1240, -0.1930],
          [-0.1460, -0.1581, -0.3102],
          [-0.1753, -0.0260, -0.2082]]],


        [[[ 0.1877,  0.1062, -0.0604],
          [-0.1049,  0.1274, -0.2945],
          [-0.0793,  0.2840,  0.0659

In [8]:
# Apply global unstructured pruning (prune 70% of the weights)
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.fc, 'weight')
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.7,  # Prune 70% of the weights globally
)


In [9]:
# Permanently remove the pruned weights
prune.remove(model.conv1, 'weight')
prune.remove(model.fc, 'weight')


QuantLinear(
  in_features=3136, out_features=10, bias=True
  (input_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
  )
  (output_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
  )
  (weight_quant): WeightQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
    (tensor_quant): RescalingIntQuant(
      (int_quant): IntQuant(
        (float_to_int_impl): RoundSte()
        (tensor_clamp_impl): TensorClampSte()
        (delay_wrapper): DelayWrapper(
          (delay_impl): _NoDelay()
        )
      )
      (scaling_impl): StatsFromParameterScaling(
        (parameter_list_stats): _ParameterListStats(
          (first_tracked_param): _ViewParameterWrapper(
            (view_shape_impl): OverTensorView()
          )
          (stats): _Stats(
            (stats_impl): AbsMax()
          )
        )
        (stats_scaling_impl): _StatsScaling(
          (affine_rescaling): Identity()
          (restrict_clamp_scalin

In [10]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')



  return super().rename(names)


Epoch [1/5], Loss: 0.1436
Epoch [2/5], Loss: 0.0405
Epoch [3/5], Loss: 0.0754
Epoch [4/5], Loss: 0.0150
Epoch [5/5], Loss: 0.0510


In [11]:
# Testing loop
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the pruned 3-bit quantized model on the test images: {100 * correct / total:.2f}%')



Accuracy of the pruned 3-bit quantized model on the test images: 97.16%
