# Fine-tuning and Quantizing a Vision Transformer for Edge Devices

This notebook demonstrates how to optimize a pre-trained Vision Transformer (ViT) model for deployment on edge devices using various quantization and optimization techniques.

### Table of Contents

1. Setup and Imports
2. Load Pre-trained ViT Model
3. Prepare Dataset
4. Mixed Precision Fine-tuning
5. Quantization-Aware Fine-tuning (QAF)
6. Pruning for Model Size Reduction
7. Layer Fusion
8. Efficient Quantized Attention Mechanisms
9. Optimization for Target Hardware
10. Performance Benchmarking


## 1. Setup and Imports

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import numpy as np
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

## 2. Load Pre-trained ViT Model

Now, let's load a pre-trained ViT model:

In [None]:
model_name = 'vit_base_patch16_224'
model = timm.create_model(model_name, pretrained=True)
print(f"Loaded {model_name}")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Get the configuration for data preprocessing
config = resolve_data_config({}, model=model)
transform = create_transform(**config)

## 3. Prepare Dataset

For this example, we'll use a subset of the CIFAR-10 dataset:

In [None]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Use a subset of the data for faster training
train_subset = torch.utils.data.Subset(train_dataset, range(10000))
val_subset = torch.utils.data.Subset(val_dataset, range(1000))

train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=64, shuffle=False)

## 4. Mixed Precision Fine-tuning

Let's implement mixed precision training for the initial fine-tuning:

In [None]:
def train_epoch(model, loader, optimizer, scaler, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

# Fine-tuning setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scaler = GradScaler()

# Fine-tuning loop
epochs = 5
for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, criterion)
    val_loss, val_acc = validate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}/{epochs}:")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

print("Mixed precision fine-tuning completed")

## 5. Quantization-Aware Fine-tuning (QAF)

Now, let's implement Quantization-Aware Fine-tuning:

In [None]:
import torch.quantization

class QuantizedViT(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.model = model
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

def apply_qaf(model, train_loader, val_loader, epochs=3):
    # Prepare model for QAT
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    model_qat = torch.quantization.prepare_qat(model)
    
    # Keep first and last layers in FP16
    model_qat.patch_embed = model.patch_embed
    model_qat.head = model.head
    
    optimizer = optim.AdamW(model_qat.parameters(), lr=1e-5, weight_decay=1e-2)
    scaler = GradScaler()
    
    for epoch in range(epochs):
        # Gradually increase quantization noise
        model_qat.apply(torch.quantization.disable_observer)
        model_qat.apply(torch.quantization.enable_fake_quant)
        if epoch > 0:
            model_qat.apply(torch.quantization.enable_observer)
        
        train_loss, train_acc = train_epoch(model_qat, train_loader, optimizer, scaler, criterion)
        val_loss, val_acc = validate(model_qat, val_loader, criterion)
        print(f"QAT Epoch {epoch+1}/{epochs}:")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Convert to quantized model
    model_quantized = torch.quantization.convert(model_qat.eval(), inplace=False)
    return model_quantized

model_quantized = apply_qaf(model, train_loader, val_loader)
print("Quantization-Aware Fine-tuning completed")

## 6. Pruning for Model Size Reduction

Let's implement attention head pruning:

In [None]:
def prune_attention_heads(model, prune_ratio=0.3):
    for module in model.modules():
        if isinstance(module, timm.models.vision_transformer.Attention):
            num_heads = module.num_heads
            num_prune = int(num_heads * prune_ratio)
            head_importance = torch.norm(module.qkv.weight.view(3, num_heads, -1), dim=2)
            head_importance = head_importance.mean(dim=0)
            _, indices = torch.topk(head_importance, k=num_prune, largest=False)
            
            # Create a mask to zero out pruned heads
            mask = torch.ones(num_heads)
            mask[indices] = 0
            mask = mask.repeat(3).unsqueeze(1).unsqueeze(1)
            
            # Apply mask to weights and bias
            module.qkv.weight.data *= mask
            if module.qkv.bias is not None:
                module.qkv.bias.data *= mask.squeeze()

prune_attention_heads(model_quantized)
print("Attention head pruning completed")

## 7. Layer Fusion

For layer fusion, we'll focus on fusing batch normalization layers:

In [None]:
def fuse_bn_recursively(model):
    for module_name, module in model.named_children():
        if list(module.named_children()):
            fuse_bn_recursively(module)
        if isinstance(module, torch.nn.BatchNorm2d):
            setattr(model, module_name, torch.nn.Identity())

fuse_bn_recursively(model_quantized)
print("Layer fusion completed")

## 8. Efficient Quantized Attention Mechanisms

We've already implemented quantized attention mechanisms in our QAF process. Here's a function to verify the quantization:

In [None]:
def verify_quantized_attention(model):
    for module in model.modules():
        if isinstance(module, timm.models.vision_transformer.Attention):
            print(f"Attention module: {module}")
            print(f"QKV weight quantized: {isinstance(module.qkv.weight, torch.quantization.Quantized)}")
            print(f"Projection weight quantized: {isinstance(module.proj.weight, torch.quantization.Quantized)}")

verify_quantized_attention(model_quantized)

## 9. Optimization for Target Hardware

For this step, we'll use TorchScript to optimize the model for deployment:

In [None]:
def optimize_for_mobile(model):
    model.eval()
    example_input = torch.rand(1, 3, 224, 224).to(device)
    traced_model = torch.jit.trace(model, example_input)
    optimized_model = torch.jit.optimize_for_mobile(traced_model)
    return optimized_model

optimized_model = optimize_for_mobile(model_quantized)
optimized_model.save("optimized_vit_edge.pt")
print("Model optimized and saved for mobile deployment")

## 10. Performance Benchmarking

Finally, let's benchmark our optimized model:

In [None]:
def benchmark(model, input_shape, num_runs=100):
    model.eval()
    input_tensor = torch.rand(input_shape).to(device)
    
    start_time = torch.cuda.Event(enable_timing=True)
    end_time = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        # Warm-up run
        for _ in range(10):
            _ = model(input_tensor)
        
        # Timed runs
        start_time.record()
        for _ in range(num_runs):
            _ = model(input_tensor)
        end_time.record()
    
    torch.cuda.synchronize()
    elapsed_time = start_time.elapsed_time(end_time) / num_runs
    return elapsed_time

original_time = benchmark(model, (1, 3, 224, 224))
quantized_time = benchmark(model_quantized, (1, 3, 224, 224))
optimized_time = benchmark(optimized_model, (1, 3, 224, 224))

print(f"Original model inference time: {original_time:.2f} ms")
print(f"Quantized model inference time: {quantized_time:.2f} ms")
print(f"Optimized model inference time: {optimized_time:.2f} ms")
print(f"Speedup: {original_time / optimized_time:.2f}x")