In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Define AlexNet Model for MNIST
class AlexNetMNIST(nn.Module):
    def __init__(self, q=False):
        super(AlexNetMNIST, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=2),  # Conv1
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),  # Conv2
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),  # Conv3
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),  # Conv4
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),  # Conv5
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),  # FC6
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),  # FC7
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),  # FC8 (output for MNIST)
        )
        self.q = q
        if q:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.q:
            x = self.quant(x)
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.classifier(x)
        if self.q:
            x = self.dequant(x)
        return x

# MNIST Dataset Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for AlexNet
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize for grayscale images
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

# Training Function
def train(model, dataloader, epochs=10, cuda=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(dataloader):
            inputs, labels = data
            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader):.4f}, Accuracy: {100 * correct / total:.2f}%')

# Testing Function
def test(model, dataloader, cuda=False):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy}%')
    return accuracy


# Train the FP32 Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
alexnet_fp32 = AlexNetMNIST(q=False).to(device)
train(alexnet_fp32, trainloader, epochs=10, cuda=(device == 'cuda'))
torch.save(alexnet_fp32.state_dict(), "alexnet_fp32_mnist.pth")
print("FP32 Model Trained and Saved.")

# Test FP32 Model
fp32_accuracy = test(alexnet_fp32, testloader, cuda=(device == 'cuda'))
print(f"FP32 Model Accuracy on MNIST: {fp32_accuracy}%")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 186242737.42it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 14373005.91it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 73548275.02it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17606773.35it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch [1/10], Loss: 0.4617, Accuracy: 84.07%
Epoch [2/10], Loss: 0.0840, Accuracy: 97.52%
Epoch [3/10], Loss: 0.0666, Accuracy: 98.05%
Epoch [4/10], Loss: 0.0612, Accuracy: 98.19%
Epoch [5/10], Loss: 0.0547, Accuracy: 98.39%
Epoch [6/10], Loss: 0.0519, Accuracy: 98.47%
Epoch [7/10], Loss: 0.0499, Accuracy: 98.50%
Epoch [8/10], Loss: 0.0453, Accuracy: 98.64%
Epoch [9/10], Loss: 0.0424, Accuracy: 98.82%
Epoch [10/10], Loss: 0.0425, Accuracy: 98.78%
FP32 Model Trained and Saved.
Accuracy: 99.19%
FP32 Model Accuracy on MNIST: 99.19%


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Per-Layer Quantization
def per_layer_quantize(tensor):
    max_val = tensor.abs().amax()
    scale = 127 / max_val
    quantized_tensor = (tensor * scale).clamp(-127, 127).round().char()
    return quantized_tensor, scale

def per_layer_dequantize(quantized_tensor, scale):
    return quantized_tensor.float() / scale

# Quantized Forward Pass
def quantized_forward_per_layer(model, x, quantize_fn, dequantize_fn):
    with torch.no_grad():
        weights_q = {}
        scales = {}

        # Quantize weights
        for name, param in model.named_parameters():
            weights_q[name], scales[name] = quantize_fn(param.data)

        # Forward pass through features
        for i, layer in enumerate(model.features):
            if isinstance(layer, nn.Conv2d):
                key = f"features.{i}.weight"
                activation_scale = 127 / (x.abs().amax() + 1e-8)
                x = F.conv2d(
                    x / activation_scale,
                    dequantize_fn(weights_q[key], scales[key]),
                    stride=layer.stride,
                    padding=layer.padding
                )
                x = (x * activation_scale).clamp(-127, 127).round().char()
            elif isinstance(layer, nn.MaxPool2d):
                x = x.float()  # Convert back to Float for pooling
                x = layer(x)
                activation_scale = 127 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-127, 127).round().char()  # Requantize
            elif isinstance(layer, nn.ReLU):
                x = x.float()  # Convert back to Float for ReLU
                x = layer(x)
                activation_scale = 127 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-127, 127).round().char()  # Requantize

        x = x.view(x.size(0), -1)

        # Forward pass through classifier
        for i, layer in enumerate(model.classifier):
            if isinstance(layer, nn.Linear):
                key = f"classifier.{i}.weight"
                activation_scale = 127 / (x.abs().amax() + 1e-8)
                x = F.linear(
                    x / activation_scale,
                    dequantize_fn(weights_q[key], scales[key])
                )
                x = (x * activation_scale).clamp(-127, 127).round().char()
            elif isinstance(layer, nn.ReLU) or isinstance(layer, nn.Dropout):
                x = x.float()  # Convert back to Float for ReLU or Dropout
                x = layer(x)
                activation_scale = 127 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-127, 127).round().char()  # Requantize

        return x


