## Installing and importing necessary libraries

### CORN loss function

In [None]:
!pip install coral-pytorch

### GradCAM

In [None]:
!pip install grad-cam

### Pytorch flops

In [None]:
!pip install ptflops

### Optuna

In [None]:
!pip install optuna

### Imports

In [None]:
# basic libs
from google.colab import drive
from google.colab import auth
from google.auth import default
import gspread
from gspread_dataframe import set_with_dataframe
import os
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import time
from PIL import Image
from collections import defaultdict
import random
from datetime import datetime
import pytz
import json
import shutil
import math

# pytorch libs
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, Subset
from ptflops import get_model_complexity_info
from transformers import (ViTForImageClassification,
                          ViTImageProcessor,
                          DeiTForImageClassification,
                          DeiTForImageClassificationWithTeacher,
                          DeiTImageProcessor,
                          AutoFeatureExtractor,
                          AutoImageProcessor,
                          AutoModelForImageClassification)

# coral pytorch
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits

# pytorch grad cam
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# plotting libs
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             cohen_kappa_score, mean_absolute_error, confusion_matrix, classification_report,
                             roc_curve, auc)
from sklearn.preprocessing import label_binarize, OneHotEncoder
import gc as gc2
import timm

import optuna
import joblib

## Mounting Google Drive

In [None]:
drive.mount('/content/drive')

In [None]:
root_folder = '/content/drive/MyDrive/pgc'
dataset_prefix = 'preprocessed_dataset'
results_folder = 'trained_models'
grad_cam_folder = 'grad_cam'
nr_classes = 5

In [None]:
def get_dataset_dir(nr_classes):
  return f'{root_folder}/{dataset_prefix}_{nr_classes}_classes'

In [None]:
def get_results_dir(nr_classes, model_name):
  return f'{root_folder}/{results_folder}/{nr_classes}_classes/{model_name}'

In [None]:
def get_grad_cam_dir(nr_classes, criteria):
  return f'{root_folder}/{grad_cam_folder}/{nr_classes}_classes/{criteria}'

In [None]:
def get_output_dir_gradcam(criteria, model_name):
  return f'{root_folder}/{grad_cam_folder}/{criteria}/{model_name}'

## Setting random seed for reproducibility

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

## Hyperparameters

In [None]:
class Config:
    data_dir = get_dataset_dir(nr_classes)
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')
    test_dir = os.path.join(data_dir, 'test')
    calib_dit = os.path.join(data_dir, 'calib')
    num_classes = nr_classes
    max_samples_per_class = 1700 # undersampling for dataset imbalance
    batch_size = 28
    num_epochs = 60
    shuffle = True
    feature_extract = False
    use_pretrained = True
    learning_rate = 0.0001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epsilon = 0.05 # 95% confidence

## Data transforms

In [None]:
def get_transforms(model_name, dataset, processor=None):
  if not processor:
    image_size = (224, 224) if (model_name != 'inception_v3') else (299, 299)
    normalization_mean = [0.485, 0.456, 0.406]
    normalization_std = [0.229, 0.224, 0.225]
  elif ('facebook/deit' in model_name):
    image_size = processor.crop_size['height']
    normalization_mean = processor.image_mean
    normalization_std = processor.image_std
  else:
    image_size = processor.size['height']
    normalization_mean = processor.image_mean
    normalization_std = processor.image_std

  data_transforms = {
      'train': transforms.Compose([
          transforms.Resize(image_size),
          transforms.RandomHorizontalFlip(),
          transforms.RandomRotation(10),
          transforms.ToTensor(),
          transforms.Normalize(normalization_mean, normalization_std)
      ]),
      'val': transforms.Compose([
          transforms.Resize(image_size),
          transforms.ToTensor(),
          transforms.Normalize(normalization_mean, normalization_std)
      ]),
      'test': transforms.Compose([
          transforms.Resize(image_size),
          transforms.ToTensor(),
          transforms.Normalize(normalization_mean, normalization_std)
      ]),
      'calib': transforms.Compose([
          transforms.Resize(image_size),
          transforms.ToTensor(),
          transforms.Normalize(normalization_mean, normalization_std)
      ])
  }

  return data_transforms[dataset]

### Undersampling

In [None]:
def undersample_dataset(dataset):
  # organize indices by class
  class_indices = defaultdict(list)
  for idx, (img_path, label) in enumerate(dataset['train'].imgs):
    class_indices[label].append(idx)

  # limit the number of samples per class
  limited_indices = {}
  for label, indices in class_indices.items():
    limited_indices[label] = random.sample(indices, min(len(indices), Config.max_samples_per_class))

  limited_train_dataset = Subset(dataset['train'], sum(limited_indices.values(), []))
  dataset['train'] = limited_train_dataset

  return dataset

## Load the datasets

In [None]:
def load_dataset(model_name, processor=None, batch_size=Config.batch_size):
  dataset = {
      x: datasets.ImageFolder(os.path.join(Config.data_dir, x), get_transforms(model_name, x, processor))
      for x in ['train', 'val', 'test', 'calib']
  }

  # applying undersampling
  dataset = undersample_dataset(dataset)

  dataloaders = {}
  for x in ['train', 'val', 'test', 'calib']:
    shuffle = True if x == 'train' else False
    dataloaders[x] = DataLoader(
        dataset[x],
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=2
    )

  dataset_sizes = {
      x: len(dataset[x])
      for x in ['train', 'val', 'test', 'calib']
  }

  return dataloaders, dataset_sizes

## Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, verbose=True, loss_function='cross_entropy'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.min_delta = min_delta
        self.val_loss_min = float('inf')
        self.loss_function = loss_function

    def __call__(self, val_loss, model, model_name):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model, model_name)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, model_name):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.4f} --> {val_loss:.4f}).  Saving model ...')
        model_file = f'{model_name}_{self.loss_function}.pth'
        torch.save(model.state_dict(), model_file)
        self.val_loss_min = val_loss

## Getting model

