### Adopted from https://github.com/zama-ai/concrete-ml/tree/main/use_case_examples/cifar/cifar_brevitas_finetuning with modifications

### Requirements

In [1]:
!pip install  concrete-ml brevitas

Collecting concrete-ml
  Downloading concrete_ml-1.3.0-py3-none-any.whl (202 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m202.5/202.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting brevitas
  Downloading brevitas-0.10.0-py3-none-any.whl (600 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m600.8/600.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting boto3<2.0.0,>=1.23.5 (from concrete-ml)
  Downloading boto3-1.34.2-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting brevitas
  Downloading brevitas-0.8.0-py3-none-any.whl (357 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m357.3/357.3 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting certifi==2023.07.22 (from concrete-ml)
  Downloading certifi-2023.7.22-py3-none-any.whl (158 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Utils

In [1]:
import pickle as pkl
import random
import sys
import warnings
from collections import OrderedDict
from pathlib import Path
from time import time
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from brevitas import config
from concrete.fhe.compilation import Configuration
#from models import Fp32VGG11
from sklearn.metrics import top_k_accuracy_score
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
#from torchvision import datasets, transforms
#from torchvision.utils import make_grid
from tqdm import tqdm

from concrete.ml.torch.compile import compile_brevitas_qat_model

warnings.filterwarnings("ignore", category=UserWarning)

class FromTensors(Dataset):
    #def __init__(self, images_path, labels_path):
    def __init__(self, dataset_path):
        #self.images = torch.load(images_path)
        #self.labels = torch.load(labels_path).tolist()
        self.images, self.labels = torch.load(dataset_path)
        self.labels = self.labels.tolist()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return len(self.labels)

def load_torch_dataset(
        cluster_idx: int,
        base_path :str,
        train_set :bool) -> Dataset:

    split = 'train' if train_set else 'test'
    ds_name = f'cluster{cluster_idx:02d}' if cluster_idx>=0 else 'cluster_cls'
    return FromTensors( f'{base_path}/{ds_name}_{split}.pt')

def get_dataloader_precanned(
    param: Dict,
) -> Tuple[DataLoader, DataLoader]:
    """Returns the training and the test loaders of either CIFAR-10 or CIFAR-100 data-set.
    Args:
        param (Dict): Set of hyper-parameters
    Return:
        Tuple[DataLoader, DataLoader]: Training and test data loaders.
    """
    dataset_name = param["dataset_name"]
    path = param["dir"]

    if param['train_required']:
        train_dataset = load_torch_dataset(param["cluster_idx"], path, True)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size= param["batch_size"],
            shuffle=True,
            drop_last=True,
        )
    else:
        train_loader = None

    if param['test_required']:
        test_dataset = load_torch_dataset(param["cluster_idx"], path, False)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size= param["batch_size"],
            shuffle=True,
            drop_last=True,
        )
    else:
        test_loader = None

    return (train_loader, test_loader)


def plot_history(param: Dict, load: bool = False) -> None:
    """Display the loss and accuracy for the test and training sets.

    Args:
        param (Dict): Set of hyper-parameters to use depending on whether
            CIFAR-10 or CIFAR-100 is used.
        load (bool): If True, we upload the stored param.
    Returns:
        None.
    """

    if load:
        with open(
            f"{param['dir']}/{param['training']}/{param['dataset_name']}_history.pkl", "br"
        ) as f:
            param = pkl.load(f)

    fig, axes = plt.subplots(1, 2, figsize=(15, 8))

    # Plot the training and test loss.
    axes[0].plot(
        range(len(param["loss_train_history"])),
        param["loss_train_history"],
        label="Train loss",
        marker="o",
    )
    axes[0].plot(
        range(len(param["loss_test_history"])),
        param["loss_test_history"],
        label="Test loss",
        marker="s",
    )

    axes[0].set_title("Loss")
    axes[0].set_ylabel("Loss")
    axes[0].set_xlabel("Epochs")
    axes[0].legend(loc="best")
    axes[0].grid(True)

    # Plot the training and test accuracy.
    axes[1].plot(
        range(len(param["accuracy_train"])),
        param["accuracy_train"],
        label="Train accuracy",
        marker="o",
    )
    axes[1].plot(
        range(len(param["accuracy_test"])),
        param["accuracy_test"],
        label="Test accuracy",
        marker="s",
    )

    axes[1].set_title("Top-1 accuracy")
    axes[1].set_ylabel("Accuracy")
    axes[1].set_xlabel("Epochs")
    axes[1].legend(loc="best")
    axes[1].grid(True)

    fig.tight_layout()
    plt.show()


def plot_baseline(param: Dict, data: DataLoader, device: str) -> None:
    """
    Display the test accuracy given `param` arguments
    that we got using Transfer Learning and QAT approaches.

    Args:
        param (Dict): Set of hyper-parameters to use depending on whether
            CIFAR-10 or CIFAR-100 is used.
        data (DataLoader): Test set.
        device (str): Device type.

    Returns:
        None
    """
    # The accuracy of the counterpart pre-trained model in fp32 will be used as a baseline.
    # That we try to catch up during the Quantization Aware Training.
    checkpoint = torch.load(f"{param['dir']}/{param['pre_trained_path']}", map_location=device)
    fp32_vgg = Fp32VGG11(param["output_size"])
    fp32_vgg.load_state_dict(checkpoint)
    baseline = torch_inference(fp32_vgg, data, param, device)

    plt.plot(
        range(len(param["accuracy_test"])),
        param["accuracy_test"],
        marker="o",
        label="Test accuracy",
    )
    plt.text(x=0, y=baseline + 0.01, s=f"Baseline = {baseline * 100: 2.2f}%", fontsize=15, c="red")
    plt.plot(range(len(param["accuracy_test"])), [baseline] * len(param["accuracy_test"]), "r--")

    plt.title(f"Accuracy on the testing set with {param['dataset_name']}")
    plt.legend(loc="best")
    plt.ylabel("Accuracy")
    plt.xlabel("Epochs")
    plt.xlim(-0.3, 4.2)
    plt.ylim(0, 1)
    plt.grid(True)
    plt.show()


def train(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    param: Dict,
    step: int = 1,
    device: str = "cpu",
) -> nn.Module:
    """Training the model.

    Args:
        model (nn.Module): A PyTorch or Brevitas network.
        train_loader (DataLoader): The training set.
        test_loader (DataLoader): The test set.
        param (Dict): Set of hyper-parameters to use depending on whether
            CIFAR-10 or CIFAR-100 is used.
        step (int): Display the loss and accuracy every `epoch % step`.
        device (str): Device type.
    Returns:
        nn.Module: the trained model.
    """

    torch.manual_seed(param["seed"])
    random.seed(param["seed"])

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=param["lr"])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=param["milestones"], gamma=param["gamma"]
    )

    # To avoid breaking up the tqdm bar
    with tqdm(total=param["epochs"], file=sys.stdout) as pbar:

        for i in range(param["epochs"]):
            # Train the model
            model.train()
            loss_batch_train, accuracy_batch_train = [], []

            for x, y in train_loader:
                x, y = x.to(device), y.to(device)

                optimizer.zero_grad()
                yhat = model(x)
                loss_train = param["criterion"](yhat, y)
                loss_train.backward()
                optimizer.step()

                loss_batch_train.append(loss_train.item())
                accuracy_batch_train.extend((yhat.argmax(1) == y).cpu().float().tolist())

            if scheduler:
                scheduler.step()

            param["accuracy_train"].append(np.mean(accuracy_batch_train))
            param["loss_train_history"].append(np.mean(loss_batch_train))

            # Evaluation during training:
            # Disable autograd engine (no backpropagation)
            # To reduce memory usage and to speed up computations
            with torch.no_grad():
                # Notify batchnormalization & dropout layers to work in eval mode
                model.eval()
                loss_batch_test, accuracy_batch_test = [], []
                for x, y in test_loader:
                    x, y = x.to(device), y.to(device)
                    yhat = model(x)
                    loss_test = param["criterion"](yhat, y)
                    loss_batch_test.append(loss_test.item())
                    accuracy_batch_test.extend((yhat.argmax(1) == y).cpu().float().tolist())

                param["accuracy_test"].append(np.mean(accuracy_batch_test))
                param["loss_test_history"].append(np.mean(loss_batch_test))

            if i % step == 0:
                pbar.write(
                    f"Epoch {i:2}: Train loss = {param['loss_train_history'][-1]:.4f} "
                    f"VS Test loss = {param['loss_test_history'][-1]:.4f} - "
                    f"Accuracy train: {param['accuracy_train'][-1]:.4f} "
                    f"VS Accuracy test: {param['accuracy_test'][-1]:.4f}"
                )
                pbar.update(step)

    # Save the state_dict
    dir = Path(".") / param["dir"] / param["training"]
    dir.mkdir(parents=True, exist_ok=True)
    torch.save(
        model.state_dict(), f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt"
    )

    with open(f"{dir}/{param['dataset_name']}_history.pkl", "wb") as f:
        pkl.dump(param, f)

    torch.cuda.empty_cache()

    return model


