In [None]:
import numpy as np
import onnxruntime as ort
import time
import torch
import torchvision.transforms as transforms
from datasets import load_dataset
from transformers import ViTFeatureExtractor, ViTForImageClassification

In [None]:
num_samples = 256

In [None]:
# Load datasets

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
preprocess = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

def imageTransform(example):
    example["image"] = preprocess(example["image"])
    return example
datasetStream = load_dataset("imagenet-1k", split="validation", streaming=True, trust_remote_code=True)
iterable_dataset = iter(datasetStream)
selected_samples = [next(iterable_dataset) for _ in range(num_samples)]
selected_samples = list(map(imageTransform, selected_samples))

In [None]:
# Original model metrics

def evaluate_torch(model, selected_samples, device):
    model.eval()
    correct, total = 0, 0
    latencies = []
    with torch.no_grad():
        for example in selected_samples:
            image = example["image"].unsqueeze(0).to(device)
            label = torch.tensor(example["label"]).to(device)
            
            start_time = time.time()
            output = model(image)
            end_time = time.time()
            
            latencies.append((end_time - start_time))
            pred = torch.argmax(output.logits, dim=1)
            correct += (pred == label).sum().item()
            total += 1
    
    accuracy = correct / total
    avg_latency = np.mean(latencies)
    return accuracy, avg_latency

device = torch.device("cpu")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
accuracy, avg_latency = evaluate_torch(model, selected_samples, device)

print(f"Original Model Accuracy: {accuracy * 100:.2f}%")
print(f"Original Model Average Latency Per Image: {avg_latency * 1000:.2f} ms")

In [None]:
# Quantized model metrics

def evaluate_onnx(session, selected_samples):
    correct, total = 0, 0
    latencies = []
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    for example in selected_samples:
        image = np.expand_dims(example["image"], axis=0)
        label = example["label"]
        
        start_time = time.time()
        output = session.run([output_name], {input_name: image})[0]
        end_time = time.time()
        
        latencies.append((end_time - start_time))
        pred = np.argmax(output, axis=1)[0]
        correct += (pred == label)
        total += 1
    
    accuracy = correct / total
    avg_latency = np.mean(latencies)
    return accuracy, avg_latency

model_file_path = "./model/model.onnx"
options = ort.SessionOptions()
session = ort.InferenceSession(model_file_path, sess_options=options,
                               providers=["QNNExecutionProvider"],
                               provider_options=[{"backend_path": "QnnHtp.dll"}])
accuracy, avg_latency = evaluate_onnx(session, selected_samples)

print(f"Quantized Model Accuracy: {accuracy * 100:.2f}%")
print(f"Quantized Model Average Latency Per Image: {avg_latency * 1000:.2f} ms")