## Import

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2"

import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch.nn.functional as F
import torch.nn as nn
from torch.cuda.amp import autocast 
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from torch.optim.lr_scheduler import _LRScheduler

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torchmetrics
from torchmetrics.classification import MultilabelAveragePrecision, MultilabelROC, MultilabelPrecisionRecallCurve, MultilabelConfusionMatrix, MultilabelHammingDistance
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score, ConfusionMatrix, Precision, MultilabelAUROC, MultilabelExactMatch
from torchmetrics.regression import MeanAbsoluteError, R2Score, MeanSquaredError, MeanAbsoluteError
import lightning as L
from lightning.pytorch.utilities.memory import garbage_collection_cuda

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

from sklearn.metrics import classification_report, ConfusionMatrixDisplay, multilabel_confusion_matrix
from sklearn.metrics import roc_curve, precision_recall_curve, roc_auc_score, auc, average_precision_score, roc_auc_score, hamming_loss
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn import metrics
from sklearn.utils.class_weight import compute_class_weight

from transformers import AutoConfig, AutoModel
from transformers import logging
logging.set_verbosity_error()

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
mpl.rcParams.update(mpl.rcParamsDefault)
import cv2

from sklearn.model_selection import StratifiedKFold

import math
import numpy as np
import pandas as pd
import random
import time
import itertools
import warnings
import ast

from progress_table import ProgressTable
import gc

warnings.filterwarnings('ignore') 
torch.cuda.empty_cache()

  check_for_updates()


## Set Seed

In [2]:
seed = 333

def set_seed(seed: int = 333) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(seed)
g = torch.Generator()
g.manual_seed(seed)

Random seed set as 333


<torch._C.Generator at 0x7f987eeb7410>

## Parameters/Config

In [23]:
model_type = "microsoft/swin-base-patch4-window12-384-in22k"

id2label = {0: "Right Upper-0", 1: "Left Upper-1", 2: "Left Lower-2", 3: "Right Lower-3"}
label2id = {"Right Upper-0": 0, "Left Upper-1": 1, "Left Lower-2": 2, "Right Lower-3": 3}

# training and model parameters
patch_size = 16
num_labels = 4
num_heads = 4
learn_rate = 7e-5
steps_multipler = 1
pct_start = 0.05
num_folds = 10
max_epoch = 40
batch_size = 8
accumulation_steps = 4
drop = 0.2
weight_decay = 0.01

mmoe_params = {
    'input_size': 3072,
    'hidden_size': 768,
    'compressed_size': 384,
    'drop': 0.2,
    'num_experts': 7,
    'num_tasks': 4,
    'tower_hidden_size': 768,
    'num_heads': 2
}

## Load Original Dataset 

In [21]:
df = pd.read_csv("5505-Dataset.csv")
df['Labels'] = df['Labels'].apply(lambda x: ast.literal_eval(x))
df

Unnamed: 0,ImagePath,Right Upper-0,Left Upper-1,Left Lower-2,Right Lower-3,Labels
0,images/6b65f571-1001_6.jpg,0,1,1,1,"[0, 1, 1, 1]"
1,images/6f6d8ea6-1001_7.jpg,0,0,0,0,"[0, 0, 0, 0]"
2,images/33795f3b-1001_8.jpg,0,0,1,0,"[0, 0, 1, 0]"
3,images/eb160acb-1001_9.jpg,0,0,1,1,"[0, 0, 1, 1]"
4,images/78741275-1001_10.jpg,0,0,0,1,"[0, 0, 0, 1]"
...,...,...,...,...,...,...
5500,images/891dfade-1129_991.jpg,1,1,1,1,"[1, 1, 1, 1]"
5501,images/b4fd6b9c-1129_992.jpg,1,1,0,0,"[1, 1, 0, 0]"
5502,images/8dd1c1e0-1129_1001.jpg,0,0,1,0,"[0, 0, 1, 0]"
5503,images/1f21c608-1129_1002.jpg,1,1,1,1,"[1, 1, 1, 1]"


## Dataset Module