def torch_inference(
    model: nn.Module,
    data: DataLoader,
    device: str = "cpu",
    verbose: bool = False,
) -> float:

    """Returns the `top_k` accuracy.

    Args:
        model (nn.Module): A PyTorch or Brevitas network.
        data (DataLoader): The test or evaluation set.
        device (str): Device type.
        verbose (bool): For display.
    Returns:
        float: The top_k accuracy.
    """
    correct = []
    total_example = 0
    model = model.to(device)

    with torch.no_grad():
        model.eval()
        for x, y in tqdm(data, disable=verbose is False):
            x, y = x.to(device), y
            yhat = model(x).cpu().detach()
            correct.append(yhat.argmax(1) == y)
            total_example += len(x)

    return np.mean(np.vstack(correct), dtype="float64")


def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
    """Test if the model is FHE-compatible.

    Args:
        model (Callable): The Brevitas model.
        data (DataLoader): The data loader.

    Returns:
        Callable: Quantized model.
    """

    qmodel = compile_brevitas_qat_model(
        model.to("cpu"),
        # Training
        torch_inputset=data,
        show_mlir=False,
        output_onnx_file="test.onnx",
    )

    return qmodel


def mapping_keys(pre_trained_weights: Dict, model: nn.Module, device: str) -> nn.Module:

    """
    Initialize the quantized model with pre-trained fp32 weights.

    Args:
        pre_trained_weights (Dict): The state_dict of the pre-trained fp32 model.
        model (nn.Module): The Brevitas model.
        device (str): Device type.

    Returns:
        Callable: The quantized model with the pre-trained state_dict.
    """

    # Brevitas requirement to ignore missing keys
    config.IGNORE_MISSING_KEYS = True

    old_keys = list(pre_trained_weights.keys())
    new_keys = list(model.state_dict().keys())
    new_state_dict = OrderedDict()

    for old_key, new_key in zip(old_keys, new_keys):
        new_state_dict[new_key] = pre_trained_weights[old_key]

    model.load_state_dict(new_state_dict)
    model = model.to(device)

    return model


