# Import Stuff

In [1]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install opencv-python numpy seaborn matplotlib scikit-learn ipykernel tqdm pillow
%pip install onnx onnxruntime quanto

Looking in indexes: https://download.pytorch.org/whl/cu118
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import time
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
import onnx
import onnxruntime
import quanto
from tqdm import tqdm

from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data

from datasets import dataset_utils
from matching import matching
from evaluation.metrics import createPR, recallAt100precision, recallAtK
from datasets.load_dataset import GardensPointDataset, SFUDataset, StLuciaDataset

In [3]:
# GardensPointDataset().load()
# SFUDataset().load()
# StLuciaDataset().load()

# Constants

In [4]:
WEIGHTS_FILE = "calc.caffemodel.pt"
ITERATIONS = 100 # for testing average duration

# Preprocess Images

In [5]:
class ConvertToYUVandEqualizeHist:
    def __call__(self, img):
        img_yuv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2YUV)
        img_yuv[:, :, 0] = cv2.equalizeHist(img_yuv[:, :, 0])
        img_rgb = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
        return Image.fromarray(img_rgb)

preprocess = transforms.Compose(
    [
        ConvertToYUVandEqualizeHist(),
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((120, 160), interpolation=Image.BICUBIC),
        transforms.ToTensor(),
    ]
)

In [6]:
class CustomImageDataset(Dataset):
    def __init__(self, name, folder, transform=None):
        
        self.name = os.path.basename(name)
        self.folder = os.path.join(name, folder)
        self.image_paths = dataset_utils.read_images_paths(self.folder, get_abs_path=True)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index) :
        image_path = self.image_paths[index]
        img = Image.open(image_path)
        if self.transform:
            img = self.transform(img)
        return(img)

In [7]:
dataset_db = CustomImageDataset("images/SFU", "dry", preprocess)
dataset_q = CustomImageDataset("images/SFU", "jan", preprocess)

print("Dataset Length:", len(dataset_db))
dataset_db[0]

Dataset Length: 385


tensor([[[0.3098, 0.4431, 0.6235,  ..., 0.0314, 0.0196, 0.0196],
         [0.1882, 0.4471, 0.7020,  ..., 0.0471, 0.0353, 0.0353],
         [0.1412, 0.4392, 0.6510,  ..., 0.0431, 0.0314, 0.0353],
         ...,
         [0.7882, 0.8157, 0.8314,  ..., 0.1529, 0.1176, 0.0824],
         [0.7961, 0.8118, 0.8353,  ..., 0.1333, 0.0902, 0.0627],
         [0.7843, 0.8000, 0.8196,  ..., 0.0980, 0.0588, 0.0431]]])

In [8]:
batch_size = 64
num_workers = 8
db_dataloader = DataLoader(dataset_db, batch_size=batch_size, shuffle=False, num_workers=num_workers)
q_dataloader = DataLoader(dataset_q, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Model Definition

In [9]:
class CalcModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.input_dim = (1, 120, 160)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(5, 5), stride=2, padding=4)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=1, padding=2)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = nn.Conv2d(128, 4, kernel_size=(3, 3), stride=1, padding=0)
        self.relu3 = nn.ReLU(inplace=False)
        self.pool = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
        self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool(x)
        x = self.lrn1(x)

        x = self.relu2(self.conv2(x))
        x = self.pool(x)
        x = self.lrn2(x)

        x = self.relu3(self.conv3(x))
        x = torch.flatten(x, 1)
        return x

In [10]:
class CalcModelCompiled(nn.Module):
    def __init__(self):
        super().__init__()

        self.input_dim = (1, 120, 160)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(5, 5), stride=2, padding=4)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=(4, 4), stride=1, padding=2)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = nn.Conv2d(128, 4, kernel_size=(3, 3), stride=1, padding=0)
        self.relu3 = nn.ReLU(inplace=False)
        self.pool = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
        self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)

    @torch.compile
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool(x)
        x = self.lrn1(x)

        x = self.relu2(self.conv2(x))
        x = self.pool(x)
        x = self.lrn2(x)

        x = self.relu3(self.conv3(x))
        x = torch.flatten(x, 1)
        return x

### Normal Model

In [11]:
calc = CalcModel()

# Load the model weights
state_dict = torch.load(WEIGHTS_FILE)
my_new_state_dict = {}
my_layers = list(calc.state_dict().keys())
for layer in my_layers:
    my_new_state_dict[layer] = state_dict[layer]
calc.load_state_dict(my_new_state_dict)