In [None]:
def get_model(model_name):
  model = None
  if model_name == 'resnet34':
    model = models.resnet34(pretrained=Config.use_pretrained)
  elif model_name == 'resnet50':
    model = models.resnet50(pretrained=Config.use_pretrained)
  elif model_name == 'resnet101':
    model = models.resnet101(pretrained=Config.use_pretrained)
  elif model_name == 'vgg16':
    model = models.vgg16(pretrained=Config.use_pretrained)
  elif model_name == 'vgg19':
    model = models.vgg19(pretrained=Config.use_pretrained)
  elif model_name == 'densenet121':
    model = models.densenet121(pretrained=Config.use_pretrained)
  elif model_name == 'densenet169':
    model = models.densenet169(pretrained=Config.use_pretrained)
  elif model_name == 'inception_v3':
    model = models.inception_v3(pretrained=Config.use_pretrained)
  elif model_name == 'swin_b':
    model = models.swin_b(pretrained=Config.use_pretrained)
  elif ('facebook/deit' in model_name):
    model = DeiTForImageClassification.from_pretrained(model_name, num_labels=Config.num_classes, ignore_mismatched_sizes=True)
  elif model_name == 'maxvit_t':
    model = models.maxvit_t(pretrained=Config.use_pretrained)
  elif ('davit' in model_name) or ('gcvit' in model_name):
    model = timm.create_model(model_name, pretrained=Config.use_pretrained, num_classes=Config.num_classes)
  else:
    raise ValueError(f"Invalid model name: {model_name}")

  return model

In [None]:
def set_fc_layer(model, model_name, loss_function):
  num_classes = Config.num_classes if loss_function == 'cross_entropy' else Config.num_classes-1

  if 'resnet' in model_name:
    model.fc = nn.Linear(model.fc.in_features, num_classes)
  elif 'vgg' in model_name:
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
  elif 'densenet' in model_name:
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
  elif 'inception' in model_name:
    model.fc = nn.Linear(model.fc.in_features, num_classes)
  elif 'swin' in model_name:
    model.head = nn.Linear(model.head.in_features, num_classes)
  elif 'facebook/deit' in model_name:
    model.classifier = nn.Linear(model.config.hidden_size, num_classes)
  elif 'maxvit_t' in model_name:
    model.classifier[5] = nn.Linear(model.classifier[5].in_features, num_classes)
  elif 'davit' in model_name:
    model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
  elif 'gcvit' in model_name:
    model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
  else:
    raise ValueError(f"Invalid model name: {model_name}")

  return model

In [None]:
def init_model(model_name, loss_function='cross_entropy', feature_extract=Config.feature_extract):
  model = get_model(model_name)
  model = set_fc_layer(model, model_name, loss_function)

  if feature_extract:
    for name, param in model.named_parameters():
      if 'fc' in name:
        param.requires_grad = True
      elif 'layer4' in name:
        param.requires_grad = True
      elif 'classifier' in name:
        param.requires_grad = True
      else:
        param.requires_grad = False
  else:
    for param in model.parameters():
      param.requires_grad = True

  for name, param in model.named_parameters():
    print(f"{name}: {'trainable' if param.requires_grad else 'frozen'}")

  model = model.to(Config.device)

  return model

## Utils

In [None]:
def save_in_google_drive(model_name, file_name):
  dest_folder = f"{root_folder}/{results_folder}/{nr_classes}_classes/{model_name}"
  os.makedirs(dest_folder, exist_ok=True)

  if file_name.endswith('.pth'):
    shutil.copy(file_name, f"{dest_folder}/{file_name}")
  elif file_name.endswith('.png'):
    shutil.copy(file_name, f"{dest_folder}/{file_name}")
  elif file_name.endswith('.json'):
    shutil.copy(file_name, f"{dest_folder}{file_name}")
  else:
    raise ValueError(f"Invalid file type: {file_name}")