def fhe_simulation_inference(quantized_module, data_loader, verbose: bool = False) -> float:
    """Evaluate the model in FHE simulation mode.

    Args:
        quantized_module (Callable): The quantized module.
        data_loader (int): The test or evaluation set.
        verbose (bool): For display.

    Returns:
        float: The accuracy measured through FHE simulation
    """
    correct_sim = []
    total_example = 0

    disable_tqdm = not verbose
    for data, labels in tqdm(data_loader, disable=disable_tqdm):

        data, labels = data.detach().cpu().numpy(), labels.detach().cpu().numpy()

        # Store the predicted quantized probabilities
        predictions = quantized_module.forward(data, fhe="simulate")

        total_example += data.shape[0]

        # Store the class predictions
        correct_sim.extend(predictions.argmax(1) == labels)

    acc = np.mean(correct_sim, dtype="float64")
    return acc

### Models

In [2]:
import brevitas
import brevitas.nn as qnn
import torch
import torch.nn as nn
from brevitas.quant import Int8ActPerTensorFloat, Int8WeightPerTensorFloat

""" In this models.py we provide the code for the PyTorch and Brevitas networks."""

# This architecture is inspired by the original VGG-11 network available in
# PyTorch.hub (https://pytorch.org/hub/pytorch_vision_vgg/)