print(calc)

CalcModel(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(4, 4))
  (relu1): ReLU()
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (conv3): Conv2d(128, 4, kernel_size=(3, 3), stride=(1, 1))
  (relu3): ReLU()
  (pool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
  (lrn1): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  (lrn2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
)


### ONNX Model

In [12]:
example_input = torch.randn(1, 1, 120, 160)

dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

# Export the model
torch.onnx.export(
    calc,  # model
    example_input,  # example input
    "calc_model.onnx",  # output file name
    input_names=["input"],  # input names
    output_names=["output"],  # output names
    dynamic_axes=dynamic_axes,  # dynamic axes
)

ort_session = onnxruntime.InferenceSession("calc_model.onnx")

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


### Dynamic Quantized Model (ONNX)

In [13]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_fp32 = 'calc_model.onnx'
model_quant = 'calc_model_quant_dynamic.onnx'
quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)

# Load the dynamic quantized model
ort_session_quant_dynamic = onnxruntime.InferenceSession("calc_model_quant_dynamic.onnx")

  elem_type: 7
  shape {
    dim {
      dim_value: 5
    }
    dim {
      dim_value: 2
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_value: 5
    }
    dim {
      dim_value: 2
    }
  }
}
.


### Static Quantized Model (ONNX)

In [14]:
# from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

from onnxruntime.quantization.shape_inference import quant_pre_process

quant_pre_process('calc_model.onnx', 'calc_model_quant_static_prep.onnx')

In [15]:
# calib_ds = db_tensor[:100] # first 100 for calibration - reserve for quantization
# val_ds = db_tensor[100:] # last 100 for validation

calib_ds = torch.stack([dataset_db[i] for i in range(100)])
val_ds = torch.stack([dataset_db[i] for i in range(100, len(dataset_db))])

print(calib_ds.shape)
print(val_ds.shape)

torch.Size([100, 1, 120, 160])
torch.Size([285, 1, 120, 160])


In [16]:
from onnxruntime.quantization.calibrate import CalibrationDataReader

class QuantizationDataReader(CalibrationDataReader):
    def __init__(self, torch_ds, batch_size, input_name):
        self.torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=batch_size, shuffle=False)
        self.input_name = input_name
        self.datasize = len(self.torch_dl)
        self.enum_data = iter(self.torch_dl)

    def to_numpy(self, pt_tensor):
        return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy()

    def get_next(self):
        batch = next(self.enum_data, None)
        if batch is not None:

            data = self.to_numpy(batch[0])
            data = np.expand_dims(data, axis=0)  # Add a new dimension to the data
            
            return {self.input_name: data}
        else:
            return None

    def rewind(self):
        self.enum_data = iter(self.torch_dl)

qdr = QuantizationDataReader(calib_ds, batch_size=64, input_name=ort_session.get_inputs()[0].name)

In [17]:
from onnxruntime.quantization import quantize_static

q_static_opts = {"ActivationSymmetric":False,
                 "WeightSymmetric":True}
# if torch.cuda.is_available():
#     q_static_opts = {"ActivationSymmetric":True,
#                   "WeightSymmetric":True}

# q_static_opts = {"ActivationSymmetric":False, "WeightSymmetric":False}

# check layer quantization support

quantized_model = quantize_static(model_input='calc_model_quant_static_prep.onnx',
                                               model_output='calc_model_quant_static.onnx',
                                               calibration_data_reader=qdr,
                                               extra_options=q_static_opts)

# Load the static quantized model
ort_session_quant_static = onnxruntime.InferenceSession('calc_model_quant_static.onnx')



### Quantization (Quanto)

In [18]:
calc_quanto = CalcModel()

# Load the model weights
state_dict = torch.load(WEIGHTS_FILE)
my_new_state_dict = {}
my_layers = list(calc.state_dict().keys())
for layer in my_layers:
    my_new_state_dict[layer] = state_dict[layer]
calc_quanto.load_state_dict(my_new_state_dict)

print(calc_quanto)

quanto.quantize(calc_quanto, weights=quanto.qint8, activations=quanto.qint8) # quantization is in place
print(calc_quanto)

