In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, fuse_modules, prepare_qat, convert

class SimpleQATModel(nn.Module):
    def __init__(self):
        super(SimpleQATModel, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(8 * 28 * 28, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)  # Quantize input
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        x = self.dequant(x)  # Dequantize output
        return x


In [2]:
import torchvision
import torchvision.transforms as transforms

# Define the dataset and loader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)


In [None]:
# Instantiate the model
model = SimpleQATModel()

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], 
             on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), 
             record_shapes=True) as prof:

    model.train()
    for epoch in range(1):  # Train the full-precision model
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    prof.export()

print("Full-precision training complete.")
torch.save(model.state_dict(), "full_precision_model.pth")


[W117 02:38:07.720226687 kineto_shim.cpp:415] Adding profiling metadata requires using torch.profiler with Kineto support (USE_KINETO=1)


In [6]:
fuse_modules(model, [['conv1', 'relu']], inplace=True)

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
prepare_qat(model, inplace=True)

with profile(activities=[ProfilerActivity.CPU], 
             on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), 
             record_shapes=True) as prof:
    for epoch in range(1):  # Fine-tune with QAT
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    prof.export()

print("QAT fine-tuning complete.")




QAT fine-tuning complete.


In [7]:
# Convert to a quantized model
model.eval()
model = convert(model, inplace=True)
print("Model quantized.")


Model quantized.


In [8]:
import os

# Save models
torch.save(model.state_dict(), "quantized_model.pth")
print("Quantized model saved as 'quantized_model.pth'.")

# Get file sizes
full_precision_size = os.path.getsize("full_precision_model.pth")
quantized_size = os.path.getsize("quantized_model.pth")

print(f"Full-Precision Model Size: {full_precision_size / 1024:.2f} KB")
print(f"Quantized Model Size: {quantized_size / 1024:.2f} KB")



Quantized model saved as 'quantized_model.pth'.
Full-Precision Model Size: 247.43 KB
Quantized Model Size: 66.50 KB


In [6]:
!tensorboard

Traceback (most recent call last):
  File [35m"/home/bandham/miniconda3/envs/dl_venv/bin/tensorboard"[0m, line [35m6[0m, in [35m<module>[0m
    from tensorboard.main import run_main
  File [35m"/home/bandham/miniconda3/envs/dl_venv/lib/python3.13/site-packages/tensorboard/main.py"[0m, line [35m27[0m, in [35m<module>[0m
    from tensorboard import default
  File [35m"/home/bandham/miniconda3/envs/dl_venv/lib/python3.13/site-packages/tensorboard/default.py"[0m, line [35m40[0m, in [35m<module>[0m
    from tensorboard.plugins.image import images_plugin
  File [35m"/home/bandham/miniconda3/envs/dl_venv/lib/python3.13/site-packages/tensorboard/plugins/image/images_plugin.py"[0m, line [35m18[0m, in [35m<module>[0m
    import imghdr
[1;35mModuleNotFoundError[0m: [35mNo module named 'imghdr'[0m