In [None]:
def plot_and_save_training_curves(model_name, loss_function, train_losses, val_losses, train_accs, val_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(train_losses, label='Erro de treinamento')
    ax1.plot(val_losses, label='Erro de validação')
    ax1.set_title('Erro de treinamento e validação')
    ax1.set_xlabel('Época')
    ax1.set_ylabel('Erro')
    ax1.legend()

    ax2.plot(train_accs, label='Acurácia de treinamento')
    ax2.plot(val_accs, label='Acurácia de validação')
    ax2.set_title('Acurácia de treinamento e validação')
    ax2.set_xlabel('Época')
    ax2.set_ylabel('Acurácia')
    ax2.legend()

    # save the chart to a file in colab
    file_name = f'{model_name}_curves_{loss_function}.png'
    plt.savefig(file_name)

    plt.show()
    plt.close()

    save_in_google_drive(model_name, file_name)

In [None]:
def save_results(base_file_name, results):
  sao_paulo_tz = pytz.timezone('America/Sao_Paulo')
  timestamp = datetime.now(sao_paulo_tz).strftime('%Y-%m-%d_%H-%M-%S')
  filename = f'{base_file_name}_{timestamp}.json'

  with open(filename, 'w') as f:
    json.dump(results, f, indent=4)

  save_in_google_drive('', filename)

In [None]:
def get_model_processor(model_name, model_src):
  if model_src == 'pytorch':
    return None
  if model_name == 'facebook/deit-base-distilled-patch16-224':
    return AutoFeatureExtractor.from_pretrained(model_name)
  # return ViTImageProcessor.from_pretrained(model_name)
  return AutoImageProcessor.from_pretrained(model_name)

In [None]:
def get_results_from_json_file():
  file_path = '/content/classification_report_2025-06-20_02-40-18.json'
  if not os.path.exists(file_path):
    return {}

  with open(file_path, 'r') as f:
    results = json.load(f)

  return results

In [None]:
def get_classes_to_exclude():
  if nr_classes == 5:
    return []
  elif nr_classes == 4:
    return [1]
  elif nr_classes == 3:
    return [0,1]
  elif nr_classes == 2:
    return [0,1,2]
  else:
    raise ValueError(f"Invalid number of classes: {nr_classes}")

In [None]:
def get_cm_labels():
  if nr_classes == 5:
    return None
  elif nr_classes == 4:
    return ["0", "2", "3", "4"]
  elif nr_classes == 3:
    return ["2", "3", "4"]
  elif nr_classes == 2:
    return ["0", "1"]
  else:
    raise ValueError(f"Invalid number of classes: {nr_classes}")

In [None]:
def sanitize_model_name(model_name):
  return model_name.replace('/', '-').replace('.', '-')

In [None]:
def load_model(model, path, map_location='cuda' if torch.cuda.is_available() else 'cpu'):
  print(path)
  model.load_state_dict(torch.load(path, map_location=map_location))
  model.eval()

  return model

In [None]:
def map_model_labels(y_true, y_pred, labels):
  if labels is None:
    return y_true, y_pred
  return [labels[i] for i in y_true], [labels[i] for i in y_pred]

## Model evaluation

In [None]:
def measure_inference_time(model, dataloader, warmup=5, repeat=50):
  model.eval()

  inputs, _ = next(iter(dataloader))
  inputs = inputs.to(Config.device)

  for _ in range(warmup):
    with torch.no_grad():
      _ = model(inputs)

  start = time.time()
  for _ in range(repeat):
    with torch.no_grad():
      _ = model(inputs)
  end = time.time()

  total_time = end - start
  avg_inference_time = total_time / repeat
  time_per_sample = avg_inference_time / inputs.shape[0]

  print(f"Inferência média por batch: {avg_inference_time:.6f} segundos")
  print(f"Inferência média por amostra: {time_per_sample:.6f} segundos")
  return avg_inference_time, time_per_sample

In [None]:
def gen_confusion_matrix(model_name, loss_function, y_true, y_pred, labels=None):
  cm = confusion_matrix(y_true, y_pred, labels=labels)
  plt.figure(figsize=(10, 8))

  if labels is None:
    labels = 'auto'

  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
  plt.title('Matriz de Confusão')
  plt.xlabel('Predição')
  plt.ylabel('Real')
  file_name = f'{model_name}_cm_{loss_function}.png'
  plt.savefig(file_name)

  plt.show()
  plt.close()

  save_in_google_drive(model_name, file_name)

In [None]:
def gen_classification_report(model_name, loss_function, y_true, y_pred, classes_to_exclude):
  # calculate overall metrics
  kappa = cohen_kappa_score(y_true, y_pred)
  qwk = cohen_kappa_score(y_true, y_pred, weights='quadratic')
  mae = mean_absolute_error(y_true, y_pred)

  report = classification_report(
    y_true,
    y_pred,
    # target_names=[f'KL{i}' for i in range(5) if i not in classes_to_exclude],
    target_names=[f'{i}' for i in range(5) if i not in classes_to_exclude],
    output_dict=True
  )
  report['kappa'] = kappa
  report['qwk'] = qwk
  report['mae'] = mae

  return report

In [None]:
def evaluate_model(model, model_src, model_name, loss_function, test_loader, classes_to_exclude=[], cm_labels=None):
  model.eval()
  all_preds, all_labels = [], []

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(Config.device), labels.to(Config.device)

      outputs = model(inputs)
      if model_src == 'hugging_face':
        outputs = outputs.logits

      preds = get_predictions(outputs, loss_function)
      all_preds.extend(preds.cpu().numpy())
      all_labels.extend(labels.cpu().numpy())

  all_preds = np.array(all_preds)
  all_labels = np.array(all_labels)

  all_labels, all_preds = map_model_labels(all_labels, all_preds, cm_labels)
  gen_confusion_matrix(model_name, loss_function, all_labels, all_preds, cm_labels)

  return gen_classification_report(model_name, loss_function, all_labels, all_preds, classes_to_exclude)

In [None]:
def run_model_evaluation():
  classes_to_exclude = get_classes_to_exclude()
  cm_labels = get_cm_labels()

  for model_name, model_src in models_list.items():
    processor = get_model_processor(model_name, model_src)
    dataloaders, dataset_sizes = load_dataset(model_name, processor)
    model_name_sanitized = sanitize_model_name(model_name)

    for loss_function in ['cross_entropy', 'corn']:
      model = init_model(model_name, loss_function)

      base_path = get_results_dir(nr_classes, model_name_sanitized)
      file_name = f"{model_name_sanitized}_{loss_function}.pth"
      path = os.path.join(base_path, file_name)
      print(path)

      model = load_model(model, path)
      report = evaluate_model(model, model_src, model_name_sanitized, loss_function, dataloaders['test'], classes_to_exclude, cm_labels)
      print(json.dumps(report, indent=4))

### Run evaluation from trained model

In [None]:
# run_model_evaluation()

## Model training

In [None]:
def print_training_time(start_time, end_time):
  total_time = end_time - start_time
  print(f"\nTraining Time: {total_time / 60:.2f} minutes")

In [None]:
def compute_loss(outputs, labels, criterion, loss_function):
  if loss_function == 'cross_entropy':
    return criterion(outputs, labels)
  elif loss_function == 'corn':
    return criterion(outputs, labels, num_classes=Config.num_classes)

In [None]:
def get_predictions(outputs, loss_function):
  if loss_function == 'cross_entropy':
    _, preds = torch.max(outputs, 1)
  elif loss_function == 'corn':
    preds = corn_label_from_logits(outputs)
  return preds

In [None]:
def train_one_epoch(model, dataloader, datasize, loss_function, criterion, optimizer, model_name, model_src, device):
  model.train()
  running_loss = 0.0
  running_corrects = 0

  for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()

    with torch.set_grad_enabled(True):
      outputs = model(inputs)
      if model_src == 'hugging_face':
        outputs = outputs.logits

      if model_name == 'inception_v3':
        outputs, aux_logits = outputs.logits, outputs.aux_logits
        loss = compute_loss(outputs, labels, criterion, loss_function)
        aux_loss = compute_loss(aux_logits, labels, criterion, loss_function)
        loss += 0.4 * aux_loss
      else:
        loss = compute_loss(outputs, labels, criterion, loss_function)

      preds = get_predictions(outputs, loss_function)
      loss.backward()
      optimizer.step()

    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)

  epoch_loss = running_loss / datasize
  epoch_acc = running_corrects.double() / datasize
  return epoch_loss, epoch_acc.item()

In [None]:
def validate_one_epoch(model, dataloader, datasize, loss_function, criterion, optimizer, model_src, device):
  model.eval()
  running_loss = 0.0
  running_corrects = 0

  for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.to(Config.device), labels.to(Config.device)
    optimizer.zero_grad()

    with torch.set_grad_enabled(False):
      outputs = model(inputs)
      if model_src == 'hugging_face':
        outputs = outputs.logits

      loss = compute_loss(outputs, labels, criterion, loss_function)
      preds = get_predictions(outputs, loss_function)

    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)

  epoch_loss = running_loss / datasize
  epoch_acc = running_corrects.double() / datasize
  return epoch_loss, epoch_acc.item()