# Each tuple refers to a PyTorch or Brevitas layer:
# I: QuantIdentity layer, only required for the Brevitas network. Mainly used to quantize
# the input data or to encapsulate a PyTorch layer inside the Brevitas model.
# C: Convolutional layer.
# P: Pooling layer, we replaced the original `MaxPool2d` in VGG-11 by a `AvgPool2d` layer.
# Because in the current version of Concrete ML `MaxPool2d` isn't available yet.
# R: ReLU activation.
FEATURES_MAPS = [
    ("I",),
    ("C", 3, 64, 3, 1, 1),
    ("R",),
    ("P", 2, 2, 0, 1, False),
    ("I",),
    ("C", 64, 128, 3, 1, 1),
    ("R",),
    ("P", 2, 2, 0, 1, False),
    ("I",),
    ("C", 128, 256, 3, 1, 1),
    ("R",),
    ("C", 256, 256, 3, 1, 1),
    ("R",),
    ("P", 2, 2, 0, 1, False),
    ("I",),
    ("C", 256, 512, 3, 1, 1),
    ("R",),
    ("C", 512, 512, 3, 1, 1),
    ("R",),
    ("P", 2, 2, 0, 1, False),
    ("I",),
    ("C", 512, 512, 3, 1, 1),
    ("R",),
    ("C", 512, 512, 3, 1, 1),
    ("R",),
    ("P", 2, 2, 0, 1, False),
    ("I",),
    ("P", 2, 2, 0, 1, False),  # this and the next layer are needed in QuantVGG11 but not in Fp32VGG11
    ("I",),
    ]