In [18]:
# Initialize dataset
class DentalTrainDM(Dataset):
    
    def __init__(self, df, transform):
        self.df = df
        self.ImagePath = self.df["ImagePath"]
        self.Labels = self.df["Labels"]
        self.transform = transform
        self.resize_transform = A.Compose([
            A.Resize(height = 256, width = 256),
            A.Normalize(normalization='min_max', mean=(0.5,), std=(0.5,), p = 1.0),
            ToTensorV2(p = 1.0)
        ])
        
    def __getitem__(self, index):
        image_path = self.ImagePath[index]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)  # Load as BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   # Convert BGR to RGB

        full_pixel_values = self.transform(image=image)['image'] # Apply transformations
        
        # Split the image into left and right halves
        if len(full_pixel_values.shape) == 3:  # The transformed image has 3 dimensions (C, H, W)
            height, width = full_pixel_values.shape[1], full_pixel_values.shape[2]
            left_half = full_pixel_values[:, :, :width // 2]  # Left half
            right_half = full_pixel_values[:, :, width // 2:]  # Right half
        else:  # The transformed image is still 2D (H, W), no channel dimension
            height, width = full_pixel_values.shape
            left_half = full_pixel_values[:, :width // 2]  # Left half
            right_half = full_pixel_values[:, width // 2:]  # Right half

        Labels = torch.tensor(self.Labels[index], dtype=torch.long)
    
        return {
            "ImagePath": image_path,
            "ImagePixels": full_pixel_values,
            "LeftHalfPixels": left_half,
            "RightHalfPixels": right_half,
            "Labels": Labels
        }
            
    def __len__(self):
        return len(self.Labels)

In [19]:
class DentalValDM(Dataset):
    
    def __init__(self, df, fold, transform):
        self.df = df
        self.ImagePath = self.df["ImagePath"]
        self.Labels = self.df["Labels"]
        self.transform = transform
        self.fold = fold
        self.resize_transform = A.Compose([
            A.Resize(height = 256, width = 256),
            A.Normalize(normalization='min_max', mean=(0.5,), std=(0.5,), p = 1.0),
            ToTensorV2(p = 1.0)
        ])
        
    def __getitem__(self, index):
        fold_id = self.fold
        image_path = self.ImagePath[index]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)  # Load as BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)    # Convert BGR to RGB

        # Apply transformations
        full_pixel_values = self.transform(image=image)['image']
        resized_full_pixel_values = self.resize_transform(image=image)['image']

        # Split the image into left and right halves
        if len(full_pixel_values.shape) == 3:  # The transformed image has 3 dimensions (C, H, W)
            height, width = full_pixel_values.shape[1], full_pixel_values.shape[2]
            left_half = full_pixel_values[:, :, :width // 2]  # Left half
            right_half = full_pixel_values[:, :, width // 2:]  # Right half
        else:  # The transformed image is still 2D (H, W), no channel dimension
            height, width = full_pixel_values.shape
            left_half = full_pixel_values[:, :width // 2]  # Left half
            right_half = full_pixel_values[:, width // 2:]  # Right half
        
        Labels = torch.tensor(self.Labels[index], dtype=torch.long)
        
        return {
            "Fold_ID": fold_id,
            "ImagePath": image_path,
            "ImagePixels": full_pixel_values,
            "LeftHalfPixels": left_half,
            "RightHalfPixels": right_half,
            "Labels": Labels
        }
            
    def __len__(self):
        return len(self.Labels)

## Transforms/Augmentation

In [12]:
# training augmentation
train_transform = A.Compose([
    A.Resize(height = 512, width = 1024, p = 1.0),
    A.ShiftScaleRotate(shift_limit =  (-0.02, 0.02), 
                       scale_limit = (-0, 0),  
                       rotate_limit = (2),  
                       interpolation = 1,  
                       border_mode = cv2.BORDER_CONSTANT,
                       # border_mode = 1,
                       value = [232, 232, 232],  
                       mask_value = 0, 
                       shift_limit_x = (-0, 0),  
                       shift_limit_y = (-0.02, 0.02), 
                       rotate_method = "largest_box", 
                       always_apply = False,  
                       p = 0.6),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p = 0.3),
    A.RandomBrightnessContrast(brightness_limit = (-0.1, 0.2), contrast_limit = (-0.1, 0.1), p = 0.3),
    A.Blur(blur_limit = 3, p = 0.3),
    A.Normalize(normalization='min_max', mean=(0.5,), std=(0.5,), p = 1.0),
    ToTensorV2(p = 1.0)
])

# validation augmentation
val_transform = A.Compose([
    A.Resize(height = 512, width = 1024),
    A.Normalize(normalization='min_max', mean=(0.5,), std=(0.5,), p = 1.0),
    ToTensorV2(p = 1.0)
])

## NN.Module

In [13]:
# 4 towers: Right Upper, Left Upper, Left Lower, Right Lower
class Right_Upper_0(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, output_size=1):
        super(Right_Upper_0, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads = num_heads, batch_first=True)
        self.Linear1 = nn.Linear(input_size, hidden_size)
        self.Linear2 = nn.Linear(hidden_size, hidden_size)
        self.Classifier = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, x):
        x = self.Linear1(x)
        x = self.dropout(x)
        x = self.Linear2(x)
        x = self.dropout(x)
        x = self.Classifier(x)  
        return {"logits": x}

class Left_Upper_1(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, output_size=1):
        super(Left_Upper_1, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
        self.Linear1 = nn.Linear(input_size, hidden_size)
        self.Linear2 = nn.Linear(hidden_size, hidden_size)
        self.Classifier = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, x):
        x = self.Linear1(x)
        x = self.dropout(x)
        x = self.Linear2(x)
        x = self.dropout(x)
        x = self.Classifier(x)
        return {"logits": x}

class Left_Lower_2(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, output_size=1):
        super(Left_Lower_2, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
        self.Linear1 = nn.Linear(input_size, hidden_size)
        self.Linear2 = nn.Linear(hidden_size, hidden_size)
        self.Classifier = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, x):
        x = self.Linear1(x)
        x = self.dropout(x)
        x = self.Linear2(x)
        x = self.dropout(x)
        x = self.Classifier(x)
        return {"logits": x}

class Right_Lower_3(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, output_size=1):
        super(Right_Lower_3, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
        self.Linear1 = nn.Linear(input_size, hidden_size)
        self.Linear2 = nn.Linear(hidden_size, hidden_size)
        self.Classifier = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, x):
        x = self.Linear1(x)
        x = self.dropout(x)
        x = self.Linear2(x)
        x = self.dropout(x)
        x = self.Classifier(x)
        return {"logits": x}

In [14]:
# Expert Module
class Expert(nn.Module):
    def __init__(self, input_size, hidden_size, compressed_size, drop):
        super(Expert, self).__init__()
        self.compression = nn.Linear(input_size, compressed_size)
        self.expansion = nn.Linear(compressed_size, hidden_size)
        self.gelu = nn.GELU()
        self.final_layer = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p = drop)
    
    def forward(self, x):
        x = self.compression(x)
        x = self.expansion(x)
        x = self.gelu(x)
        x = self.final_layer(x)
        x = self.dropout(x)
        return x

# Gate Module
class Gate(nn.Module):
    def __init__(self, input_size, hidden_size, num_experts):
        super(Gate, self).__init__()
        self.Linear1 = nn.Linear(input_size, hidden_size)
        self.gelu = nn.GELU()
        self.Linear2 = nn.Linear(hidden_size, num_experts)
    
    def forward(self, x):
        x = self.Linear1(x)
        x = self.gelu(x)
        x = self.Linear2(x)
        x = F.softmax(x, dim = 1)
        return x

# MMoE Module
class MMoE(nn.Module):
    def __init__(self, input_size, hidden_size, compressed_size, drop, num_experts, num_tasks, tower_hidden_size, num_heads):
        super(MMoE, self).__init__()
        self.num_experts = num_experts
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(input_size, hidden_size, compressed_size, drop) for _ in range(num_experts)
        ])
        
        # Gates
        self.gates = nn.ModuleList([
            Gate(input_size, hidden_size, num_experts) for _ in range(num_tasks)
        ])

        self.right_upper_tower = Right_Upper_0(hidden_size, tower_hidden_size, num_heads)
        self.left_upper_tower = Left_Upper_1(hidden_size, tower_hidden_size, num_heads)
        self.left_lower_tower = Left_Lower_2(hidden_size, tower_hidden_size, num_heads)
        self.right_lower_tower = Right_Lower_3(hidden_size, tower_hidden_size, num_heads)
        
        # Task-specific Towers
        self.towers = nn.ModuleList([
            self.right_upper_tower,  
            self.left_upper_tower,  
            self.left_lower_tower,
            self.right_lower_tower
        ])
    
    def forward(self, last_hidden_state):
        # Pass full sequence to Experts
        expert_outputs = torch.stack([expert(last_hidden_state) for expert in self.experts], dim=2)
      
        # Extract CLS token for Gates
        cls_token = last_hidden_state

        # Map tower names to their indices
        tower_names = ["Right-Upper", "Left-Upper", "Left-Lower", "Right-Lower"]
        
        task_outputs = {}
        for i, (gate, tower_name) in enumerate(zip(self.gates, tower_names)):

            # Gate Output
            gate_output = gate(cls_token)
            
            # Unsqueeze Gate Output for Weighted Sum Computation
            gate_output_unsqueeze = gate_output.unsqueeze(1)
            
            # Broadcast Gate Output
            gate_output_broadcasted = gate_output_unsqueeze.expand(-1, expert_outputs.size(1), -1)
            
            # Weighted Sum of Expert Outputs
            weighted_expert_output = torch.einsum(
                "bne,bne->bn", gate_output_broadcasted, expert_outputs
            )
            
            # Pass through the Task-Specific Tower
            tower_output = self.towers[i](weighted_expert_output)
            task_outputs[tower_name] = tower_output
    
        return task_outputs

In [15]:
class GEGLU(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.proj = nn.Linear(in_features, out_features * 2)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return self.dropout(x * F.gelu(gate))

In [16]:
class DentalTransformer(torch.nn.Module):
    def __init__(self, model_type, num_class = 4, dropout = 0.2, layer_start = 4, hidden_size = 1024):
        super(DentalTransformer, self).__init__()

        self.model = AutoModel.from_pretrained(model_type, output_attentions = True, output_hidden_states = True, 
                                              id2label = id2label, label2id = label2id, image_size = (512, 1024))

        self.mmoe = MMoE(**mmoe_params) 

        self.classifier = nn.Sequential(
            GEGLU(3072, 1536),  # GEGLU layer instead of Linear(3072, 1536) + GELU
            nn.Dropout(p=dropout),
            GEGLU(1536, 768),   # GEGLU layer instead of Linear(1536, 768) + GELU
            nn.Dropout(p=dropout),
            nn.Linear(768, num_class)  # Output layer remains the same
        )
        
    def forward(self, full_pixels, left_pixels, right_pixels, interpolate_pos_encoding=True):
        
        main_re_hidden = self.model(full_pixels)["reshaped_hidden_states"][4]
        left_re_hidden = self.model(left_pixels)["reshaped_hidden_states"][4]
        right_re_hidden = self.model(right_pixels)["reshaped_hidden_states"][4]

        full_pool = main_re_hidden.mean(dim=(-2, -1))   # Shape: [B, 1024]
        left_pool = left_re_hidden.mean(dim=(-2, -1))     # Shape: [B, 1024]
        right_pool = right_re_hidden.mean(dim=(-2, -1))   # Shape: [B, 1024]

        concat_pool = torch.cat((left_pool, full_pool, right_pool), dim=1)
        # print(concat_pool.shape)

        mmoe_out = self.mmoe(concat_pool)
        # print(mmoe_out)
        
        # outputs = self.classifier(concat_pool)  
        
        return mmoe_out

## Loss Function

In [25]:
def bce_loss(outputs, targets):
    targets = targets.float()
    loss = nn.BCEWithLogitsLoss()(outputs, targets)
    return loss

## Training & Validation Function

In [26]:
table = ProgressTable(num_decimal_places=4, interactive=1)
table.add_columns("Fd", "Ep")
table.add_columns("T Loss")
table.add_columns("T Acc", "T F1", "T Ham")  # Classification metrics
table.add_columns("T Steps")
table.add_columns("Time/Ep")
table.add_columns("V Loss")
table.add_columns("V SubAcc", "V Pre", "V Rec", "V F1", "V Ham", "V MacroAcc")          # Validation classification metrics
table.add_columns("V Steps")
table.add_columns("BestEpoch")
table.add_columns("Learning Rate")

def train_model(start_epoch, n_epochs, valid_acc_max_input, training_loader, validation_loader, model, 
                optimizer, scheduler = None, checkpoint_path = None, best_model_path = None, run_id = None, continue_train = False, 
                model_state_dict = None, optimizer_state_dict = None, accumulation_steps = 2):
    
    fold = run_id+1

    if continue_train == True:
        optimizer.load_state_dict(optimizer_state_dict)
        model.load_state_dict(model_state_dict)
    
    # initialize tracker for minimum validation loss
    valid_acc_max = valid_acc_max_input 
    
    # initialize metrics
    t_sub_acc = MultilabelExactMatch(num_labels = 4).to(device)
    t_ham_mi = MultilabelHammingDistance(num_labels = 4 , average='micro').to(device)

    v_sub_acc = MultilabelExactMatch(num_labels = 4).to(device)
    v_ham_mi = MultilabelHammingDistance(num_labels = 4 , average='micro').to(device)
    v_macro_acc = MultilabelAccuracy(num_labels = 4, average = "macro").to(device)

    for epoch in range(start_epoch, n_epochs+1):

        # Reset metrics
        t_sub_acc.reset()
        t_ham_mi.reset()

        v_sub_acc.reset()
        v_ham_mi.reset()
        v_macro_acc.reset()
        
        t_epoch_f1 = []

        v_epoch_f1 = []
        v_epoch_precision = []
        v_epoch_recall = []
        
        table.update("Ep", f"{epoch}/{n_epochs}")
        table.update("Fd", fold)
        e_t0 = time.time()
        
        # step counter
        t_steps = 0
        v_steps = 0
        
        t_f1_cumulative = 0
        v_f1_cumulative = 0
        v_precision_cumulative = 0
        v_recall_cumulative = 0
        
        # train
        model.train()
        for i, data in enumerate(training_loader, 0):
            
            t_full_pixels = data["ImagePixels"].to(device, non_blocking=True)
            t_left_pixels = data["LeftHalfPixels"].to(device, non_blocking=True)
            t_right_pixels = data["RightHalfPixels"].to(device, non_blocking=True)
            t_class_labels = data["Labels"].to(device, dtype = torch.long, non_blocking=True)

            if i % accumulation_steps == 0:
                model.zero_grad(set_to_none=True)

                # training step
                t_steps += 1
                table.update("T Steps", f"{t_steps}/{int((len(training_loader) / accumulation_steps))}") # f"{t_progress:.2%}"   
                
            with torch.autocast("cuda"):

                outputs = model(full_pixels = t_full_pixels, left_pixels = t_left_pixels, right_pixels = t_right_pixels)
                t_class_logits = torch.cat((outputs["Right-Upper"]["logits"], outputs["Left-Upper"]["logits"], outputs["Left-Lower"]["logits"], outputs["Right-Lower"]["logits"]), 1)
                t_class_preds = torch.sigmoid(t_class_logits)

                t_class_loss = bce_loss(t_class_logits, t_class_labels)                 
                table.update("T Loss", f"{t_class_loss.item():.4f}", aggregate = "mean")
            
            # torch metrics
            t_sub_acc.update(t_class_preds, t_class_labels) 
            t_acc = t_sub_acc.compute()
            table.update("T Acc", f"{t_acc.item():.4f}", aggregate = "mean")

            t_ham_mi.update(t_class_preds, t_class_labels)
            t_ham = t_ham_mi.compute()
            table.update("T Ham", f"{t_ham.item():.4f}", aggregate = "mean")

            # sklearn metrics
            t_preds_np = (t_class_preds.detach().cpu().numpy() >= 0.5).astype(int)
            t_labels_np = t_class_labels.detach().cpu().numpy()
                
            t_f1 = f1_score(t_labels_np, t_preds_np, average="samples")
            t_f1_cumulative += t_f1
            t_f1_running_avg = t_f1_cumulative / (i + 1)
            table.update("T F1", f"{t_f1_running_avg:.4f}", aggregate = "mean")
            t_epoch_f1.append(t_f1)
            
            # backprop
            t_class_loss = t_class_loss / accumulation_steps
            t_class_loss.backward() 
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            if i % accumulation_steps == accumulation_steps - 1 or (i + 1) == len(training_loader):
                optimizer.step()
                if scheduler is not None:
                    if scheduler.last_epoch < scheduler.total_steps:  # Prevent stepping beyond total_steps
                        scheduler.step()
                    
                opt_lr = optimizer.state_dict()["param_groups"][0]["lr"]
                table.update("Learning Rate", f"{opt_lr:.8f}")
        
        # epoch time
        epoch_time = time.time() - e_t0
        table.update("Time/Ep",f"{epoch_time:.2f}")
        
        # validation
        model.eval()
        with torch.no_grad():
            
            for i, data in enumerate(validation_loader, 0):
                    
                v_full_pixels = data["ImagePixels"].to(device, non_blocking=True)
                v_left_pixels = data["LeftHalfPixels"].to(device, non_blocking=True)
                v_right_pixels = data["RightHalfPixels"].to(device, non_blocking=True)
                v_class_labels = data["Labels"].to(device, dtype = torch.long, non_blocking=True)      
                    
                with torch.autocast("cuda"):
    
                    outputs = model(full_pixels = v_full_pixels, left_pixels = v_left_pixels, right_pixels = v_right_pixels)
                    v_class_logits = torch.cat((outputs["Right-Upper"]["logits"], outputs["Left-Upper"]["logits"], outputs["Left-Lower"]["logits"], outputs["Right-Lower"]["logits"]), 1)
                    v_class_preds = torch.sigmoid(v_class_logits)
    
                    v_class_loss = bce_loss(v_class_logits, v_class_labels)
                    table.update("V Loss", f"{v_class_loss.item():.4f}", aggregate = "mean")
                    
                # validation loss/step 
                v_steps += 1   
                table.update("V Steps", f"{v_steps}/{(len(validation_loader))}") # f"{v_progress:.2%}"
                                  
                # torch metrics
                v_sub_acc.update(v_class_preds, v_class_labels) 
                v_acc = v_sub_acc.compute()
                table.update("V SubAcc", f"{v_acc:.4f}", aggregate = "mean")

                v_ham_mi.update(v_class_preds, v_class_labels)
                v_ham = v_ham_mi.compute()
                table.update("V Ham", f"{v_ham.item():.4f}", aggregate = "mean")

                v_macro_acc.update(v_class_preds, v_class_labels)
                v_ma_acc = v_macro_acc.compute()
                table.update("V MacroAcc", f"{v_ma_acc.item():.4f}", aggregate = "mean")

                # sklearn metrics
                v_preds_np = (v_class_preds.detach().cpu().numpy() >= 0.5).astype(int)
                v_labels_np = v_class_labels.detach().cpu().numpy()
                
                v_f1 = f1_score(v_labels_np, v_preds_np, average="samples")
                v_f1_cumulative += v_f1
                v_f1_running_avg = v_f1_cumulative / (i + 1)
                table.update("V F1", f"{v_f1_running_avg:.4f}", aggregate = "mean")
                v_epoch_f1.append(v_f1)
                
                v_precision = precision_score(v_labels_np, v_preds_np, average="samples")
                v_precision_cumulative += v_precision
                v_precision_running_avg = v_precision_cumulative / (i + 1)
                table.update("V Pre", f"{v_precision_running_avg:4f}", aggregate = "mean")
                v_epoch_precision.append(v_precision)
                        
                v_recall = recall_score(v_labels_np, v_preds_np, average="samples")
                v_recall_cumulative += v_recall
                v_recall_running_avg = v_recall_cumulative / (i + 1)
                table.update("V Rec", v_recall_running_avg, aggregate = "mean")
                v_epoch_recall.append(v_recall)                
                
        # save the model if validation f1 is higer than max
        if valid_acc_max <= v_epoch_ma_acc.item():
            torch.save(model.state_dict(), f'./Saved/f{fold}_best.pth')
            best_epoch = epoch
            valid_acc_max = v_epoch_ma_acc.item()
        table.update("BestEpoch", best_epoch)
        
        table.next_row()

    del model
    del training_loader
    del validation_loader
    collect_garbage()
        
    return best_epoch

## Load/Save checkpoint function

In [27]:
# load checkpoint
def load_checkpoint(checkpoint_path, model, optimizer = None):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint)
    return model

# save checkpoint
def save_checkpoint(state, is_best, checkpoint_path, best_model_path):
    f_path = checkpoint_path
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_path
        shutil.copyfile(f_path, best_fpath)

## Prediction Function

In [28]:
def prediction(best_model, validation_loader):
    
    best_model.eval()

    imagepath_list = []
    
    preds_list = []
    targets_list = []
    probs_list = []
    
    fold_list = []
    
    softmax = torch.nn.Softmax(dim = 1)
    
    steps = 0
    
    with torch.no_grad():
        for batch_idx, data in enumerate(validation_loader, 0):
            steps += 1

            fold = data["Fold_ID"]
            imagepath = data["ImagePath"]
            full_pixels = data["ImagePixels"].to(device)
            left_pixels = data["LeftHalfPixels"].to(device, non_blocking=True)
            right_pixels = data["RightHalfPixels"].to(device, non_blocking=True)
            labels = data["Labels"].to(device, dtype = torch.long)      
            
            outputs = best_model(full_pixels = full_pixels, left_pixels = left_pixels, right_pixels = right_pixels)
            class_logits = torch.cat((outputs["Right-Upper"]["logits"], outputs["Left-Upper"]["logits"], outputs["Left-Lower"]["logits"], outputs["Right-Lower"]["logits"]), 1)
            sigmoid_preds = torch.sigmoid(class_logits)

            # Append classification outputs
            imagepath_list.extend(imagepath)
            preds_list.extend(sigmoid_preds)
            targets_list.extend(labels)
            fold_list.extend(fold)
            
    return preds_list, targets_list, imagepath_list, fold_list

## Compute Standard Deviation Function

In [29]:
def compute_std(metrics_df, max_epoch):
    last_epoch_metrics = metrics_df.loc[(metrics_df['Epoch'] == f"{max_epoch}/{max_epoch}")]
    print("Validation Metrics Standard Deviation")
    print(f"Accuracy:  {np.std(last_epoch_metrics['V Acc'].to_list()):.4f}")
    print(f"Precision: {np.std(last_epoch_metrics['V Precision'].to_list()):.4f}")
    print(f"Recall:    {np.std(last_epoch_metrics['V Recall'].to_list()):.4f}")
    print(f"F1 Score:  {np.std(last_epoch_metrics['V F1'].to_list()):.4f}")

In [31]:
# clear memory
def collect_garbage():
    garbage_collection_cuda()
    time.sleep(3)
    torch.cuda.empty_cache()
    garbage_collection_cuda()
    gc.collect()

## Training

In [None]:
# Gather all predictions/targets for each fold
all_preds_list = []
all_targets_list = []
all_probs_list = []
all_path_list = []
all_fold_list = []
    
skfold = MultilabelStratifiedKFold(n_splits = num_folds, random_state = seed, shuffle = True)
y = np.array(df["Labels"].tolist())
for fold, (train_ids, val_ids) in enumerate(skfold.split(df, y)):
    
    train_dataset = df.loc[train_ids].copy().reset_index(drop = True)
    val_dataset = df.loc[val_ids].copy().reset_index(drop = True)
    
    train_datamodule = DentalTrainDM(train_dataset, transform = train_transform)
    val_datamodule = DentalValDM(val_dataset, fold = fold, transform = val_transform)
    
    
    training_loader = torch.utils.data.DataLoader(train_datamodule, 
                                                  batch_size = batch_size, 
                                                  num_workers = 16,
                                                  prefetch_factor = 2,
                                                  pin_memory = True,
                                                  shuffle = True)
    
    validation_loader = torch.utils.data.DataLoader(val_datamodule,
                                                    batch_size = batch_size, 
                                                    num_workers = 16,
                                                    prefetch_factor = 2,
                                                    pin_memory = True,
                                                    shuffle = True
                                                   )
    
    # initialize model
    model = DentalTransformer(model_type, hidden_size = 768)
    model = nn.DataParallel(model, device_ids=[0, 1, 2])
    model = torch.compile(model)
    model = model.to(device)
    
    # initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = learn_rate, weight_decay = weight_decay)

    batches_per_epoch = len(training_loader) 
    total_steps = int((batches_per_epoch / accumulation_steps) * max_epoch * steps_multipler)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = learn_rate, total_steps = total_steps, pct_start=pct_start)

    # training
    best_epoch = train_model(start_epoch = 1, n_epochs = max_epoch, valid_acc_max_input = 0,           
                             training_loader = training_loader, validation_loader = validation_loader, 
                             model = model, optimizer = optimizer, scheduler = scheduler, run_id = fold, 
                             accumulation_steps = accumulation_steps)    
    
    # load maxepoch/best model
    best_model_path = f'./Saved/f{fold+1}_best.pth'
    best_model = load_checkpoint(best_model_path, model)
    preds, targets, imagepath, run_id = prediction(best_model, validation_loader)

    # collect all predictions and targets over all folds
    all_preds_list.extend(preds)
    all_targets_list.extend(targets)
    all_path_list.extend(imagepath)
    all_fold_list.extend(run_id)
    
    # compute single fold validation performance
    targets_stacked = torch.stack(targets)  # Shape: [num_samples, num_labels]
    preds_stacked = torch.stack(preds)  # Shape: [num_samples, num_labels]
    
    # Apply sigmoid and threshold to predictions
    preds_probs = torch.sigmoid(preds_stacked)
    preds_binary = (preds_probs >= 0.5).int()
    
    # Convert to NumPy
    targets_np = targets_stacked.detach().cpu().numpy()
    preds_np = preds_binary.detach().cpu().numpy()

    performance = classification_report(targets_np, preds_np, 
                                        target_names = target_names, digits = 6, output_dict = True)
    performance_report = pd.DataFrame(performance).transpose()
    performance_report.to_csv(f'./Saved/BestPerformance_{version}_{architecture}_f{fold+1}.csv', index=True)

    del model
    del best_model
    del training_loader
    del validation_loader
    collect_garbage()

╭──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬──────────┬────────────┬──────────┬───────────┬───────────────╮
│    Fd    │    Ep    │  T Loss  │  T Acc   │   T F1   │  T Ham   │ T Steps  │ Time/Ep  │  V Loss  │ V SubAcc │  V Pre   │  V Rec   │   V F1   │  V Ham   │ V MacroAcc │ V Steps  │ BestEpoch │ Learning Rate │
├──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼──────────┼────────────┼──────────┼───────────┼───────────────┤
│    1     │   1/40   │          │          │          │          │  1/155   │          │          │          │          │          │          │          │            │          │           │               │

W0408 22:05:28.858000 2607602 site-packages/torch/_logging/_internal.py:1081] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


















































