In [None]:
def train_model(model_metadata):
  # destructuring arguments
  model, model_src, model_name = (
    model_metadata['model'],
    model_metadata['model_src'],
    model_metadata['model_name']
  )
  dataloaders, dataset_sizes = (
    model_metadata['dataloaders'],
    model_metadata['dataset_sizes']
  )
  loss_function, criterion, optimizer, scheduler = (
    model_metadata['loss_function'],
    model_metadata['criterion'],
    model_metadata['optimizer'],
    model_metadata['scheduler']
  )

  best_model_wts = model.state_dict()
  best_acc = 0.0
  train_losses, val_losses, train_accs, val_accs = [], [], [], []

  early_stopping = EarlyStopping(patience=5, verbose=True, loss_function=loss_function)
  start_time = time.time()
  print(f"loss function={loss_function}")

  for epoch in range(1, Config.num_epochs+1):
    print(f"Epoch {epoch}/{Config.num_epochs}" + "-"*10)

    train_loss, train_acc = train_one_epoch(model, dataloaders['train'], dataset_sizes['train'], loss_function, criterion, optimizer, model_name, model_src, Config.device)
    val_loss, val_acc = validate_one_epoch(model, dataloaders['val'], dataset_sizes['val'], loss_function, criterion, optimizer, model_src, Config.device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    scheduler.step()

    print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    early_stopping(val_loss, model, model_name)
    if early_stopping.early_stop:
      print("Early stopping")
      break

    if val_acc > best_acc:
      best_acc = val_acc
      best_model_wts = model.state_dict()

  end_time = time.time()
  print_training_time(start_time, end_time)
  print(f"Best val Acc: {best_acc:.4f}")
  model.load_state_dict(best_model_wts)
  return model, train_losses, val_losses, train_accs, val_accs, end_time - start_time

In [None]:
# @title
# def train_model():
#   # destructuring arguments
#   # model, model_src, model_name = (
#   #   model_data['model'],
#   #   model_data['model_src'],
#   #   model_data['model_name']
#   # )
#   # dataloaders, dataset_sizes = (
#   #   model_data['dataloaders'],
#   #   model_data['dataset_sizes']
#   # )
#   # loss_function, criterion, optimizer, scheduler = (
#   #   model_data['loss_function'],
#   #   model_data['criterion'],
#   #   model_data['optimizer'],
#   #   model_data['scheduler']
#   # )
#   # num_epochs = model_data['num_epochs']

#   best_model_wts = model.state_dict()
#   best_acc = 0.0
#   train_losses, val_losses, train_accs, val_accs = [], [], [], []

#   early_stopping = EarlyStopping(patience=5, verbose=True, loss_function=loss_function)
#   start_time = time.time()

#   for epoch in range(1, Config.num_epochs+1):
#     print(f"Epoch {epoch}/{Config.num_epochs}" + "-"*10)

#     for phase in ['train', 'val']:
#       if phase == 'train':
#         model.train()
#       else:
#         model.eval()

#       running_loss = 0.0
#       running_corrects = 0

#       for inputs, labels in tqdm(dataloaders[phase]):
#         inputs = inputs.to(Config.device)
#         labels = labels.to(Config.device)

#         # clear the gradients of all optimized parameters
#         optimizer.zero_grad()

#         with torch.set_grad_enabled(phase == 'train'):
#           # forward pass: obtain the model predictions for the input data
#           outputs = model(inputs)
#           outputs = outputs.logits if model_src == 'hugging_face' else outputs

#           if model_name == 'inception_v3' and phase == 'train':
#             _, preds = torch.max(outputs.logits, 1) if loss_function == 'cross_entropy' else (None, corn_label_from_logits(outputs.logits))

#             # compute the loss between the model predictions and the true labels
#             loss = criterion(outputs.logits, labels) if loss_function == 'cross_entropy' else criterion(outputs.logits, labels, num_classes=Config.num_classes)
#           else:
#             _, preds = torch.max(outputs, 1) if loss_function == 'cross_entropy' else (None, corn_label_from_logits(outputs))

#             # compute the loss between the model predictions and the true labels
#             loss = criterion(outputs, labels) if loss_function == 'cross_entropy' else criterion(outputs, labels, num_classes=Config.num_classes)

#           if phase == 'train':
#             if model_name == 'inception_v3':
#               aux_logits = outputs.aux_logits
#               aux_loss = criterion(aux_logits, labels) if loss_function == 'cross_entropy' else criterion(aux_logits, labels, num_classes=Config.num_classes)
#               loss += 0.4 * aux_loss

#             # backward pass: compute gradients of the loss with respect to model parameters
#             loss.backward()

#             # update the model parameters using the computed gradients
#             optimizer.step()

#         running_loss += loss.item() * inputs.size(0)
#         running_corrects += torch.sum(preds == labels.data)

#       epoch_loss = running_loss / dataset_sizes[phase]
#       epoch_acc = running_corrects.double() / dataset_sizes[phase]

#       if phase == 'train':
#         scheduler.step()
#         train_losses.append(epoch_loss)
#         train_accs.append(epoch_acc.item())
#       else:
#         val_losses.append(epoch_loss)
#         val_accs.append(epoch_acc.item())

#       print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

#       # Early stopping
#       if phase == 'val':
#         early_stopping(epoch_loss, model, model_name)
#         if early_stopping.early_stop:
#           print("Early stopping")
#           end_time = time.time()
#           total_time = end_time - start_time
#           print_training_time(start_time, end_time)
#           print(f"Best val Acc: {best_acc:.4f}")
#           model.load_state_dict(best_model_wts)
#           return model, train_losses, val_losses, train_accs, val_accs, total_time

#       if phase == 'val' and epoch_acc > best_acc:
#         best_acc = epoch_acc
#         best_model_wts = model.state_dict()

#     print("")

#   # record the end time
#   end_time = time.time()
#   total_time = end_time - start_time
#   print_training_time(start_time, end_time)

#   print(f"Best val Acc: {best_acc:.4f}")
#   model.load_state_dict(best_model_wts)

#   return model, train_losses, val_losses, train_accs, val_accs, total_time

### Model's list

In [None]:
models_list = {
    'resnet34': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'resnet34',
        'skip': True,
    },
    'resnet50': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'resnet50',
        'skip': True,
    },
    'resnet101': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'resnet101',
        'skip': True,
    },
    'vgg16': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'vgg16',
        'skip': True,
    },
    'vgg19': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'vgg19',
        'skip': True,
    },
    'densenet121': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'densenet121',
        'skip': True,
    },
    'densenet169': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'densenet169',
        'skip': True,
    },
    'inception_v3': {
        'type': 'cnn',
        'source': 'pytorch',
        'shortname': 'inception_v3',
        'skip': True,
    },
    'facebook/deit-base-distilled-patch16-224': {
        'type': 'vit',
        'source': 'hugging_face',
        'shortname': 'facebook-deit-base-distilled-patch16-224',
        'skip': True,
    },
    'davit_base.msft_in1k': {
        'type': 'vit',
        'source': 'pytorch',
        'shortname': 'davit_base-msft_in1k',
        'skip': False,
    },
    'maxvit_t': {
        'type': 'vit',
        'source': 'pytorch',
        'shortname': 'maxvit_t',
        'skip': True,
    },
    'gcvit_base.in1k': {
        'type': 'vit',
        'source': 'pytorch',
        'shortname': 'gcvit_base-in1k',
        'skip': True,
    },
    'swin_b': {
        'type': 'vit',
        'source': 'pytorch',
        'shortname': 'swin_b',
        'skip': True,
    },
}