CalcModel(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(4, 4))
  (relu1): ReLU()
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (conv3): Conv2d(128, 4, kernel_size=(3, 3), stride=(1, 1))
  (relu3): ReLU()
  (pool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
  (lrn1): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  (lrn2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
)
CalcModel(
  (conv1): QConv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(4, 4))
  (relu1): ReLU()
  (conv2): QConv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (conv3): QConv2d(128, 4, kernel_size=(3, 3), stride=(1, 1))
  (relu3): ReLU()
  (pool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
  (lrn1): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  (lrn2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
)


### Torch Compile

In [19]:
calc_compiled = torch.compile(calc, mode='default')

print(calc_compiled)

OptimizedModule(
  (_orig_mod): CalcModel(
    (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(4, 4))
    (relu1): ReLU()
    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (relu2): ReLU()
    (conv3): Conv2d(128, 4, kernel_size=(3, 3), stride=(1, 1))
    (relu3): ReLU()
    (pool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    (lrn1): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
    (lrn2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  )
)


# Run Models

### Normal Model

In [20]:
def run_model(dataloader, model):
    features = []

    with torch.no_grad():
        for batch in dataloader:
            output = model(batch)
            features.append(output)

    features = torch.cat(features, axis=0)
    return features

In [21]:
calc.eval()

# Process database tensor
db_features = run_model(db_dataloader, calc)
print(db_features.shape)

# Process query tensor
q_features = run_model(q_dataloader, calc)
print(q_features.shape)

torch.Size([385, 936])
torch.Size([385, 936])


### ONNX Model

In [22]:
def run_onnx_model(dataloader, ort_session, input_name):
    features_quant_dynamic = []

    for inputs in dataloader:
        # Convert the tensor to numpy
        inputs = inputs.detach().cpu().numpy()

        # Create the input dictionary
        ort_input = {input_name: inputs}

        # Run the model
        ort_output = ort_session.run(None, ort_input)

        # Append the output to the list
        features_quant_dynamic.append(ort_output[0])

    features_quant_dynamic = torch.from_numpy(np.concatenate(features_quant_dynamic, axis=0))
    return features_quant_dynamic

In [23]:
# Check if model is a valid ONNX model
onnx_model = onnx.load("calc_model.onnx")
onnx.checker.check_model(onnx_model)

# Load the ONNX model
ort_session = onnxruntime.InferenceSession("calc_model.onnx")

input_name = ort_session.get_inputs()[0].name

In [24]:
# Process database images
db_features_onnx = run_onnx_model(db_dataloader, ort_session, input_name)
print(db_features_onnx.shape)

# Process query images
q_features_onnx = run_onnx_model(q_dataloader, ort_session, input_name)
print(q_features_onnx.shape)

torch.Size([385, 936])
torch.Size([385, 936])


### Dynamic Quantized Model (ONNX)

In [25]:
# Check if model is a valid ONNX model
onnx_model_quant_dynamic = onnx.load("calc_model_quant_dynamic.onnx")
onnx.checker.check_model(onnx_model_quant_dynamic)

# Load the ONNX model
ort_session_quant_dynamic = onnxruntime.InferenceSession("calc_model_quant_dynamic.onnx")

input_name = ort_session_quant_dynamic.get_inputs()[0].name

In [26]:
# Process database images
db_features_quant_dynamic = run_onnx_model(db_dataloader, ort_session_quant_dynamic, input_name)
print(db_features_quant_dynamic.shape)

# Process query images
q_features_quant_dynamic = run_onnx_model(q_dataloader, ort_session_quant_dynamic, input_name)
print(q_features_quant_dynamic.shape)

torch.Size([385, 936])
torch.Size([385, 936])


### Static Quantized Model (ONNX)

In [27]:
# Check if model is a valid ONNX model
onnx_model_quant_static = onnx.load("calc_model_quant_static.onnx")
onnx.checker.check_model(onnx_model_quant_static)

# Load the ONNX model
ort_session_quant_static = onnxruntime.InferenceSession("calc_model_quant_static.onnx")

input_name = ort_session_quant_static.get_inputs()[0].name

In [28]:
# Process database images
db_features_quant_static = run_onnx_model(db_dataloader, ort_session_quant_static, input_name)
print(db_features_quant_static.shape)

# Process query images
q_features_quant_static = run_onnx_model(q_dataloader, ort_session_quant_static, input_name)
print(q_features_quant_static.shape)

torch.Size([385, 936])
torch.Size([385, 936])


### Quantization (Quanto)

In [29]:
calc_quanto.eval()

# Process database tensor
db_features_quanto = run_model(db_dataloader, calc)
print(db_features_quanto.shape)

# Process query tensor
q_features_quanto = run_model(q_dataloader, calc)
print(db_features_quanto.shape)

torch.Size([385, 936])
torch.Size([385, 936])


### Torch Compile

In [30]:
calc_compiled.eval()

# Process database tensor
db_features_torch_comp = run_model(db_dataloader, calc)
print(db_features_torch_comp.shape)

# Process query tensor
q_features_torch_comp = run_model(q_dataloader, calc)
print(q_features_torch_comp.shape)

torch.Size([385, 936])
torch.Size([385, 936])


# Average Time

In [31]:
def measure_time(dataloader, model, iterations, desc):
    times = []

    for _ in tqdm(range(iterations), desc=desc):
        start_time = time.time()

        run_model(dataloader, model)

        end_time = time.time()

        times.append(end_time - start_time)

    avg_time = sum(times) / len(times)

    return avg_time

In [32]:
def measure_time_onnx(dataloader, ort_session, input_name, iterations, desc):
    times = []

    for _ in tqdm(range(iterations), desc=desc):
        start_time = time.time()

        run_onnx_model(dataloader, ort_session, input_name)

        end_time = time.time()

        times.append(end_time - start_time)

    avg_time = sum(times) / len(times)

    return avg_time

### Normal Model

In [33]:
db_avg_time = measure_time(db_dataloader, calc, ITERATIONS, "Processing database dataset")
print(f"Database Average Time: {db_avg_time}")

q_avg_time = measure_time(q_dataloader, calc, ITERATIONS, "Processing query dataset")
print(f"Query Average Time: {q_avg_time}")

Processing database dataset: 100%|██████████| 100/100 [01:24<00:00,  1.18it/s]


Database Average Time: 0.845102665424347


Processing query dataset: 100%|██████████| 100/100 [01:25<00:00,  1.18it/s]

Query Average Time: 0.8495185351371766





### Dynamic Quantization (ONNX)

In [34]:
db_avg_time_quant_dynamic = measure_time_onnx(db_dataloader, ort_session_quant_dynamic, input_name, ITERATIONS, "Processing database dataset")
print(f"Database Average Time: {db_avg_time_quant_dynamic}")

q_avg_time_quant_dynamic = measure_time_onnx(q_dataloader, ort_session_quant_dynamic, input_name, ITERATIONS, "Processing query dataset")
print(f"Query Average Time: {q_avg_time_quant_dynamic}")

Processing database dataset: 100%|██████████| 100/100 [01:15<00:00,  1.33it/s]


Database Average Time: 0.7500852632522583


Processing query dataset: 100%|██████████| 100/100 [01:16<00:00,  1.30it/s]

Query Average Time: 0.76621089220047





### Static Quantized Model (ONNX)

In [35]:
db_avg_time_quant_static = measure_time_onnx(db_dataloader, ort_session_quant_static, input_name, ITERATIONS, "Processing database dataset")
print(f"Database Average Time: {db_avg_time_quant_static}")

q_avg_time_quant_static = measure_time_onnx(q_dataloader, ort_session_quant_static, input_name, ITERATIONS, "Processing query dataset")
print(f"Query Average Time: {q_avg_time_quant_static}")

Processing database dataset: 100%|██████████| 100/100 [01:14<00:00,  1.33it/s]


Database Average Time: 0.7493634605407715


Processing query dataset: 100%|██████████| 100/100 [01:14<00:00,  1.35it/s]

Query Average Time: 0.741964647769928





### Quantization (Quanto)

In [36]:
db_avg_time_quanto = measure_time(db_dataloader, calc_quanto, ITERATIONS, "Processing database dataset")
print(f"Database Average Time: {db_avg_time_quanto}")

q_avg_time_quanto = measure_time(q_dataloader, calc_quanto, ITERATIONS, "Processing query dataset")
print(f"Query Average Time: {q_avg_time_quanto}")

Processing database dataset: 100%|██████████| 100/100 [01:50<00:00,  1.10s/it]


Database Average Time: 1.1022711992263794


Processing query dataset: 100%|██████████| 100/100 [01:49<00:00,  1.10s/it]

Query Average Time: 1.0992321348190308





### Torch Compile

In [37]:
db_avg_time_comp = measure_time(db_dataloader, calc_compiled, ITERATIONS, "Processing database dataset")
print(f"Database Average Time: {db_avg_time_comp}")

q_avg_time_comp = measure_time(q_dataloader, calc_compiled, ITERATIONS, "Processing query dataset")
print(f"Query Average Time: {q_avg_time_comp}")

Processing database dataset: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]


Database Average Time: 0.7305069255828858


Processing query dataset: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]

Query Average Time: 0.6960044693946839



