In [None]:
import torchvision.transforms as transforms
import onnxruntime as ort
import torch
from PIL import Image
import numpy as np
import pandas as pd
import glob
import os
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

In [None]:
CLASS_NAMES = [
    "AKIEC",
    "BCC",
    "BEN_OTH",
    "BKL",
    "DF",
    "INF",
    "MAL_OTH",
    "BEL",
    "NV",
    "SCCKA",
    "VASC",
]

base_img_dir = "../../dataset"


class ONNXInference:
    def __init__(self, model_path):
        """Initialize ONNX model session."""
        self.session = ort.InferenceSession(model_path)
        self.input_names = [inp.name for inp in self.session.get_inputs()]

    def preprocess_image(self, image_path):
        """Load and preprocess image to [0,1] and BCHW format."""
        img = Image.open(base_img_dir + "/" + image_path).convert("RGB")
        img = img.resize((512, 512))
        img_array = np.array(img, dtype=np.float32) / 255.0  # scale to [0,1]
        img_array = np.transpose(img_array, (2, 0, 1))  # HWC -> CHW
        img_array = np.expand_dims(img_array, axis=0)  # add batch dimension
        return img_array

    def predict(self, image_path, age, gender, location):
        """Run inference on a single image with demographic data."""
        image_tensor = self.preprocess_image(image_path)
        gender_encoded = 1.0 if gender.lower() == "m" else 0.0
        demo_tensor = np.array(
            [[float(age), gender_encoded, float(location)]], dtype=np.float32
        )
        inputs = {self.input_names[0]: image_tensor, self.input_names[1]: demo_tensor}
        outputs = self.session.run(None, inputs)
        probs = outputs[0].flatten()
        pred = np.argmax(probs)
        return probs, pred, CLASS_NAMES[pred]


# Load dataset
df = pd.read_csv("../../dataset/data.csv")
le = LabelEncoder()
df["Location"] = le.fit_transform(df["Location"])
# Find all ONNX models
model_folder = "../../models/combine/2025-11-27"
model_paths = []
model_paths = glob.glob(os.path.join(model_folder, "**", "*.onnx"), recursive=True)

model_paths.append("../../models/2025-11-27/speechmaster/18_model118.onnx")
model_paths.append("../../models/2025-11-27/speechmaster/62_model94.onnx")
model_paths.append("../../models/2025-11-27/grose/61_model08.onnx")

print("Found models:", model_paths)

# Store results
results = {}
prediction_result = {}

for model_path in model_paths:
    onnx_model = ONNXInference(model_path=model_path)
    model_result = {cls: 0 for cls in CLASS_NAMES}

    res_list = []
    # Inner loop wrapped with tqdm for progress
    for _, row in tqdm(
        df.iterrows(), total=len(df), desc=f"Processing {model_path.split('/')[-1]}"
    ):
        cell = {}
        probs, pred, cls = onnx_model.predict(
            row["NewFileName"], row["Age"], row["Gender"], row["Location"]
        )
        cell["probs"] = probs
        cell["pred"] = pred
        cell["cls"] = cls
        if row["Class"] in cls:
            model_result[row["Class"]] += 1
        res_list.append(cell)

    results[model_path] = model_result
    prediction_result[model_path] = res_list