In [None]:
loss_functions = ['cross_entropy', 'corn']

In [None]:
def run_training():
  results = get_results_from_json_file()
  classes_to_exclude = get_classes_to_exclude()
  cm_labels = get_cm_labels()

  for model_name, model_src in models_list.items():
    processor = get_model_processor(model_name, model_src)
    dataloaders, dataset_sizes = load_dataset(model_name, processor)
    model_name_sanitized = sanitize_model_name(model_name)

    results[model_name_sanitized] = {}

    for loss_function in ['cross_entropy', 'corn']:
      model = init_model(model_name, loss_function)
      criterion = nn.CrossEntropyLoss() if loss_function == 'cross_entropy' else corn_loss
      optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
      scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

      # defining model metadata dict
      model_metadata = {
          'model': model,
          'model_src': model_src,
          'model_name': model_name_sanitized,
          'dataloaders': dataloaders,
          'dataset_sizes': dataset_sizes,
          'loss_function': loss_function,
          'criterion': criterion,
          'optimizer': optimizer,
          'scheduler': scheduler
      }

      # train the model using transfer learning
      model, train_losses, val_losses, train_accs, val_accs, train_time = train_model(model_metadata)

      save_in_google_drive(model_name_sanitized, f'{model_name_sanitized}_{loss_function}.pth')
      plot_and_save_training_curves(model_name_sanitized, loss_function, train_losses, val_losses, train_accs, val_accs)

      # evaluate the model using test dataloader
      report = evaluate_model(model, model_src, model_name_sanitized, loss_function, dataloaders['test'], classes_to_exclude, cm_labels)

      # get model complexity info
      flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, backend='pytorch', print_per_layer_stat=False)

      report['train_time'] = train_time
      report['flops'] = flops
      report['params'] = params

      results[model_name_sanitized][loss_function] = report

  # save report
  save_results('classification_report', results)
  print(f'Training completed. Reports saved.')

### Run training

In [None]:
# run_training()

# Optuna

In [None]:
def objective(trial):
  results = get_results_from_json_file()
  classes_to_exclude = get_classes_to_exclude()
  cm_labels = get_cm_labels()

  # === Suggest hyperparameters with Optuna ===
  model_name = trial.suggest_categorical("model_name", ["densenet169", "davit_base.msft_in1k", "densenet121", "gcvit_base.in1k", "maxvit_t"])
  loss_function = trial.suggest_categorical("loss_function", ["cross_entropy", "corn"])
  learning_rate = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
  feature_extract = trial.suggest_categorical("feature_extract", [True, False])
  batch_size = trial.suggest_categorical("batch_size", [16, 28, 32, 64])

  model_src = "pytorch"
  model_name_sanitized = sanitize_model_name(model_name)

  # === Init model ===
  processor = get_model_processor(model_name, model_src)
  dataloaders, dataset_sizes = load_dataset(model_name, processor, batch_size=batch_size)
  model = init_model(model_name, loss_function, feature_extract)

  criterion = nn.CrossEntropyLoss() if loss_function == 'cross_entropy' else corn_loss
  optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

  model_metadata = {
    'model': model,
    'model_src': model_src,
    'model_name': model_name_sanitized,
    'dataloaders': dataloaders,
    'dataset_sizes': dataset_sizes,
    'loss_function': loss_function,
    'criterion': criterion,
    'optimizer': optimizer,
    'scheduler': scheduler
  }

  # === Training ===
  model, train_losses, val_losses, train_accs, val_accs, train_time = train_model(model_metadata)

  # === Evaluate ===
  report = evaluate_model(model, model_src, model_name_sanitized, loss_function, dataloaders['val'], classes_to_exclude, cm_labels)

  # === Free GPU
  del model
  del dataloaders
  del model_metadata
  gc2.collect()
  torch.cuda.empty_cache()

  # Returning the metric to maximize
  return report['qwk']

## Run optimization

In [None]:
def run_optuna():
  study = optuna.create_study(direction="maximize")
  study.optimize(objective, n_trials=30, timeout=36000) # 10h

  print("Melhores hiperparâmetros encontrados:")
  print(study.best_params)

  # (Opcional) salvar
  joblib.dump(study, "optuna_study.pkl")

In [None]:
# run_optuna()

# Conformal prediction

### Calculate non-conformity scores

In [None]:
def xe_calc_conformity_score(logits, y):
  probs = F.softmax(logits, dim=0)
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)
  score = 0.0

  for k, idx in enumerate(sorted_indices):
    if idx.item() == y:
      score = sorted_probs[:k+1].sum().item()
      return score

In [None]:
def corn_calc_conformity_score(logits, y):
  probs = torch.sigmoid(logits)
  score = 0.0

  for k in range(nr_classes - 1):
    p = probs[k].item()
    if k >= y:
      score += p
    else:
      score += 1 - p

  return score

In [None]:
def calc_nonconformity_scores(model, model_src, calib_loader, loss_function):
  model.eval()
  scores = []

  with torch.no_grad():
    for inputs, labels in tqdm(calib_loader):
      inputs, labels = inputs.to(Config.device), labels.to(Config.device)
      outputs = model(inputs)

      if model_src == 'hugging_face':
        outputs = outputs.logits # shape: (B, K-1) para CORN

      for i in range(len(inputs)):
        y = labels[i].item()
        logits_i = outputs[i]

        if loss_function == 'corn':
          scores.append(corn_calc_conformity_score(logits_i, y))
        else:
          scores.append(xe_calc_conformity_score(logits_i, y))

  return scores