class QuantVGG11(nn.Module):
    def __init__(
        self,
        bit: int,
        output_size: int = 3,
        act_quant: brevitas.quant = Int8ActPerTensorFloat,
        weight_quant: brevitas.quant = Int8WeightPerTensorFloat,
    ):
        """A quantized network with Brevitas.

        Args:
            bit (int): Bit of quantization.
            output_size (int): Number of classes.
            act_quant (brevitas.quant): Quantization protocol of activations.
            weight_quant (brevitas.quant): Quantization protocol of the weights.

        """
        super(QuantVGG11, self).__init__()
        self.bit = bit

        def tuple2quantlayer(t):
            if t[0] == "R":
                return qnn.QuantReLU(return_quant_tensor=True, bit_width=bit, act_quant=act_quant)
            if t[0] == "P":
                return nn.AvgPool2d(kernel_size=t[1], stride=t[2], padding=t[3], ceil_mode=t[5])
            if t[0] == "C":
                return qnn.QuantConv2d(
                    t[1],
                    t[2],
                    kernel_size=t[3],
                    stride=t[4],
                    padding=t[5],
                    weight_bit_width=bit,
                    weight_quant=weight_quant,
                    return_quant_tensor=True,
                )
            if t[0] == "L":
                return qnn.QuantLinear(
                    in_features=t[1],
                    out_features=t[2],
                    weight_bit_width=bit,
                    weight_quant=weight_quant,
                    bias=True,
                    return_quant_tensor=True,
                )
            if t[0] == "I":
                # According to the literature, the first layer holds the most information
                # about the input data. So, it is possible to quantize the input using more
                # precision bit-width than the rest of the network.
                identity_quant = t[1] if len(t) == 2 else bit
                return qnn.QuantIdentity(
                    bit_width=identity_quant, act_quant=act_quant, return_quant_tensor=True
                )

        # The very first layer is a `QuantIdentity` layer, which is very important
        # to ensure that the input data is also quantized.
        self.features = nn.Sequential(*[tuple2quantlayer(t) for t in FEATURES_MAPS])

        # self.identity1 and self.identity2 are used to encapsulate the `torch.flatten`.
        self.identity1 = qnn.QuantIdentity(
            bit_width=bit, act_quant=act_quant, return_quant_tensor=True
        )

        self.identity2 = qnn.QuantIdentity(
            bit_width=bit, act_quant=act_quant, return_quant_tensor=True
        )

        # Fully connected linear layer.
        self.final_layer = qnn.QuantLinear(
            in_features=512 * 3 * 3,
            out_features=output_size,
            weight_quant=weight_quant,
            weight_bit_width=bit,
            bias=True,
            return_quant_tensor=True,
        )

    def forward(self, x):
        x = self.features(x)
        x = self.identity1(x)
        # As `torch.flatten` is a PyTorch layer, you must place it between two `QuantIdentity`
        # layers to ensure that all intermediate values of the network are properly quantized.
        x = torch.flatten(x, 1)
        # Replace `x.view(x.shape[0], -1)` by `torch.flatten(x, 1)` which is an equivalent
        # But is compatible with Concrete ML.
        x = self.identity2(x)
        x = self.final_layer(x)
        return x.value

## Settings

In [3]:
import torch

bit = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device Type: {device}")

Device Type: cuda


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Quantization Aware Training (QAT)

In [6]:
param = {
    "output_size": 105,
    "batch_size": 16,
    "training": "quant",
    "cluster_idx": 0,
    "dataset_name": "cluster00_large_150",
    "mean": [0.485, 0.456, 0.406],
    "std": [0.229, 0.224, 0.225],
    "lr": 6e-5,
    "seed": 727,
    "epochs": 10,
    "gamma": 0.01,
    "milestones": [3, 5],
    "criterion": torch.nn.CrossEntropyLoss(),
    "accuracy_test": [],
    "accuracy_train": [],
    "loss_test_history": [],
    "loss_train_history": [],
    "dir": "/content/drive/MyDrive/Colab Notebooks/Zama/f32_pretrained",
    "pre_trained_model": "cluster00_SPEC_fp32__state_dict.pt",
    'train_required': True,
    'test_required': True,
}

## Quant aware classiication training
- Clusters from 0 to num_clusters (we use num_clusters=80)
- cluster_idx = -1 is a special case, it indicates Cluster classification dataset and model


In [6]:
for cluster_idx in range(-1,0):
    param["cluster_idx"] = cluster_idx
    param["dataset_name"] = f"cluster{cluster_idx:02d}_large_150" if cluster_idx >= 0 else 'cluster_cls'
    param["pre_trained_model"] = f"cluster{cluster_idx:02d}_SPEC_fp32_state_dict.pt" if cluster_idx >= 0 else f"cluster_cls_SPEC_fp32_state_dict.pt"

    # Load dataset
    train_loader, test_loader = get_dataloader_precanned(param= param)

    # fp32 pretrained model
    checkpoint = torch.load(f"{param['dir']}/{param['pre_trained_model']}", map_location=device)
    param["output_size"] = checkpoint['final_layer.bias'].shape.numel()
    print(f"Cluster {cluster_idx:02d} output size = {param['output_size']}")

    # Quantized model
    quant_vgg = QuantVGG11(bit=bit, output_size=param["output_size"])

    # Mapping keys from pretrained to quantized model
    quant_vgg = mapping_keys(checkpoint, quant_vgg, device)

    # Train
    quant_vgg = train(quant_vgg, train_loader, test_loader, param, device=device)

