In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torch.quantization import quantize_dynamic

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data augmentation and normalization
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform_test, download=True)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Load EfficientNet with pretrained weights
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, 10)
model = model.to(device)

# Train function
def train(model, dataloader, cuda=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    model.train()
    for epoch in range(10):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(dataloader):
            if cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if i % 100 == 0:
                print(f'Epoch [{epoch + 1}], Step [{i}], Loss: {running_loss / (i + 1):.4f}, Accuracy: {100 * correct / total:.2f}%')

# Evaluate the model
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

# Training the model
print("Training EfficientNet...")
train(model, train_loader, cuda=(device.type == 'cuda'))

# Testing the model
print("Testing Full-Precision Model...")
full_precision_accuracy = test_model(model, test_loader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 20.6MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 116MB/s] 


Training EfficientNet...
Epoch [1], Step [0], Loss: 2.3858, Accuracy: 9.38%
Epoch [1], Step [100], Loss: 2.2593, Accuracy: 17.30%
Epoch [1], Step [200], Loss: 2.1315, Accuracy: 22.99%
Epoch [1], Step [300], Loss: 2.0230, Accuracy: 27.57%
Epoch [1], Step [400], Loss: 1.9319, Accuracy: 30.81%
Epoch [1], Step [500], Loss: 1.8513, Accuracy: 33.83%
Epoch [1], Step [600], Loss: 1.7904, Accuracy: 36.18%
Epoch [1], Step [700], Loss: 1.7337, Accuracy: 38.25%
Epoch [2], Step [0], Loss: 1.1864, Accuracy: 59.38%
Epoch [2], Step [100], Loss: 1.3231, Accuracy: 53.12%
Epoch [2], Step [200], Loss: 1.2922, Accuracy: 54.12%
Epoch [2], Step [300], Loss: 1.2799, Accuracy: 54.34%
Epoch [2], Step [400], Loss: 1.2613, Accuracy: 55.10%
Epoch [2], Step [500], Loss: 1.2403, Accuracy: 55.94%
Epoch [2], Step [600], Loss: 1.2286, Accuracy: 56.45%
Epoch [2], Step [700], Loss: 1.2144, Accuracy: 56.93%
Epoch [3], Step [0], Loss: 1.1619, Accuracy: 62.50%
Epoch [3], Step [100], Loss: 1.1155, Accuracy: 61.23%
Epoch [3],

In [3]:
def quantize_and_test(model, test_loader, quant_type):

    # Move the model to CPU for quantization
    model.cpu()

    if quant_type == 'int8':
        quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)  # INT8 quantization
    elif quant_type == 'int16':
        quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.float16)  # INT16 simulation
    elif quant_type == 'int4':
        print("Simulating INT4 quantization using INT8.")
        quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)  # Simulate INT4
    else:
        raise ValueError(f"Unsupported quantization type: {quant_type}")

    # Test the quantized model
    quantized_model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            # Move data to CPU for testing
            images, labels = images.cpu(), labels.cpu()
            outputs = quantized_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"{quant_type.upper()} Quantized Model Accuracy: {accuracy:.2f}%")
    return accuracy

# Quantized model evaluations
quant_types = ['int4', 'int8', 'int16']
quantized_accuracies = {}
for qt in quant_types:
    quantized_accuracies[qt] = quantize_and_test(model, test_loader, qt)

# Display results
print("\nQuantization Results:")
for qt, acc in quantized_accuracies.items():
    print(f"{qt.upper()} Accuracy: {acc:.2f}%")

Simulating INT4 quantization using INT8.
INT4 Quantized Model Accuracy: 79.39%
INT8 Quantized Model Accuracy: 79.39%
INT16 Quantized Model Accuracy: 79.41%

Quantization Results:
INT4 Accuracy: 79.39%
INT8 Accuracy: 79.39%
INT16 Accuracy: 79.41%