In [None]:
# corn
def calc_per_threshold_scores(model, model_src, calib_loader):
  model.eval()
  per_threshold_scores = [[] for _ in range(nr_classes - 1)] # K-1 thresholds

  with torch.no_grad():
    for inputs, labels in tqdm(calib_loader):
      inputs, labels = inputs.to(Config.device), labels.to(Config.device)
      outputs = model(inputs)

      if model_src == 'hugging_face':
        outputs = outputs.logits # shape: (B, K-1)

      probs = torch.sigmoid(outputs)

      for i in range(len(inputs)):
        y = labels[i].item()
        for k in range(nr_classes - 1):
          p = probs[i, k].item()
          if y > k:
            score = 1 - p
          else:
            score = p
          per_threshold_scores[k].append(score)

  return per_threshold_scores

In [None]:
def get_prediction_set(probs, q_hat):
  sorted_indices = torch.argsort(probs, descending=True)
  pred_set = []
  cumulative = 0

  for idx in sorted_indices:
    pred_set.append(idx.item())
    cumulative += probs[idx].item()
    if cumulative >= q_hat:
      break
  return pred_set

In [None]:
def get_prediction_interval(logits, q_hat):
  probs = torch.sigmoid(logits) # shape: (K-1,)
  score = 0.0
  upper_class = len(probs)

  for k in range(len(probs)):
    score += 1 - probs[k].item()
    if score > q_hat:
      upper_class = k
      break

  return list(range(upper_class + 1))  # intervalo ordinal: [0, ..., upper_class]

In [None]:
def predict_interval_from_thresholds(logits, thresholds):
  probs = torch.sigmoid(logits)
  interval = []

  for k, tau_k in enumerate(thresholds):
    if 1 - probs[k].item() <= tau_k:
        interval.append(k + 1)
    else:
        break

  return list(range(interval[-1] + 1)) if interval else [0]

In [None]:
def run_conformal_prediction(dataloader, model, model_src, thresholds, loss_function):
  model.eval()
  set_sizes = []
  coverage = 0.0
  total = 0

  for inputs, labels in tqdm(dataloader):
    inputs, labels = inputs.to(Config.device), labels.to(Config.device)

    with torch.set_grad_enabled(False):
      outputs = model(inputs)
      if model_src == 'hugging_face':
        outputs = outputs.logits

      for i in range(len(inputs)):
        if loss_function == 'corn':
          # pred_set = get_prediction_interval(outputs[i], q_hat)
          pred_set = predict_interval_from_thresholds(outputs[i], thresholds)
        else:
          probs = F.softmax(outputs, dim=1)
          pred_set = get_prediction_set(probs[i], thresholds[0])

        true_label = labels[i].item()
        print(f"Predicted set: {pred_set}, True label: {true_label}")
        set_sizes.append(len(pred_set))
        coverage += 1 if true_label in pred_set else 0
        total += 1

  coverage /= total
  return set_sizes, coverage

In [None]:
def plot_and_save_set_sizes(set_sizes, model_name, loss_function):
  plt.figure(figsize=(6, 6))
  plt.hist(set_sizes, bins=np.arange(min(set_sizes), max(set_sizes) + 2))
  plt.xlabel('Tamanho do conjunto de predições')
  plt.ylabel('Frequência')
  plt.title('Distribuição do tamanho do conjunto de predições')

  file_name = f'{model_name}_cp_{loss_function}.png'
  plt.savefig(file_name)

  plt.show()
  plt.close()

  save_in_google_drive(model_name, file_name)

In [None]:
def compute_thresholds(per_threshold_scores, epsilon):
  thresholds = []
  for scores_k in per_threshold_scores:
    n = len(scores_k)
    q_hat = np.quantile(scores_k, math.ceil((n+1)*(1-epsilon))/n)
    thresholds.append(q_hat)
  return thresholds

In [None]:
def squash_set_sizes(set_sizes):
  squashed_set_sizes = [0 for _ in range(nr_classes+1)]
  for size in set_sizes:
    squashed_set_sizes[size] += 1
  return squashed_set_sizes

In [None]:
def conformal_prediction():
  results = {}
  for model_name, model_attr in models_list.items():
    if model_attr.get('skip') == True:
      continue

    processor = get_model_processor(model_name, model_attr['source'])
    dataloaders, dataset_sizes = load_dataset(model_name, processor)

    results[model_attr['shortname']] = {}

    for loss_function in loss_functions:
      print(f"model name: {model_attr['shortname']} - loss function: {loss_function}")
      model = init_model(model_name, loss_function)

      base_path = get_results_dir(nr_classes, model_attr['shortname'])
      file_name = f"{model_attr['shortname']}_{loss_function}.pth"
      path = os.path.join(base_path, file_name)

      model = load_model(model, path)

      thresholds = None

      # calculating the non-confirmity scores
      if loss_function == 'corn':
        per_threshold_scores = calc_per_threshold_scores(model, model_attr['source'], dataloaders['calib'])
        thresholds = compute_thresholds(per_threshold_scores, Config.epsilon)
      else:
        scores = calc_nonconformity_scores(model, model_attr['source'], dataloaders['calib'], loss_function)
        thresholds = compute_thresholds([scores], Config.epsilon)

      set_sizes, coverage = run_conformal_prediction(dataloaders['test'], model, model_attr['source'], thresholds, loss_function)

      plot_and_save_set_sizes(set_sizes, model_attr['shortname'], loss_function)

      results[model_attr['shortname']][loss_function] = {
        'q_hat': thresholds,
        'coverage': coverage,
        'set_sizes': squash_set_sizes(set_sizes)
      }

  save_results('conformal_prediction', results)
  print(f'Conformal prediction completed. Reports saved.')

### Run conformal prediction

In [None]:
conformal_prediction()

# Inference Time

In [None]:
def run_model_inference_time():
  classes_to_exclude = get_classes_to_exclude()
  cm_labels = get_cm_labels()
  results = {}

  for model_name, model_attr in models_list.items():
    processor = get_model_processor(model_name, model_attr['source'])
    dataloaders, dataset_sizes = load_dataset(model_name, processor)

    results[model_attr['shortname']] = {}

    for loss_function in loss_functions:
      model = init_model(model_name, loss_function)

      base_path = get_results_dir(nr_classes, model_attr['shortname'])
      file_name = f"{model_attr['shortname']}_{loss_function}.pth"
      path = os.path.join(base_path, file_name)

      model = load_model(model, path)

      avg_inference_time, time_per_sample = measure_inference_time(model, dataloaders['test'], warmup=5, repeat=50)

      results[model_attr['shortname']][loss_function] = {
          'avg_inference_time': avg_inference_time,
          'time_per_sample': time_per_sample
      }

  save_results('inference_time', results)
  print(f'Inference time completed. Reports saved.')

