<a href="https://colab.research.google.com/github/juanserrano90/codelatam/blob/main/Training/Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/juanserrano90/codelatam.git

Cloning into 'codelatam'...
remote: Enumerating objects: 75776, done.[K
remote: Counting objects: 100% (3959/3959), done.[K
remote: Compressing objects: 100% (3948/3948), done.[K
remote: Total 75776 (delta 25), reused 3934 (delta 11), pack-reused 71817 (from 2)[K
Receiving objects: 100% (75776/75776), 696.42 MiB | 19.18 MiB/s, done.
Resolving deltas: 100% (1285/1285), done.
Updating files: 100% (90963/90963), done.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import glob
import numpy as np
import pandas as pd
import copy
import json
import matplotlib.pyplot as plt
import matplotlib as mpl
import pickle
from pathlib import Path
from collections import defaultdict
import random
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from torchvision.io import read_image, ImageReadMode
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel
from transformers import AutoImageProcessor, Swinv2Model
from transformers import DINOv3ViTModel
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from huggingface_hub import notebook_login



In [None]:
# Global definitions ---------------
data_dir = "/content/codelatam/Data"
working_dir = "/content/drive/MyDrive/Doctorado/Codelatam/Files_codelatam"
num_classes = 3
inv_dict_mapping_classes = {0:'Ia-norm', 1:'Ia-pec', 2:'Others'}
dataset_folder = 'Dataset_augmented_images'

def subtype_to_class_mapping(a):
  subtype_to_class = {0:0, 1:1, 2:1, 3:1, 4:1, 5:1, 6:2, 7:2, 8:2, 9:2, 10:2, 11:2, 12:2, 13:2, 14:2, 15:2, 16:2}
  return subtype_to_class[a]

def id_to_subtype_mapping(a):
  id_to_subtype = {0: 'Ia-norm', 1: 'Ia-91T', 3: 'Ia-csm', 2: 'Ia-91bg', 6: 'Ib-norm', 4: 'Iax', 5: 'Ia-pec', 10: 'Ic-norm',
                   13: 'IIP', 14: 'IIL', 8: 'IIb', 16: 'II-pec', 11: 'Ic-broad', 12: 'Ic-pec', 15: 'IIn', 7: 'Ibn', 9: 'Ib-pec'}
  return id_to_subtype(a)

dataset_versions = [
    'augmented_images_v2.0',
    'augmented_images_v2.0_20x20',
    'augmented_images_v2.0_20x20_n',
    'augmented_images_v2.0_224x112',
    'augmented_images_v2.0_224x112_n',
    'augmented_images_v2.0_224x224',
    'augmented_images_v2.0_224x224_n',
    'augmented_images_v2.0_224x56',
    'augmented_images_v2.0_224x56_n',
    'augmented_images_v2.0_50x50',
    'augmented_images_v2.0_50x50_n',
    ]

In [None]:
def load_split(n):
  with open(f"{data_dir}/Splits/tvt_split{n}.pkl", 'rb') as f:
    splits = pickle.load(f)
  return splits

def extract_data_from_splits(splits, root_dir):
    data = {'train': [], 'val': [], 'test': []}

    for subfolder, split_dict in splits.items():
        for split, image_list in split_dict.items():
            for image_name in image_list:
              if len(root_dir.split('/')[-1].split('_'))>3:   # dfdw datasets have different naming
                if 'COPY' in image_name:
                  image_name = image_name[:-8] + "m_" + "COPY" + ".png"
                else:
                  image_name = image_name[:-4] + "_m" + ".png"
              image_path = os.path.join(root_dir, subfolder, image_name)
              label = subtype_to_class_mapping(int(image_name.split('_')[1]))
              image_id = image_name[:-4]
              data[split].append((image_path, image_id, label))

    return data['train'], data['val'], data['test']

class CustomDataset(Dataset):
    def __init__(self, data, processor):
        """
        data: list of tuples (image_path, id, label)
        """
        self.image_paths = [d[0] for d in data]
        self.ids = [d[1] for d in data]
        self.labels = [d[2] for d in data]
        self.processor = processor
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = read_image(self.image_paths[idx], mode=ImageReadMode.RGB)
        processed = self.processor(images=image, return_tensors="pt")
        pixel_values = processed['pixel_values'].squeeze(0)
        label = torch.tensor(self.labels[idx]).long()
        return {
            'id': self.ids[idx],
            'pixel_values': pixel_values,
            'y_true': label
        }

def load_datasets(splits, dataset_name, processor):
  train_data, val_data, test_data = extract_data_from_splits(splits, os.path.join(data_dir, dataset_folder, dataset_name))

  train_dataset = CustomDataset(train_data, processor)
  val_dataset = CustomDataset(val_data, processor)
  test_dataset = CustomDataset(test_data, processor)

  print(f'Loading {dataset_name}...')

  return train_dataset, val_dataset, test_dataset

def show_example_image(dataset, n):
  image_data = dataset[n]
  # Denormalize the image using processor's mean and std
  # pixel_values are (C, H, W)
  mean = torch.tensor(processor.image_mean).view(3, 1, 1)
  std = torch.tensor(processor.image_std).view(3, 1, 1)
  denormalized_image = (image_data['pixel_values'] * std) + mean
  # Convert to numpy array, scale to 0-255, and change to uint8
  image_np = (denormalized_image.permute(1, 2, 0).numpy() * 255).astype('uint8')
  img = Image.fromarray(image_np)
  display(img)

def show_model_architecture(model):
  !pip install torchinfo
  from torchinfo import summary
  summary(model)

def load_model_and_classifier(pt_model_name, dropout, head_n):

  if 'google' in pt_model_name:
    model = ViTModel.from_pretrained(pt_model_name).to(device)
  elif 'microsoft' in pt_model_name:
    model = Swinv2Model.from_pretrained(pt_model_name).to(device)
  elif 'facebook' in pt_model_name:
    model = DINOv3ViTModel.from_pretrained(pt_model_name).to(device)

  if head_n == 3:
    classifier = nn.Sequential(
      nn.Linear(model.config.hidden_size, 512),
      nn.ReLU(),
      nn.Dropout(dropout),
      nn.Linear(512, 256),
      nn.ReLU(),
      nn.Dropout(dropout),
      nn.Linear(256, num_classes)
      ).to(device)

  elif head_n == 2:
    classifier = nn.Sequential(
      nn.Linear(model.config.hidden_size, 256),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(256, num_classes)).to(device)

  elif head_n == 1:
    classifier = nn.Sequential(
        nn.Linear(model.config.hidden_size, num_classes)).to(device)

  return model, classifier

def train_step(batch_data, model, processor, classifier, optimizer, device):
    model.train()
    classifier.train()
    y_true = batch_data['y_true'].to(device)

    # inputs = processor(images=batch_data['pixel_values'], return_tensors="pt", do_convert_rgb=False).to(device)
    # outputs = model(**inputs)
    inputs = batch_data['pixel_values'].to(device)
    outputs = model(inputs)
    # pooled_output = outputs.last_hidden_state[:, 0, :]             # alternatively
    pooled_output = outputs.pooler_output

    logits = classifier(pooled_output)
    loss = F.cross_entropy(logits, y_true)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    y_pred = torch.argmax(logits, dim=-1)
    return loss.item(), y_pred, y_true


def validate_step(batch_data, model, processor, classifier, device):
    model.eval()
    classifier.eval()
    with torch.no_grad():
        y_true = batch_data['y_true'].to(device)

        inputs = batch_data['pixel_values'].to(device)
        outputs = model(inputs)
        pooled_output = outputs.pooler_output

        logits = classifier(pooled_output)
        loss = F.cross_entropy(logits, y_true)

        y_pred = torch.argmax(logits, dim=-1)
        return loss.item(), y_pred, y_true


def predict_step(batch_data, model, processor, classifier, device):
    model.eval()
    classifier.eval()
    with torch.no_grad():
        y_true = batch_data['y_true'].to(device)

        inputs = batch_data['pixel_values'].to(device)
        outputs = model(inputs)
        pooled_output = outputs.pooler_output

        logits = classifier(pooled_output)
        y_pred_prob = F.softmax(logits, dim=1)
        y_pred = torch.argmax(logits, dim=-1)

        return {
            'id': batch_data['id'],
            'y_pred': y_pred.cpu(),
            'y_pred_prob': y_pred_prob.cpu(),
            'y_true': y_true.cpu()
        }

def evaluate(
    model,
    classifier,
    processor,
    dataloader,
    inv_class_map=None,
    save_dir=None,
    ):

    model.eval()
    classifier.eval()

    y_true = []
    y_pred = []
    y_pred_prob = []

    with torch.no_grad():
        for batch in dataloader:
          output = predict_step(batch, model, processor, classifier, device)
          y_true.append(output['y_true'])
          y_pred.append(output['y_pred'])
          y_pred_prob.append(output['y_pred_prob'])

    y_true = torch.cat(y_true).numpy()
    y_pred = torch.cat(y_pred).numpy()
    y_pred_prob = torch.cat(y_pred_prob).numpy()

    if inv_class_map is not None:
        y_true = [inv_class_map[i] for i in y_true]
        y_pred = [inv_class_map[i] for i in y_pred]

    # --- Classification report ---
    report_str = classification_report(y_true, y_pred, digits=4)
    print("\n=== Classification Report ===")
    print(report_str)
    macro_f1 = f1_score(y_true, y_pred, average="macro")

    # --- Save to disk ---
    if save_dir is not None:
      file_name = f"clfreport_bs_{hp['batch_size']}_lr_{hp['lr']}_wd_{hp['wd']}_dp_{hp['dropout']}_pt_{hp['patience']}_f_{hp['freeze']}_h{hp['head_n']}.txt"
      txt_path = os.path.join(f"{working_dir}/Runs/{save_dir}", file_name)
      with open(txt_path, "w") as f:
        f.write(report_str)

    return y_true, y_pred, y_pred_prob, macro_f1

def load_processor(pt_model_name):
  if 'google' in pt_model_name:
    processor = ViTImageProcessor.from_pretrained(pt_model_name)
  elif 'microsoft' in pt_model_name:
    processor = AutoImageProcessor.from_pretrained(pt_model_name, use_fast=True)
  elif 'facebook' in pt_model_name:
    from huggingface_hub import notebook_login
    notebook_login()
    processor = AutoImageProcessor.from_pretrained(pt_model_name, use_fast=True)

  return processor

def plot_confusion_matrix(
    y_true,
    y_pred,
    inv_dict_mapping_classes,
    normalize=True,
    figsize=(10, 7),
    save_dir=None,
    dpi=100
):

    cm = confusion_matrix(y_true, y_pred)
    class_names = inv_dict_mapping_classes.values()

    if normalize:
        cm = cm.astype(float)
        row_sums = cm.sum(axis=1, keepdims=True)
        cm = np.divide(cm, row_sums, where=row_sums != 0)

    plt.figure(figsize=figsize)

    sns.heatmap(
        cm,
        annot=True,
        fmt=".2f" if normalize else "d",
        cmap="Blues",
        cbar=False,
        xticklabels=class_names,
        yticklabels=class_names
    )

    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    title = "Confusion Matrix (Normalized)" if normalize else "Confusion Matrix"
    plt.title(title)

    plt.tight_layout()

    if save_dir is not None:
      file_name = f"cfmatrix_{model_name}.png"
      save_path = os.path.join(f"{working_dir}/Runs/{save_dir}", file_name)
      plt.savefig(save_path, dpi=dpi, bbox_inches="tight")

    plt.show()

In [None]:
# load best_model_**.pth, set to eval_mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pt_model_names = {'ViT32': 'google/vit-base-patch32-224-in21k',
                  'ViT16': 'google/vit-base-patch16-224-in21k',
                  'Swinv2': 'microsoft/swinv2-tiny-patch4-window16-256',
                  'DINOv3': 'facebook/dinov3-vits16-pretrain-lvd1689m'}
final_models = {
    "ViT16": {"bs": 64, "dropout": 0.1, "wd": 1e-4, "file_name": "model_bs_64_lr_1e-05_wd_0.0001_dp_0.1_pt_10_f_False_h3.pth"},
    "ViT32": {"bs": 32, "dropout": 0.5, "wd": 1e-4, "file_name": "model_bs_32_lr_1e-05_wd_0.0001_dp_0.5_pt_10_f_False_h3.pth"},
    "Swinv2": {"bs": 32, "dropout": 0.3, "wd": 1e-4, "file_name": "model_bs_32_lr_1e-05_wd_0.0001_dp_0.3_pt_10_f_False_h3.pth"},
    "DINOv3": {"bs": 32, "dropout": 0.5, "wd": 1e-5, "file_name": "model_bs_32_lr_1e-05_wd_1e-05_dp_0.5_pt_10_f_False_h3.pth"},
}

def run_inference(model_name, n_split, save_dir):

  dataset_name = dataset_versions[0]
  pt_model_name = pt_model_names[model_name]
  splits = load_split(n_split)
  processor = load_processor(pt_model_name)

  train_dataset, val_dataset, test_dataset = load_datasets(
      splits,
      dataset_name,
      processor,
      )

  test_dataloader = DataLoader(
      test_dataset,
      batch_size=final_models[model_name]['bs'],
      )

  model, classifier = load_model_and_classifier(
      pt_model_name,
      final_models[model_name]['dropout'],
      head_n=3
      )

  best_state = torch.load(f"{working_dir}/Runs/Phase4/{model_name}/v2.0/split4/{final_models[model_name]["file_name"]}")
  model.load_state_dict(best_state['model'])
  classifier.load_state_dict(best_state['classifier'])

  model.eval()
  classifier.eval()

  outputs = []

  with torch.no_grad():
      for batch_data in test_dataloader:
        outputs.append(predict_step(batch_data, model, processor, classifier, device))

  keys = outputs[0].keys()
  outputs_dict = {key: [] for key in keys}

  for i in range(len(outputs)):
      for key in keys:
          outputs_dict[key].append(outputs[i][key])

  outputs_dict = {key: np.concatenate(values) for key, values in outputs_dict.items()}

  outputs_dict['y_true_c'] = [inv_dict_mapping_classes[val] for val in outputs_dict['y_true']]
  outputs_dict['y_pred_c'] = [inv_dict_mapping_classes[val] for val in outputs_dict['y_pred']]

    # --- Classification report ---
  report_str = classification_report(outputs_dict['y_true_c'], outputs_dict['y_pred_c'], digits=4)
  print("\n=== Classification Report ===")
  print(report_str)
  macro_f1 = f1_score(outputs_dict['y_true_c'], outputs_dict['y_pred_c'], average="macro")

  # --- Save to disk ---
  if save_dir is not None:
    save_path = os.path.join(working_dir, "Runs", save_dir)
    os.makedirs(save_path, exist_ok=True)
    file_name = f"clfreport_{model_name}.txt"
    txt_path = os.path.join(f"{working_dir}/Runs/{save_dir}", file_name)
    with open(txt_path, "w") as f:
      f.write(report_str)

    with open(f"{working_dir}/Runs/{save_dir}/output_dict_"+model_name+".pkl", "wb") as f:
      pickle.dump(outputs_dict, f)

  plot_confusion_matrix(
    outputs_dict['y_true'], outputs_dict['y_pred'],
    inv_dict_mapping_classes,
    figsize=(5,5),
    save_dir=save_dir,
    )

  return outputs_dict

In [None]:
model_name = "DINOv3"
n_split = 4
save_dir = "Test"

# outputs_dict = run_inference(model_name, n_split, save_dir)

In [None]:
# Ensemble
outputs_dicts = {}

for m in final_models.keys():
  with open(f"{working_dir}/Runs/Test/output_dict_"+m+".pkl", 'rb') as f:
    outputs_dicts[m] = pickle.load(f)

with open(f"{working_dir}/Runs/Test/output_dict_ensemble_soft.pkl", 'rb') as f:
  outputs_dicts["Ensemble"] = pickle.load(f)

y_true = np.array(outputs_dicts["ViT16"]["y_true"])

preds = {
    "ViT16": np.array(outputs_dicts["ViT16"]["y_pred"]),
    "Swinv2": np.array(outputs_dicts["Swinv2"]["y_pred"]),
    "DINOv3": np.array(outputs_dicts["DINOv3"]["y_pred"]),
    "ViT32": np.array(outputs_dicts["ViT32"]["y_pred"]),
    "Ensemble": np.array(outputs_dicts["ViT16"]["y_pred"]),
}

probs = {
    "ViT16": np.array(outputs_dicts["ViT16"]["y_pred_prob"]),
    "Swinv2": np.array(outputs_dicts["Swinv2"]["y_pred_prob"]),
    "DINOv3": np.array(outputs_dicts["DINOv3"]["y_pred_prob"]),
    "ViT32": np.array(outputs_dicts["ViT32"]["y_pred_prob"]),
    "Ensemble": np.array(outputs_dicts["ViT16"]["y_pred_prob"]),
}

In [None]:
from scipy.stats import mode

# Hard and soft ensembles
# pred_matrix = np.stack(list(preds.values()), axis=0).astype(np.int64)
# y_pred_hard, _ = mode(pred_matrix, axis=0, keepdims=False)

# avg_probs = np.mean(np.stack(list(probs.values()), axis=0), axis=0)
# y_pred_soft = np.argmax(avg_probs, axis=1)

# y_pred_soft_c = [inv_dict_mapping_classes[val] for val in y_pred_soft]
# y_pred_soft_c

# print("=== Hard voting (3 models) ===")
# print(classification_report(y_true, y_pred_hard, digits=4))

# print("=== Soft voting (3 models) ===")
# print(classification_report(y_true, y_pred_soft, digits=4))

# soft_output_dict = {"y_true": y_true,
#                     "y_pred": y_pred_soft,
#                     "y_pred_prob": avg_probs,
#                     "y_pred_c": y_pred_soft_c}

# normalize=True
# figsize=(6,6)
#   # --- Classification report ---
# report_str = classification_report(y_true, y_pred_soft, digits=4)
# # --- Save to disk ---
# save_path = os.path.join(working_dir, "Runs", "Test")
# file_name = f"clfreport_ensemble_soft.txt"
# txt_path = os.path.join(f"{working_dir}/Runs/Test", file_name)
# with open(txt_path, "w") as f:
#   f.write(report_str)

# with open(f"{working_dir}/Runs/Test/output_dict_ensemble_soft.pkl", "wb") as f:
#   pickle.dump(soft_output_dict, f)

# cm = confusion_matrix(y_true, y_pred_soft)
# class_names = inv_dict_mapping_classes.values()

# if normalize:
#     cm = cm.astype(float)
#     row_sums = cm.sum(axis=1, keepdims=True)
#     cm = np.divide(cm, row_sums, where=row_sums != 0)

# plt.figure(figsize=figsize)

# sns.heatmap(
#     cm,
#     annot=True,
#     fmt=".2f" if normalize else "d",
#     cmap="Blues",
#     cbar=False,
#     xticklabels=class_names,
#     yticklabels=class_names
# )

# plt.xlabel("Predicted Labels")
# plt.ylabel("True Labels")
# title = "Confusion Matrix (Normalized)" if normalize else "Confusion Matrix"
# plt.title(title)

# plt.tight_layout()


# file_name = f"cfmatrix_ensemble_soft.png"
# save_path = os.path.join(f"{working_dir}/Runs/Test", file_name)
# plt.savefig(save_path, dpi=100, bbox_inches="tight")

# plt.show()

In [None]:
from sklearn.metrics import roc_auc_score

print(f"--------- \t ROC AUC (ovo weighted) \t ROC AUC (ovr weighted)")
for m in outputs_dicts.keys():
  r = roc_auc_score(outputs_dicts[m]['y_true'], outputs_dicts[m]['y_pred_prob'], multi_class='ovo', average="weighted")
  r2 = roc_auc_score(outputs_dicts[m]['y_true'], outputs_dicts[m]['y_pred_prob'], multi_class='ovr', average="weighted")
  print(f"{m:<10} \t {r:.2f}                       \t {r2:.2f}")

--------- 	 ROC AUC (ovo weighted) 	 ROC AUC (ovr weighted)
ViT16      	 0.95                       	 0.95
ViT32      	 0.94                       	 0.95
Swinv2     	 0.97                       	 0.97
DINOv3     	 0.95                       	 0.95
Ensemble   	 0.97                       	 0.98
