In [1]:
import torch
from pathlib import Path
from utils import evaluate_model, dynamic_quantization, static_quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models.ResNet50 import ResNet50Baseline
import json

In [2]:
supported_engines = torch.backends.quantized.supported_engines
if 'qnnpack' in supported_engines:
    torch.backends.quantized.engine = 'qnnpack'
if 'fbgemm' in supported_engines:
    torch.backends.quantized.engine = 'fbgemm'
saved_model_path = Path("./saved_models/resnet")
saved_model_path.mkdir(parents=True, exist_ok=True)
torch.backends.quantized.engine

'qnnpack'

In [3]:
train_path = Path('/Users/lakshya/quantization/tiny-imagenet-200/train')
test_path = Path('/Users/lakshya/quantization/tiny-imagenet-200/val/processed_val_2')

In [4]:
transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [5]:
train_dataset = datasets.ImageFolder(root=train_path, transform=transformations)
test_dataset = datasets.ImageFolder(root=test_path, transform=transformations)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [6]:
baseline_model = torch.load(saved_model_path / "resnet50_baseline_model.pth", map_location="cpu",weights_only=False)
baseline_model.eval()

ResNet50Baseline(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
   

In [16]:
# ===== dynamic quantisation on baseline model. Uses the function inside the utils.py script
quantized_dy_model = dynamic_quantization(baseline_model, (torch.randn(1, 3, 224, 224),))
quantized_dy_model

GraphModule(
  (model): Module(
    (conv1): ConvReLU2d(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
      (1): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
    

In [18]:
torch.save(quantized_dy_model.state_dict(), saved_model_path / 'resnet50_base_dy_quant_weights.pth')
torch.save(quantized_dy_model, saved_model_path / 'resnet50_base_dy_quant_model.pth')

In [10]:
quantized_dy_model = torch.load(saved_model_path / "resnet50_base_dy_quant_model.pth", map_location="cpu",weights_only=False)
quantized_dy_model.eval()

  device=storage.device,


GraphModule(
  (model): Module(
    (conv1): ConvReLU2d(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
      (1): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
    

In [7]:
baseline_metrics = evaluate_model(baseline_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [23:32<00:00,  9.00s/it, Loss=1.6259, Top1=62.42%]


In [8]:
baseline_metrics

{'top1_acc': 0.6242,
 'top5_acc': 0.8429,
 'total_inference_time': 1412.3039989471436,
 'average_inference_time': 8.9955668722748,
 'average_loss': 103.42139172402157,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [11]:
baseline_dy_metrics = evaluate_model(quantized_dy_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [23:36<00:00,  9.02s/it, Loss=1.6258, Top1=62.38%]


In [12]:
baseline_dy_metrics

{'top1_acc': 0.6238,
 'top5_acc': 0.8429,
 'total_inference_time': 1416.166141986847,
 'average_inference_time': 9.020166509470362,
 'average_loss': 103.40966021179393,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [15]:
# ======= Static Quantization. Uses the function inside the utils.py script
calibration_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(1000)), batch_size=64)
base_st_quant_model = static_quantization(baseline_model, (torch.randn(1, 3, 224, 224),), 
                                         calibration_loader)
base_st_quant_model

100%|███████████████████████████████████████████| 16/16 [02:34<00:00,  9.68s/it]


GraphModule(
  (model): Module(
    (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.014843720011413097, zero_point=0, padding=(3, 3))
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.004731349181383848, zero_point=0)
        (conv2): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.006016128696501255, zero_point=0, padding=(1, 1))
        (conv3): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=0.0090074697509408, zero_point=121)
        (downsample): Module(
          (0): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=0.01691831648349762, zero_point=155)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(256, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.004068697802722454, zero_point=0)
        (conv2): 

In [16]:
torch.save(base_st_quant_model.state_dict(), saved_model_path / 'resnet50_base_st_quant_weights.pth')
torch.save(base_st_quant_model, saved_model_path / 'resnet50_base_st_quant_model.pth')

In [17]:
baseline_st_metrics = evaluate_model(base_st_quant_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [05:53<00:00,  2.25s/it, Loss=1.6228, Top1=62.19%]


In [18]:
baseline_st_metrics

{'top1_acc': 0.6219,
 'top5_acc': 0.8425,
 'total_inference_time': 353.2770128250122,
 'average_inference_time': 2.250172056210269,
 'average_loss': 103.22040442752231,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [19]:
# ======= load the structured + unstructured style pruned ResNet50 model!
sparse_pruned_resnet_model = torch.load(saved_model_path / "resnet50_structured_pruned_SparseML40%_finalized_model.pth",
                                        map_location="cpu",weights_only=False)
sparse_pruned_resnet_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [20]:
# ======= Dyanamic Quantisation of the structured + unstructured style pruned ResNet50 model!
sparse_pruned_dy_quant = dynamic_quantization(sparse_pruned_resnet_model, 
                                              (torch.randn(1, 3, 224, 224),))

In [21]:
sparse_pruned_dy_quant

GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): ReLU(inplace=True)
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
      )
      (conv2): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (downsample): Module(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (relu): ReLU(inplace=True)
    )
    (1): Module(
      (conv1): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
      )
      (conv2): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [22]:
torch.save(sparse_pruned_dy_quant.state_dict(), saved_model_path / 'resnet50_sparse_prun_dy_quant_weights.pth')
torch.save(sparse_pruned_dy_quant, saved_model_path / 'resnet50_sparse_prun_dy_quant_model.pth')

In [23]:
sparse_prun_base_metrics = evaluate_model(sparse_pruned_resnet_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [20:49<00:00,  7.96s/it, Loss=1.6869, Top1=63.18%]


In [24]:
sparse_prun_base_metrics

{'top1_acc': 0.6318,
 'top5_acc': 0.8495,
 'total_inference_time': 1249.0078840255737,
 'average_inference_time': 7.955464229462253,
 'average_loss': 107.34647544933732,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [25]:
sparse_prun_dy_quant_metrics = evaluate_model(sparse_pruned_dy_quant, test_loader, 'resnet50')

100%|███████████████| 157/157 [19:12<00:00,  7.34s/it, Loss=1.6876, Top1=63.13%]


In [26]:
sparse_prun_dy_quant_metrics

{'top1_acc': 0.6313,
 'top5_acc': 0.849,
 'total_inference_time': 1152.8256080150604,
 'average_inference_time': 7.3428382676118495,
 'average_loss': 107.38760610446809,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [27]:
# ======= Static Quantisation of the structured + unstructured style pruned ResNet50 model!
sparse_pruned_st_quant_model = static_quantization(baseline_model, (torch.randn(1, 3, 224, 224),), 
                                         calibration_loader)
sparse_pruned_st_quant_model

100%|███████████████████████████████████████████| 16/16 [02:36<00:00,  9.76s/it]


GraphModule(
  (model): Module(
    (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.014843720011413097, zero_point=0, padding=(3, 3))
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.004731349181383848, zero_point=0)
        (conv2): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.006016128696501255, zero_point=0, padding=(1, 1))
        (conv3): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=0.0090074697509408, zero_point=121)
        (downsample): Module(
          (0): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=0.01691831648349762, zero_point=155)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(256, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.004068697802722454, zero_point=0)
        (conv2): 

In [28]:
torch.save(sparse_pruned_st_quant_model.state_dict(), saved_model_path / 'resnet50_sparse_prun_st_quant_weights.pth')
torch.save(sparse_pruned_st_quant_model, saved_model_path / 'resnet50_sparse_prun_st_quant_model.pth')

In [29]:
sparse_prun_st_quant_metrics = evaluate_model(sparse_pruned_st_quant_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [05:49<00:00,  2.22s/it, Loss=1.6228, Top1=62.19%]


In [30]:
sparse_prun_st_quant_metrics

{'top1_acc': 0.6219,
 'top5_acc': 0.8425,
 'total_inference_time': 349.1686441898346,
 'average_inference_time': 2.2240041031199658,
 'average_loss': 103.22040442752231,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [31]:
all_metrics = {
    "baseline_metrics": baseline_metrics,
    "baseline_dy_metrics": baseline_dy_metrics,
    "baseline_st_metrics": baseline_st_metrics,
    "baseline_sparse_prun_metrics": sparse_prun_base_metrics,
    "sparse_prun_dy_quant_metrics": sparse_prun_dy_quant_metrics,
    "sparse_prun_st_quant_metrics": sparse_prun_st_quant_metrics
}
metrics_folder = Path("./model_metrics/resnet50")
metrics_folder.mkdir(parents=True, exist_ok=True)

In [32]:
for name, metrics in all_metrics.items():
    with (metrics_folder / f"{name}.json").open("w") as file:
        json.dump(metrics, file, indent=1)