In [None]:
# run_model_inference_time()

# Visualizing with Grad-CAM

In [None]:
grad_cam_criteria = {
    'per-kl-criteria': 'per_kl_class'
}

In [None]:
class HuggingFaceWrapper(nn.Module):
  def __init__(self, model):
    super().__init__()
    self.model = model

  def forward(self, x):
    return self.model(x).logits

In [None]:
class CORNWrapper(nn.Module):
  def __init__(self, model):
    super().__init__()
    self.model = model  # base model returning CORN logits

  def forward(self, x):
    return self.model(x)  # logits of shape [B, K-1]

In [None]:
class CORNOutputTarget:
  def __init__(self, class_index, epsilon=1e-1):
    self.class_index = class_index  # 0-based class index
    self.epsilon = epsilon

  def __call__(self, model_output):
    """
    CORN produces K-1 outputs.
    We define a differentiable proxy score that encourages correct attribution.
    """
    # Apply sigmoid to logits: shape [K-1]
    print(model_output)
    probs = torch.sigmoid(model_output)
    print("probs:", probs)
    print("prob shape:", probs.shape)

    # Calculate pseudo-probability for class_index
    # For class 0: P(class = 0) = 1 - sigmoid(logit_0)
    # For class k: P(class = k) = sigmoid(logit_{k-1}) - sigmoid(logit_{k})
    # if self.class_index == 0:
    #     score = 1 - probs[0]
    # elif self.class_index == probs.shape[0]:
    #     score = probs[-1]
    # else:
    #     score = probs[self.class_index - 1] - probs[self.class_index]

    if self.class_index == 0:
      return 1 - probs[0]

    return probs[:self.class_index].sum()
    # return probs[:, self.class_index]

In [None]:
def resize_image(image, model_name, processor):
  if not processor:
    image_size = (224, 224) if (model_name != 'inception_v3') else (299, 299)
  elif ('facebook/deit' in model_name):
    image_size = (processor.crop_size['height'], processor.crop_size['height'])
  else:
    image_size = (processor.size['height'], processor.size['height'])
  return image.resize(image_size)

In [None]:
def load_images_from_folder(folder_path):
  images = []
  for filename in os.listdir(folder_path):
    if filename.endswith('.png'):
      image_path = os.path.join(folder_path, filename)
      images.append(image_path)
  return images

In [None]:
def image_to_tensor(img_path, model_name, processor):
  img_pil = Image.open(img_path).convert('RGB')
  image_resized = resize_image(img_pil, model_name, processor)
  transform = get_transforms(model_name, 'test')
  img_tensor = transform(image_resized).unsqueeze(0)
  img_np = np.array(image_resized).astype(np.float32)/255.0
  return img_tensor, img_np, img_pil

In [None]:
def get_target_layer(model, model_name):
  if model_name in ['resnet34', 'resnet50', 'resnet101']:
    return [model.layer4[-1]]
  elif model_name in ['vgg16', 'vgg19']:
    return [model.features[-1]]
  elif model_name in ['densenet121', 'densenet169']:
    return [model.features[-1]]
  elif model_name in ['inception_v3']:
    return [model.Mixed_7c]
  elif model_name in ['google/vit-base-patch16-224']:
    return [model.model.vit.encoder.layer[-1].layernorm_before]
  elif model_name in ['facebook/deit-base-distilled-patch16-224']:
    return [model.model.deit.encoder.layer[-1].layernorm_before]
  elif model_name in ['davit_base.msft_in1k']:
    return [model.stages[3].blocks[0][1].norm1]
  elif model_name in ['maxvit_t']:
    return [model.blocks[-1].layers[-1].layers[-1].attn_layer[0]]
  elif model_name in ['gcvit_base.in1k']:
    return [model.stages[-1].blocks[-1].norm2]
  elif model_name in ['swin_b']:
    return [model.features[-1][-1].norm1]
  else:
    return None

In [None]:
def reshape_transform(tensor):
  # tensor shape: [batch_size, seq_len, hidden_dim]
  # remove CLS token and reshape to 2D feature map

  print(tensor.size())
  # tensor: [B, H, N, C]
  if tensor.dim() == 4:
    B, H, W, C = tensor.shape
    if H == 1: # maxvit_t
      tensor = tensor.squeeze(1)
    else:
      return tensor.permute(0, 3, 1, 2)

  B, seq_len, hidden_dim = tensor.size()

  if seq_len == 197: # stardard vit with CLS
    tensor = tensor[:, 1:, :]
  elif seq_len == 198: # deit-distilled
    tensor = tensor[:, 2:, :]
  elif seq_len == 49: # davit, and others
    pass # does not prepend CLS by default
  elif seq_len == 77: # maxvit
    tensor = tensor[:, 1:, :]
  else:
    raise ValueError(f"Invalid sequence length: {seq_len}")

  h = w = int(tensor.shape[1] ** 0.5)
  if h * w != tensor.shape[1]:
      raise ValueError(f"Cannot reshape: {tensor.shape[1]} tokens is not a square")

  # reshape to [batch, hidden_dim, height, width]
  return tensor.permute(0, 2, 1).reshape(B, hidden_dim, h, w)

In [None]:
def reshape_transform_swinb(tensor, height=7, width=7):
  result = tensor[:, 1:,:].reshape(tensor.size(0), height, width, tensor.size(2))

  # Bring the channels to the first dimension,
  # like in CNNs.
  result = result.transpose(2, 3).transpose(1, 2)
  return result

In [None]:
def run_gradcam(model, model_attr, loss_function, target_layers, input_tensor, img_np, file_name):
  input_tensor = input_tensor.to(Config.device)
  input_tensor.requires_grad = True

  model.eval()
  model.zero_grad()

  if model_attr.get('shortname') == 'swin_bo':
    cam = GradCAM(model=model,
                  target_layers=target_layers,
                  reshape_transform=reshape_transform_swinb)
  elif model_attr.get('type') == 'vit':
    cam = GradCAM(model=model,
                  target_layers=target_layers,
                  reshape_transform=reshape_transform)
  else:
    cam = GradCAM(model=model,
                  target_layers=target_layers)

  if type(model) == HuggingFaceWrapper:
    outputs = model.model(input_tensor).logits
  else:
    outputs = model(input_tensor)

  pred_class = get_predictions(outputs, loss_function).item()
  print(f"Predicted class: {pred_class}")

  if loss_function == 'corn':
    targets = [CORNOutputTarget(pred_class)]
  else:
    targets = [ClassifierOutputTarget(pred_class)]

  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
  cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

  plt.imshow(cam_image)
  plt.axis('off')
  plt.title(f'Escala KL predita: {pred_class}')
  plt.savefig(file_name)

  plt.show()
  plt.close()

  save_in_google_drive(model_attr.get('shortname'), file_name)