│    1     │   1/40   │  0.9095  │  0.3876  │  0.1205  │  0.3610  │ 156/155  │  603.14  │  0.2094  │  0.3575  │ 0.422727 │  0.6302  │  0.4827  │  0.3689  │   0.6311   │  69/69   │     1     │   0.00003683  │
│    1     │   2/40   │  0.1865  │  0.6325  │  0.4555  │  0.1406  │ 156/155  │  515.21  │  0.6793  │  0.7514  │ 0.580616 │  0.6142  │  0.5881  │  0.0799  │   0.9201   │  69/69   │     2     │   0.00007000  │
│    1     │   3/40   │  0.0098  │  0.7920  │  0.5595  │  0.0698  │ 156/155  │  518.09  │  0.0351  │  0.8094  │ 0.599832 │  0.6136  │  0.6001  │  0.0631  │   0.9369   │  69/69   │     3     │   0.00006988  │
│    1     │   4/40   │  0.0261  │  0.8545  │  0.5865  │  0.0451  │ 156/155  │  506.77  │  0.0101  │  0.8457  │ 0.608825 │  0.5877  │  0.5926  │  0.0499  │   0.9501   │  69/69   │     4     │   0.00006951  │
│    1     │   5/40   │  0.0015  │  0.8885  │  0.5988  │  0.0346  │ 156/155  │  504.04  │  0.4101  │  0.8802  │ 0.626812 │  0.6209  │  0.6195  │  0.0340  │   0.9660   │

In [None]:
table.close()
run.stop()

In [None]:
metrics_df = table.to_df()
metrics_df.to_csv(f'Saved/{version}_{max_epoch}e_{fold+1}f_metrics.csv', index=False)
metrics_df

In [1]:
pred_tar_df =  pd.DataFrame(list(zip(np.array(all_fold_list), np.array(all_path_list), 
                                        torch.stack(all_preds_list).detach().cpu().numpy(), 
                                        torch.stack(all_targets_list).detach().cpu().numpy()
                                       )
                                   ),
                               columns =["Fold", "ImagePaths", "Predictions", "Targets"])

pred_tar_df.to_csv(f'Saved/preds_targets_probs_{max_epoch}e_{fold+1}f.csv', index=False)
pred_tar_df

NameError: name 'pd' is not defined