Cluster -1 output size = 80
Epoch  0: Train loss = 3.3306 VS Test loss = 2.9646 - Accuracy train: 0.1624 VS Accuracy test: 0.2262
Epoch  1: Train loss = 2.8746 VS Test loss = 2.8378 - Accuracy train: 0.2337 VS Accuracy test: 0.2367
Epoch  2: Train loss = 2.6408 VS Test loss = 2.7533 - Accuracy train: 0.2873 VS Accuracy test: 0.2519
Epoch  3: Train loss = 2.2218 VS Test loss = 2.6132 - Accuracy train: 0.3854 VS Accuracy test: 0.2902
Epoch  4: Train loss = 2.1421 VS Test loss = 2.5947 - Accuracy train: 0.4041 VS Accuracy test: 0.2982
Epoch  5: Train loss = 2.1166 VS Test loss = 2.5970 - Accuracy train: 0.4131 VS Accuracy test: 0.2925
Epoch  6: Train loss = 2.1139 VS Test loss = 2.5962 - Accuracy train: 0.4133 VS Accuracy test: 0.2917
Epoch  7: Train loss = 2.1130 VS Test loss = 2.5971 - Accuracy train: 0.4135 VS Accuracy test: 0.2942
Epoch  8: Train loss = 2.1129 VS Test loss = 2.5956 - Accuracy train: 0.4126 VS Accuracy test: 0.2913
Epoch  9: Train loss = 2.1152 VS Test loss = 2.5961 - 

## End to End test

In [11]:
from copy import deepcopy

def instantiate_model(model_path):
    # Load cluster classification model state_dict
    state_dict = torch.load( model_path, map_location=device)
    output_size = state_dict['final_layer.bias'].shape.numel()

    # Create model template
    model = QuantVGG11(bit=bit, output_size=output_size)

    # Copy the trained state_dict
    model.load_state_dict(deepcopy(state_dict), strict=False)
    # Move model to device
    model = model.to(device)

    return model

In [None]:
import numpy as np
import pandas as pd

results = []

# Instantiate cluster classification model
cluster_cls_model = instantiate_model("/content/drive/MyDrive/Colab Notebooks/Zama/f32_pretrained/quant/cluster_cls_quant_state_dict.pt")

# Test for available clusters (ones having intra-cluster classificaton models trained)
for cluster_idx in range(0, 6):
    param["cluster_idx"] = cluster_idx
    param['batch_size'] = 1
    param['train_required'] = False
    param['test_required'] = True

    # Instantiate intra-cluster track classification model
    intra_cluster_model = instantiate_model(f"/content/drive/MyDrive/Colab Notebooks/Zama/f32_pretrained/quant/cluster{cluster_idx:02d}_large_150_quant_state_dict.pt")

    # Test with various augmentation layers
    for augm in ['small', 'medium', 'large']:

        total_samples = 0
        top1_cls_correct, top1_e2e_correct = 0, 0
        top3_cls_correct, top3_e2e_correct = 0, 0

        param['dir'] = f"/content/drive/MyDrive/Colab Notebooks/Zama/e2e_test_data/{augm}"
        param["dataset_name"] = f"cluster{cluster_idx:02d}_large_150" if cluster_idx >= 0 else 'cluster_cls'

        # Load dataset
        _, test_loader = get_dataloader_precanned(param= param)   # only test end2end datasets

        # Run intra-cluster track classification inference
        # Collect all correctly classfied tracks (within the given cluster)
        intra_cluster_correct = []
        with torch.no_grad():
            intra_cluster_model.eval()
            for x, track_id in test_loader:
                x = x.to(device)
                track_preds = intra_cluster_model(x).cpu().detach()
                track_preds = track_preds.numpy()

                total_samples += 1
                if track_id == np.argmax(track_preds):
                    intra_cluster_correct.append(track_id)
                break

        # Run cluster classification inference
        with torch.no_grad():
            cluster_cls_model.eval()
            for x, track_id in test_loader:
                x = x.to(device)
                cluster_preds = cluster_cls_model(x).cpu().detach()
                cluster_preds = cluster_preds.numpy()

                if cluster_idx == np.argmax(cluster_preds):
                    top1_cls_correct += 1
                if cluster_idx == np.argmax(cluster_preds) and track_id in intra_cluster_correct:
                    top1_e2e_correct += 1
                if cluster_idx in np.argsort(cluster_preds)[-3:]:
                    top3_cls_correct += 1
                if cluster_idx in np.argsort(cluster_preds)[-3:] and track_id in intra_cluster_correct:
                    top3_e2e_correct += 1
                break

        results.append({
            "cluster": cluster_idx,
            "augmentation": augm,
            "top1_cls_correct": int(top1_cls_correct/total_samples*100),
            "top3_cls_correct": int(top3_cls_correct/total_samples*100),
            "top1_e2e_correct": int(top1_e2e_correct/total_samples*100),
            "top3_e2e_correct": int(top3_e2e_correct/total_samples*100)
        })