In [None]:
def run_batch_gradcam(criteria_name, folder_path, model_name, model_attr, loss_function):
  processor = get_model_processor(model_name, model_attr.get('source'))

  image_paths = load_images_from_folder(folder_path)
  if len(image_paths) == 0:
    print(f"Nenhuma imagem encontrada na pasta: {folder_path}")
    return

  model = init_model(model_name, loss_function)
  base_path = get_results_dir(nr_classes, model_attr.get('shortname'))
  file_name = f"{model_attr.get('shortname')}_{loss_function}.pth"
  path = os.path.join(base_path, file_name)
  model = load_model(model, path)

  if model_attr.get('source') == 'hugging_face':
    model = HuggingFaceWrapper(model)

  print(model)

  target_layers = get_target_layer(model, model_name)
  print(target_layers)

  if target_layers is None:
    print(f"Nenhuma camada de destino encontrada para o modelo: {model_name}")
    return

  for img_path in image_paths:
    input_tensor, img_np, _ = image_to_tensor(img_path, model_name, processor)
    file_name = f"{criteria_name}_gradcam_loss_{loss_function}_{os.path.basename(img_path)}"
    run_gradcam(model, model_attr, loss_function, target_layers, input_tensor, img_np, file_name)

### Run Grad-CAM

In [None]:
def exec_grad_cam():
  for model_name, model_attr in models_list.items():
    if model_attr.get('skip') == True:
      continue
    for loss_function in loss_functions:
      for criteria_name, criteria_path in grad_cam_criteria.items():
        folder_path = get_grad_cam_dir(Config.num_classes, criteria_path)
        run_batch_gradcam(criteria_name, folder_path, model_name, model_attr, loss_function)

In [None]:
# exec_grad_cam()

In [None]:
# gc2.collect()

# AUC-ROC curve

In [None]:
def corn_probs(logits: torch.Tensor) -> torch.Tensor:
  """
  Converte logits do modelo treinado com CORN para distribuição de probabilidade por classe.
  Supondo K classes, logits tem shape [B, K-1].
  """
  prob_gt = torch.sigmoid(logits)
  prob_le = 1.0 - prob_gt          # P(y ≤ k)

  B, K_minus_1 = prob_gt.shape
  K = K_minus_1 + 1

  # Inicializa tensor de probabilidades por classe
  probs = torch.zeros((B, K), device=logits.device)

  # Primeira classe: P(y == 0) = 1 - P(y > 0)
  probs[:, 0] = prob_le[:, 0]

  # Classes intermediárias: P(y == k) = P(y > k-1) * (1 - P(y > k))
  for k in range(1, K - 1):
      probs[:, k] = prob_gt[:, k - 1] * prob_le[:, k]

  # Última classe: P(y == K-1) = P(y > K-2)
  probs[:, -1] = prob_gt[:, -1]

  return probs  # shape [B, K]

In [None]:
def plot_auc_roc(model, model_attr, test_loader, loss_function):
  """
  Compute and plot the AUC-ROC curve for a PyTorch image classification model.

  Parameters:
  - model: Trained PyTorch model.
  - dataloader: DataLoader with test/validation dataset.
  - device: 'cuda' or 'cpu'.
  - num_classes: Number of classes (e.g., 5 for KL grades 0-4).
  """

  model.eval()
  model.to(Config.device)

  all_preds = []
  all_labels = []

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(Config.device), labels.to(Config.device)
      outputs = model(inputs)

      if model_attr['source'] == 'hugging_face':
        outputs = outputs.logits

      if loss_function == 'corn':
          preds = corn_probs(outputs)
      else:
        preds = F.softmax(outputs, dim=1)

      all_preds.extend(preds.cpu().numpy())
      all_labels.extend(labels.cpu().numpy())

  all_preds = np.array(all_preds)
  all_labels = np.array(all_labels)

  # Stack all predictions and targets
  y_scores = np.vstack(all_preds)
  y_true = np.hstack(all_labels)

  # Plot ROC curve for each class
  plt.figure(figsize=(8, 6))

  if Config.num_classes == 2:
    # Binary classification
    if y_scores.shape[1] == 1:
      y_scores = y_scores[:, 0]
    else:
      y_scores = y_scores[:, 1]

    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"(AUC = {roc_auc:.2f})")
  else:
    # Multiclass case
    assert y_scores.shape == (len(y_true), Config.num_classes), \
        f"Expected y_scores shape ({len(y_true)}, {Config.num_classes}), got {y_scores.shape}"

    y_true_bin = label_binarize(y_true, classes=np.arange(Config.num_classes))
    assert y_true_bin.shape == y_scores.shape, \
        f"Shape mismatch: y_true_bin {y_true_bin.shape}, y_scores {y_scores.shape}"

    for i in range(Config.num_classes):
      fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
      roc_auc = auc(fpr, tpr)
      plt.plot(fpr, tpr, lw=2, label=f"KL {i} (AUC = {roc_auc:.2f})")

  # Random classifier line
  plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Aleatório")

  plt.xlabel("FPR")
  plt.ylabel("TPR")
  plt.title("Curva AUC-ROC")
  plt.legend(loc="lower right")
  plt.grid(True)
  plt.tight_layout()

  file_name = f"{model_attr['shortname']}_auc_roc_{loss_function}.png"
  plt.savefig(file_name)

  plt.show()
  plt.close()

  save_in_google_drive(model_attr['shortname'], file_name)

In [None]:
def run_auc_roc():
  for model_name, model_attr in models_list.items():
    if model_attr.get('skip') == True:
      continue

    processor = get_model_processor(model_name, model_attr['source'])
    dataloaders, dataset_sizes = load_dataset(model_name, processor)

    for loss_function in loss_functions:
      model = init_model(model_name, loss_function)

      base_path = get_results_dir(nr_classes, model_attr['shortname'])
      file_name = f"{model_attr['shortname']}_{loss_function}.pth"
      path = os.path.join(base_path, file_name)

      model = load_model(model, path)

      plot_auc_roc(model, model_attr, dataloaders['test'], loss_function)

## Run AUC-ROC

In [None]:
# run_auc_roc()