In [1]:
import torch
import torchvision
import time
from torch.quantization import quantize_qat, prepare_qat, convert
from torchvision.models.quantization import resnet18
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

In [None]:

# Simulated training loop for QAT (Here we use a small dataset)
# Replace with your actual dataset and training loop as needed
class RandomDataset(Dataset):
    def __init__(self, num_samples, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = torch.rand(3, 224, 224)
        label = torch.randint(0, self.num_classes, (1,)).item()
        return image, label

# DataLoader with random data for demonstration
train_loader = DataLoader(RandomDataset(100, 50), batch_size=32)


In [None]:

# Load and prepare a ResNet18 model for quantization-aware training
original_model = resnet18(weights=True)
original_model.eval()



In [None]:
# Helper function to measure inference time
def measure_inference_time(model, data_loader, num_batches=5):
    start_time = time.time()
    with torch.no_grad():
        for i, (images, _) in enumerate(data_loader):
            if i >= num_batches:
                break
            _ = model(images)
    end_time = time.time()
    avg_inference_time = (end_time - start_time) / num_batches
    return avg_inference_time

# Prepare data loader for inference
test_loader = DataLoader(RandomDataset(100, 50), batch_size=16)

# Measure inference times for the original and quantized models
original_inference_time = measure_inference_time(original_model, test_loader)

print(f"Original model inference time (avg per batch): {original_inference_time:.4f} seconds")
print(f"Quantized model inference time (avg per batch): {quantized_inference_time:.4f} seconds")

Original model inference time (avg per batch): 2.1879 seconds
Quantized model inference time (avg per batch): 5.5030 seconds