# Quantized Model Testing
def test_quantized_per_layer(model, dataloader, device, quantize_fn, dequantize_fn):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = quantized_forward_per_layer(model, inputs, quantize_fn, dequantize_fn)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Quantized Model Accuracy: {accuracy}%")
    return accuracy


# Quantized Testing for CIFAR-10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
alexnet_fp32.load_state_dict(torch.load("alexnet_fp32_mnist.pth"))
print("Testing Per-Layer INT8 Quantized Model...")
int8_accuracy = test_quantized_per_layer(alexnet_fp32, testloader, device, per_layer_quantize, per_layer_dequantize)
print(f"INT8 Quantized Model Accuracy: {int8_accuracy}%")


Testing Per-Layer INT8 Quantized Model...
Quantized Model Accuracy: 98.6%
INT8 Quantized Model Accuracy: 98.6%


In [10]:
# INT16 Per-Layer Quantization
def per_layer_quantize_int16(tensor):
    max_val = tensor.abs().amax()
    scale = 32767 / max_val
    quantized_tensor = (tensor * scale).clamp(-32767, 32767).round().short()
    return quantized_tensor, scale

def per_layer_dequantize_int16(quantized_tensor, scale):
    return quantized_tensor.float() / scale

# INT16 Quantized Forward Pass
def quantized_forward_per_layer_int16(model, x, quantize_fn, dequantize_fn):
    with torch.no_grad():
        weights_q = {}
        scales = {}

        # Quantize weights
        for name, param in model.named_parameters():
            weights_q[name], scales[name] = quantize_fn(param.data)

        # Forward pass through features
        for i, layer in enumerate(model.features):
            if isinstance(layer, nn.Conv2d):
                key = f"features.{i}.weight"
                activation_scale = 32767 / (x.abs().amax() + 1e-8)
                x = F.conv2d(
                    x / activation_scale,
                    dequantize_fn(weights_q[key], scales[key]),
                    stride=layer.stride,
                    padding=layer.padding
                )
                x = (x * activation_scale).clamp(-32767, 32767).round().short()
            elif isinstance(layer, nn.MaxPool2d):
                x = x.float()  # Convert back to Float for pooling
                x = layer(x)
                activation_scale = 32767 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-32767, 32767).round().short()  # Requantize
            elif isinstance(layer, nn.ReLU):
                x = x.float()  # Convert back to Float for ReLU
                x = layer(x)
                activation_scale = 32767 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-32767, 32767).round().short()  # Requantize

        x = x.view(x.size(0), -1)

        # Forward pass through classifier
        for i, layer in enumerate(model.classifier):
            if isinstance(layer, nn.Linear):
                key = f"classifier.{i}.weight"
                activation_scale = 32767 / (x.abs().amax() + 1e-8)
                x = F.linear(
                    x / activation_scale,
                    dequantize_fn(weights_q[key], scales[key])
                )
                x = (x * activation_scale).clamp(-32767, 32767).round().short()
            elif isinstance(layer, nn.ReLU) or isinstance(layer, nn.Dropout):
                x = x.float()  # Convert back to Float for ReLU or Dropout
                x = layer(x)
                activation_scale = 32767 / (x.abs().amax() + 1e-8)  # Recompute scale
                x = (x * activation_scale).clamp(-32767, 32767).round().short()  # Requantize

        return x

# Quantized Model Testing for INT16
def test_quantized_per_layer_int16(model, dataloader, device, quantize_fn, dequantize_fn):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = quantized_forward_per_layer_int16(model, inputs, quantize_fn, dequantize_fn)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Quantized Model Accuracy: {accuracy}%")
    return accuracy

# Test INT16 Quantized Model
print("Testing INT16 Quantized Model...")
int16_accuracy = test_quantized_per_layer_int16(
    alexnet_fp32, testloader, device, per_layer_quantize_int16, per_layer_dequantize_int16
)
print(f"INT16 Quantized Model Accuracy: {int16_accuracy}%")


Testing INT16 Quantized Model...
Quantized Model Accuracy: 98.64%
INT16 Quantized Model Accuracy: 98.64%


In [9]:
import torch.onnx

# Convert FP32 Model to ONNX
def convert_fp32_to_onnx(model, onnx_filename, input_size=(1, 1, 224, 224)):
    model.eval()
    dummy_input = torch.randn(*input_size).to(next(model.parameters()).device)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_filename,
        export_params=True,
        opset_version=11,
        input_names=['input'],
        output_names=['output']
    )
    print(f"FP32 Model exported to {onnx_filename}")

# Example Usage
convert_fp32_to_onnx(alexnet_fp32, "alexnet_fp32_mnist.onnx")

FP32 Model exported to alexnet_fp32_mnist.onnx