df = pd.DataFrame(results, columns=["cluster", "augmentation", "top1_cls_correct", "top3_cls_correct", "top1_e2e_correct", "top3_e2e_correct"])
df.to_csv("/content/drive/MyDrive/Colab Notebooks/Zama/e2e_test_results.csv", index=False)


## Backup

In [None]:
!pip install torchinfo
from torchinfo import summary

In [57]:
quant_vgg = QuantVGG11(bit=bit, output_size=105)
summary(quant_vgg, input_size=[16, 3, 224, 224])

Layer (type:depth-idx)                                                 Output Shape              Param #
QuantVGG11                                                             [16, 105]                 --
├─Sequential: 1-1                                                      [16, 512, 3, 3]           9,220,494
│    └─QuantIdentity: 2-1                                              [16, 3, 224, 224]         --
│    │    └─ActQuantProxyFromInjector: 3-1                             [16, 3, 224, 224]         --
│    │    └─ActQuantProxyFromInjector: 3-2                             [16, 3, 224, 224]         1
├─QuantIdentity: 1-30                                                  --                        (recursive)
│    └─ActQuantProxyFromInjector: 2-58                                 --                        (recursive)
│    │    └─FusedActivationQuantProxy: 3-91                            --                        (recursive)
├─Sequential: 1-31                                            

In [50]:
fp32_vgg = Fp32VGG11(output_size=105)
summary(fp32_vgg, input_size=[16, 3, 224, 224])

Layer (type:depth-idx)                   Output Shape              Param #
Fp32VGG11                                [16, 105]                 --
├─Sequential: 1-1                        [16, 512, 7, 7]           --
│    └─Conv2d: 2-1                       [16, 64, 224, 224]        1,792
│    └─ReLU: 2-2                         [16, 64, 224, 224]        --
│    └─AvgPool2d: 2-3                    [16, 64, 112, 112]        --
│    └─Conv2d: 2-4                       [16, 128, 112, 112]       73,856
│    └─ReLU: 2-5                         [16, 128, 112, 112]       --
│    └─AvgPool2d: 2-6                    [16, 128, 56, 56]         --
│    └─Conv2d: 2-7                       [16, 256, 56, 56]         295,168
│    └─ReLU: 2-8                         [16, 256, 56, 56]         --
│    └─Conv2d: 2-9                       [16, 256, 56, 56]         590,080
│    └─ReLU: 2-10                        [16, 256, 56, 56]         --
│    └─AvgPool2d: 2-11                   [16, 256, 28, 28]         -