# Dual-Head Oral Pathology Classifier

### `configs/config.py`

In [None]:
import os
import torch

# Base Paths (Updated for local usage or specific environment if needed)
# NOTE: Update BASE_PATH to match your actual data location
BASE_PATH = '/workspace'

# Dataset 1 Paths
DS1_ORIGINAL_BENIGN = os.path.join(BASE_PATH, 'Dataset 1', 'original_data', 'benign_lesions')
DS1_ORIGINAL_MALIGNANT = os.path.join(BASE_PATH, 'Dataset 1', 'original_data', 'malignant_lesions')

# Dataset 2 Paths
DS2_TRAINING = os.path.join(BASE_PATH, 'Dataset 2 ', 'Training')
DS2_VALIDATION = os.path.join(BASE_PATH, 'Dataset 2 ', 'Validation')
DS2_TESTING = os.path.join(BASE_PATH, 'Dataset 2 ', 'Testing')

# Dataset configuration
DS2_CLASSES = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']
MALIGNANT_SUBTYPES = ['MC', 'OC', 'CaS']
NUM_SUBTYPES = len(DS2_CLASSES)

# Model configuration
IMG_SIZE = 224
BATCH_SIZE = 128  
NUM_WORKERS = 8  
# BACKBONE Options:
# 'resnet50'        - ResNet50 (Default)
# 'densenet121'     - DenseNet121
# 'convnext_tiny'   - ConvNeXt Tiny
# 'swin_t'          - Swin Transformer Tiny
# 'efficientnet_b0' - EfficientNet B0
# 'efficientnet_v2b2' - EfficientNet V2-B2
# 'efficientnet_v2b3' - EfficientNet V2-B3
# 'efficientnet_v2s'  - EfficientNet V2-S
BACKBONE = 'swin_t'
DROPOUT = 0.5
USE_PRETRAINED = False  
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Training configuration
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 200
SEED = 42

# Learning Rate Scheduler configuration
SCHEDULER_TYPE = 'cosine'  
SCHEDULER_PATIENCE = 5  
SCHEDULER_FACTOR = 0.5  
SCHEDULER_STEP_SIZE = 10  
SCHEDULER_GAMMA = 0.95  

# Early Stopping configuration
EARLY_STOPPING = True
EARLY_STOPPING_PATIENCE = 15  
EARLY_STOPPING_MIN_DELTA = 1e-4  

# Paths for saving results
SAVE_DIR = os.path.join(BASE_PATH, 'results', BACKBONE)
BEST_MODEL_PATH = os.path.join(SAVE_DIR, 'best_model.pth')
HISTORY_PLOT_PATH = os.path.join(SAVE_DIR, 'training_history.png')
CONFUSION_MATRIX_PATH = os.path.join(SAVE_DIR, 'confusion_matrices.png')

# Ensure save directory exists
os.makedirs(SAVE_DIR, exist_ok=True)


### `utils/common.py`

In [None]:
import os
import random
import numpy as np
import torch
from glob import glob

def set_seed(seed=SEED):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

def get_device():
    """Get the current device."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    return device

def check_paths_exist(paths_list):
    """Verify that all paths in the list exist."""
    all_exist = True
    print("\nChecking paths...")
    for path in paths_list:
        exists = os.path.exists(path)
        status = "✓" if exists else "✗"
        print(f"{status} {path}")
        if not exists:
            all_exist = False
    return all_exist

def count_images_in_folder(folder):
    """Count image files in a folder."""
    if not os.path.exists(folder):
        return 0
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
    count = 0
    for ext in extensions:
        count += len(glob(os.path.join(folder, ext)))
        count += len(glob(os.path.join(folder, ext.upper())))
    return count


### `utils/evaluation.py`

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from PIL import Image

def evaluate_model(model, test_loader, device):
    model.eval()
    results = {
        'preds_binary': [], 'targets_binary': [],
        'preds_subtype': [], 'targets_subtype': []
    }
    
    with torch.no_grad():
        for images, targets_b, targets_s in tqdm(test_loader, desc="Testing"):
            images = images.to(device)
            pred_b, pred_s = model(images)
            
            results['preds_binary'].extend(torch.argmax(pred_b, dim=1).cpu().numpy())
            results['targets_binary'].extend(targets_b.numpy())
            
            mask = targets_s != -1
            if mask.sum() > 0:
                results['preds_subtype'].extend(torch.argmax(pred_s[mask], dim=1).cpu().numpy())
                results['targets_subtype'].extend(targets_s[mask].numpy())
                
    return {k: np.array(v) for k, v in results.items()}

def plot_confusion_matrix(y_true, y_pred, classes, title, ax):
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(title)

def predict_single_image(model, image_path, device):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    tensor = val_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_b, pred_s = model(tensor)
        prob_b = torch.softmax(pred_b, dim=1)[0]
        prob_s = torch.softmax(pred_s, dim=1)[0]
        
        idx_b = torch.argmax(prob_b).item()
        idx_s = torch.argmax(prob_s).item()
    
    return {
        'binary': ('Malignant' if idx_b == 1 else 'Benign', prob_b[idx_b].item()),
        'subtype': (DS2_CLASSES[idx_s], prob_s[idx_s].item())
    }


### `data/transforms.py`

In [None]:
import torchvision.transforms as transforms

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


### `data/dataset.py`

In [None]:
import os
from glob import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class OralPathologyDataset(Dataset):
    """Union Dataset for Dual-Head Multi-Task Learning."""
    def __init__(self, image_paths, labels_binary, labels_subtype, transform=None):
        self.image_paths = image_paths
        self.labels_binary = labels_binary
        self.labels_subtype = labels_subtype
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels_binary[idx], self.labels_subtype[idx]

def get_image_files(folder):
    if not os.path.exists(folder): return []
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
    files = []
    for ext in extensions:
        files.extend(glob(os.path.join(folder, ext)))
        files.extend(glob(os.path.join(folder, ext.upper())))
    return files

def load_dataset1_split(split='train', test_size=0.10, val_size=0.10, random_state=42):
    """
    Fixed split with properly separated train/val/test sets.
    Test set is completely held out and should ONLY be used for final evaluation.
    """
    benign_paths = get_image_files(DS1_ORIGINAL_BENIGN)
    malignant_paths = get_image_files(DS1_ORIGINAL_MALIGNANT)
    all_paths = benign_paths + malignant_paths
    all_binary = [0] * len(benign_paths) + [1] * len(malignant_paths)
    all_subtype = [-1] * len(all_paths)
    
    # First split: separate test set (held out completely)
    temp_paths, test_paths, temp_bin, test_bin, temp_sub, test_sub = train_test_split(
        all_paths, all_binary, all_subtype, test_size=test_size, random_state=random_state, stratify=all_binary
    )
    
    # Second split: divide remaining into train and validation
    val_size_adj = val_size / (1 - test_size)
    train_paths, val_paths, train_bin, val_bin, train_sub, val_sub = train_test_split(
        temp_paths, temp_bin, temp_sub, test_size=val_size_adj, random_state=random_state, stratify=temp_bin
    )
    
    if split == 'train': return train_paths, train_bin, train_sub
    elif split == 'val': return val_paths, val_bin, val_sub
    else: return test_paths, test_bin, test_sub

def load_dataset2_split(split='train', test_size=0.20, val_size=0.20, random_state=42):
    """
    Load Dataset 2 with proper train/val/test split.
    MERGES Training + Validation folders, then splits properly to avoid
    the suspicious pre-made Testing folder with identical distributions.
    """
    image_paths, labels_binary, labels_subtype = [], [], []
    
    # Merge Training + Validation folders (ignore the suspicious Testing folder)
    for base_path in [DS2_TRAINING, DS2_VALIDATION]:
        for idx, subtype in enumerate(DS2_CLASSES):
            subtype_path = os.path.join(base_path, subtype)
            imgs = get_image_files(subtype_path)
            image_paths.extend(imgs)
            labels_subtype.extend([idx] * len(imgs))
            labels_binary.extend([1 if subtype in MALIGNANT_SUBTYPES else 0] * len(imgs))
    
    # Now split this merged data properly
    # First split: separate test set (held out completely)
    temp_paths, test_paths, temp_bin, test_bin, temp_sub, test_sub = train_test_split(
        image_paths, labels_binary, labels_subtype, 
        test_size=test_size, random_state=random_state, stratify=labels_subtype
    )
    
    # Second split: divide remaining into train and validation
    val_size_adj = val_size / (1 - test_size)
    train_paths, val_paths, train_bin, val_bin, train_sub, val_sub = train_test_split(
        temp_paths, temp_bin, temp_sub, 
        test_size=val_size_adj, random_state=random_state, stratify=temp_sub
    )
    
    if split == 'train': 
        return train_paths, train_bin, train_sub
    elif split == 'val': 
        return val_paths, val_bin, val_sub
    else: 
        return test_paths, test_bin, test_sub


### `models/architecture.py`

In [None]:
import torch
import torch.nn as nn
import timm

# Map user-friendly names to timm model names
BACKBONE_MAP = {
    'resnet50': 'resnet50',
    'densenet121': 'densenet121',
    'convnext_tiny': 'convnext_tiny',
    'swin_t': 'swin_tiny_patch4_window7_224',
    'efficientnet_b0': 'efficientnet_b0',
    'efficientnet_v2b2': 'tf_efficientnetv2_b2',
    'efficientnet_v2b3': 'tf_efficientnetv2_b3',
    'efficientnet_v2s': 'tf_efficientnetv2_s',
}

class MultiTaskOralClassifier(nn.Module):
    """
    Dual-Head Multi-Task Model for Oral Pathology Classification.
    Shared Backbone with Two Independent Parallel Heads.
    Uses 'timm' for flexible backbone selection.
    """
    def __init__(self, backbone=None, num_subtypes=None, dropout=None, pretrained=None):
        super(MultiTaskOralClassifier, self).__init__()
        
        # Resolve config with arguments or defaults
        self.backbone_name = backbone if backbone else BACKBONE
        self.num_subtypes = num_subtypes if num_subtypes is not None else NUM_SUBTYPES
        self.dropout_val = dropout if dropout is not None else DROPOUT
        self.use_pretrained = pretrained if pretrained is not None else USE_PRETRAINED
        
        # Determine timm model name
        if self.backbone_name not in BACKBONE_MAP:
            raise ValueError(f"Unsupported backbone: {self.backbone_name}. Supported: {list(BACKBONE_MAP.keys())}")
            
        timm_model_name = BACKBONE_MAP[self.backbone_name]
        print(f"Initializing {self.backbone_name} ({timm_model_name}) with pretrained={self.use_pretrained}")
        
        # Initialize backbone using timm
        self.backbone = timm.create_model(
            timm_model_name, 
            pretrained=self.use_pretrained, 
            num_classes=0
        )
        
        # Automatically get the number of output features from the backbone
        num_features = self.backbone.num_features
        
        self.dropout_layer = nn.Dropout(p=self.dropout_val)
        
        # Head 1: Binary Classification (Malignant vs Benign)
        self.head_binary = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=self.dropout_val),
            nn.Linear(512, 2)
        )
        
        # Head 2: Subtype Classification
        self.head_subtype = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=self.dropout_val),
            nn.Linear(512, self.num_subtypes)
        )
        
        print(f"Model initialized with {self.backbone_name} backbone (features={num_features})")
    
    def forward(self, x):
        features = self.backbone(x)
        features = self.dropout_layer(features)
        
        out_binary = self.head_binary(features)
        out_subtype = self.head_subtype(features)
        
        return out_binary, out_subtype
    
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("Backbone frozen.")
    
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("Backbone unfrozen.")


### `models/loss.py`

In [None]:
import torch
import torch.nn as nn

class MultiTaskLoss(nn.Module):
    """
    Combined loss for Multi-Task Learning.
    Loss 1 (Binary): CrossEntropyLoss for all samples
    Loss 2 (Subtype): CrossEntropyLoss with ignore_index=-1 (Masking Trick)
    """
    def __init__(self, weight_binary=1.0, weight_subtype=1.0):
        super(MultiTaskLoss, self).__init__()
        self.weight_binary = weight_binary
        self.weight_subtype = weight_subtype
        
        self.criterion_binary = nn.CrossEntropyLoss()
        self.criterion_subtype = nn.CrossEntropyLoss(ignore_index=-1)
    
    def forward(self, pred_binary, pred_subtype, target_binary, target_subtype):
        # Binary loss (all samples contribute)
        loss_binary = self.criterion_binary(pred_binary, target_binary)
        
        # Subtype loss (only DS2 samples contribute via ignore_index)
        loss_subtype = self.criterion_subtype(pred_subtype, target_subtype)
        
        # Handle NaN if batch has only DS1 samples
        if torch.isnan(loss_subtype):
            loss_subtype = torch.tensor(0.0, device=pred_binary.device)
        
        total_loss = (self.weight_binary * loss_binary) + (self.weight_subtype * loss_subtype)
        
        return total_loss, loss_binary, loss_subtype


### `engine/trainer.py`

In [None]:
import torch
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import numpy as np

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss, running_loss_b, running_loss_s = 0.0, 0.0, 0.0
    all_preds_b, all_targets_b = [], []
    all_preds_s, all_targets_s = [], []
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for images, targets_b, targets_s in pbar:
        images, targets_b, targets_s = images.to(device), targets_b.to(device), targets_s.to(device)
        
        optimizer.zero_grad()
        pred_b, pred_s = model(images)
        loss, loss_b, loss_s = criterion(pred_b, pred_s, targets_b, targets_s)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_loss_b += loss_b.item()
        running_loss_s += loss_s.item() if not torch.isnan(loss_s) else 0
        
        preds_b = torch.argmax(pred_b, dim=1)
        all_preds_b.extend(preds_b.cpu().numpy())
        all_targets_b.extend(targets_b.cpu().numpy())
        
        mask = targets_s != -1
        if mask.sum() > 0:
            preds_s = torch.argmax(pred_s[mask], dim=1)
            all_preds_s.extend(preds_s.cpu().numpy())
            all_targets_s.extend(targets_s[mask].cpu().numpy())
            
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
    avg_loss = running_loss / len(train_loader)
    avg_loss_b = running_loss_b / len(train_loader)
    avg_loss_s = running_loss_s / len(train_loader)
    acc_b = accuracy_score(all_targets_b, all_preds_b)
    acc_s = accuracy_score(all_targets_s, all_preds_s) if all_targets_s else 0.0
    
    return avg_loss, avg_loss_b, avg_loss_s, acc_b, acc_s

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds_b, all_targets_b = [], []
    all_preds_s, all_targets_s = [], []
    
    with torch.no_grad():
        for images, targets_b, targets_s in tqdm(val_loader, desc="Validating", leave=False):
            images, targets_b, targets_s = images.to(device), targets_b.to(device), targets_s.to(device)
            pred_b, pred_s = model(images)
            loss, _, _ = criterion(pred_b, pred_s, targets_b, targets_s)
            running_loss += loss.item()
            
            preds_b = torch.argmax(pred_b, dim=1)
            all_preds_b.extend(preds_b.cpu().numpy())
            all_targets_b.extend(targets_b.cpu().numpy())
            
            mask = targets_s != -1
            if mask.sum() > 0:
                preds_s = torch.argmax(pred_s[mask], dim=1)
                all_preds_s.extend(preds_s.cpu().numpy())
                all_targets_s.extend(targets_s[mask].cpu().numpy())
    
    avg_loss = running_loss / len(val_loader)
    acc_b = accuracy_score(all_targets_b, all_preds_b)
    acc_s = accuracy_score(all_targets_s, all_preds_s) if all_targets_s else 0.0
    
    return avg_loss, acc_b, acc_s


### `train.py`

In [None]:
"""
Training script - Does NOT evaluate on test set.
Test set should only be evaluated once at the very end using evaluate_final.py
"""
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

def main():
    set_seed()
    device = get_device()
    
    # Argument Parsing
    import argparse
    parser = argparse.ArgumentParser(description='Train Oral Pathology Model')
    parser.add_argument('--backbone', type=str, default=BACKBONE, help='Backbone model name')
    args = parser.parse_args()
    
    current_backbone = args.backbone
    print(f"Using Backbone: {current_backbone}")
    
    # Update paths if backbone changed from config default
    import os
    from configs import config
    
    # Recalculate paths based on the chosen backbone
    # Note: We use config.BASE_PATH to ensure we are relative to the workspace root
    current_save_dir = os.path.join(config.BASE_PATH, 'results', current_backbone)
    current_best_model_path = os.path.join(current_save_dir, 'best_model.pth')
    
    os.makedirs(current_save_dir, exist_ok=True)
    print(f"Results will be saved to: {current_save_dir}")
    
    # 1. Load Datasets - ONLY TRAIN AND VAL (NO TEST!)
    print("Loading datasets...")
    print(" Test set is NOT loaded during training to prevent data leakage!")
    
    # Dataset 1 (Binary only)
    d1_train_p, d1_train_b, d1_train_s = load_dataset1_split('train')
    d1_val_p, d1_val_b, d1_val_s = load_dataset1_split('val')
    
    # Dataset 2 (Both labels) - now properly split from merged Training+Validation
    d2_train_p, d2_train_b, d2_train_s = load_dataset2_split('train')
    d2_val_p, d2_val_b, d2_val_s = load_dataset2_split('val')
    
    # Combine
    train_paths = d1_train_p + d2_train_p
    train_binary = d1_train_b + d2_train_b
    train_subtype = d1_train_s + d2_train_s
    
    val_paths = d1_val_p + d2_val_p
    val_binary = d1_val_b + d2_val_b
    val_subtype = d1_val_s + d2_val_s
    
    # Create Datasets
    train_ds = OralPathologyDataset(train_paths, train_binary, train_subtype, transform=train_transform)
    val_ds = OralPathologyDataset(val_paths, val_binary, val_subtype, transform=val_transform)
    
    # DataLoaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    
    print(f"Train images: {len(train_ds)}")
    print(f"Val images: {len(val_ds)}")
    print(f"Test images: Not loaded (use evaluate_final.py after training)")
    
    # 2. Model Setup
    model = MultiTaskOralClassifier(backbone=current_backbone).to(device)
    criterion = MultiTaskLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Setup scheduler based on configuration
    if SCHEDULER_TYPE == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)
        print(f"Scheduler: CosineAnnealingLR")
    elif SCHEDULER_TYPE == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=SCHEDULER_FACTOR, 
                                                         patience=SCHEDULER_PATIENCE, verbose=True)
        print(f"Scheduler: ReduceLROnPlateau (patience={SCHEDULER_PATIENCE}, factor={SCHEDULER_FACTOR})")
    elif SCHEDULER_TYPE == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_FACTOR)
        print(f"Scheduler: StepLR (step_size={SCHEDULER_STEP_SIZE}, gamma={SCHEDULER_FACTOR})")
    elif SCHEDULER_TYPE == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=SCHEDULER_GAMMA)
        print(f"Scheduler: ExponentialLR (gamma={SCHEDULER_GAMMA})")
    else:
        scheduler = None
        print("No scheduler used")
    
    # 3. Training Loop
    best_loss = float('inf')
    epochs_no_improve = 0
    history = {'train_loss': [], 'val_loss': [], 'val_acc_b': [], 'val_acc_s': [], 'lr': []}
    
    if EARLY_STOPPING:
        print(f"Early stopping enabled (patience={EARLY_STOPPING_PATIENCE}, min_delta={EARLY_STOPPING_MIN_DELTA})")
    
    for epoch in range(NUM_EPOCHS):
        current_lr = optimizer.param_groups[0]['lr']
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} | LR: {current_lr:.2e}")
        
        t_loss, t_loss_b, t_loss_s, t_acc_b, t_acc_s = train_one_epoch(model, train_loader, criterion, optimizer, device)
        v_loss, v_acc_b, v_acc_s = validate(model, val_loader, criterion, device)
        
        # Step scheduler
        if scheduler is not None:
            if SCHEDULER_TYPE == 'plateau':
                scheduler.step(v_loss)
            else:
                scheduler.step()
        
        # Track history
        history['train_loss'].append(t_loss)
        history['val_loss'].append(v_loss)
        history['val_acc_b'].append(v_acc_b)
        history['val_acc_s'].append(v_acc_s)
        history['lr'].append(current_lr)
        
        print(f"Train Loss: {t_loss:.4f} (B: {t_loss_b:.4f}, S: {t_loss_s:.4f}) | Acc B: {t_acc_b:.4f}, S: {t_acc_s:.4f}")
        print(f"Val Loss: {v_loss:.4f} | Acc B: {v_acc_b:.4f}, S: {v_acc_s:.4f}")
        
        # Check for improvement
        if v_loss < (best_loss - EARLY_STOPPING_MIN_DELTA):
            best_loss = v_loss
            epochs_no_improve = 0
            best_loss = v_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), current_best_model_path)
            print("✓ Best model saved")
        else:
            epochs_no_improve += 1
            if epochs_no_improve > 0:
                print(f"No improvement for {epochs_no_improve} epoch(s)")
        
        # Early stopping check
        if EARLY_STOPPING and epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            print(f"No improvement for {EARLY_STOPPING_PATIENCE} consecutive epochs")
            break
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE")
    print("="*60)
    print(f"Total epochs: {epoch+1}/{NUM_EPOCHS}")
    print(f"Best validation loss: {best_loss:.4f}")
    print(f"Model saved to: {current_best_model_path}")
    if EARLY_STOPPING and epochs_no_improve >= EARLY_STOPPING_PATIENCE:
        print(f"Stopped early due to no improvement")
    print("\n  IMPORTANT: To evaluate on test set, run:")
    print("   python evaluate_final.py")
    print("\nThis ensures test set is only used ONCE for final evaluation.")

if __name__ == "__main__":
    main()


### `evaluate_final.py`

In [None]:
"""
FINAL TEST SET EVALUATION
 Run this ONLY ONCE after training is complete!
Multiple runs on the test set lead to overfitting through hyperparameter tuning.
"""
import torch
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
import os

def main():
    print("\n" + "="*70)
    print(" FINAL TEST SET EVALUATION - USE ONLY ONCE!")
    print("="*70)
    print("This script evaluates on the held-out test set.")
    print("Running this multiple times compromises the validity of results.\n")
    
    # Argument Parsing
    import argparse
    parser = argparse.ArgumentParser(description='Evaluate Oral Pathology Model')
    parser.add_argument('--backbone', type=str, default=BACKBONE, help='Backbone model name')
    parser.add_argument('--no-confirm', action='store_true', help='Skip confirmation prompt')
    args = parser.parse_args()
    
    current_backbone = args.backbone
    print(f"Using Backbone: {current_backbone}")

    if not args.no_confirm:
        response = input("Continue with test evaluation? (yes/no): ")
        if response.lower() not in ['yes', 'y']:
            print("Evaluation cancelled.")
            return
    else:
        print("Skipping confirmation prompt (--no-confirm used).")
    
    # Update paths
    import os
    from configs import config
    
    # Recalculate paths based on the chosen backbone
    current_save_dir = os.path.join(config.BASE_PATH, 'results', current_backbone)
    current_best_model_path = os.path.join(current_save_dir, 'best_model.pth')
    
    if not os.path.exists(current_best_model_path):
        print(f" Model file not found at {current_best_model_path}")
        print(f"   Make sure you have trained the {current_backbone} model first.")
        return

    set_seed()
    device = get_device()
    
    # Load test data ONLY
    print("\nLoading held-out test datasets...")
    d1_test_p, d1_test_b, d1_test_s = load_dataset1_split('test')
    d2_test_p, d2_test_b, d2_test_s = load_dataset2_split('test')  # Now from merged+split data
    
    test_ds_combined = OralPathologyDataset(
        d1_test_p + d2_test_p, 
        d1_test_b + d2_test_b, 
        d1_test_s + d2_test_s, 
        transform=val_transform
    )
    
    test_loader = DataLoader(test_ds_combined, batch_size=BATCH_SIZE, shuffle=False, 
                            num_workers=NUM_WORKERS, pin_memory=True)
    
    print(f"Test images: {len(test_ds_combined)}")
    print(f"  - Dataset 1: {len(d1_test_p)}")
    print(f"  - Dataset 2: {len(d2_test_p)}")
    
    # Load model
    print(f"\nLoading model from {current_best_model_path}...")
    model = MultiTaskOralClassifier(backbone=current_backbone).to(device)
    model.load_state_dict(torch.load(current_best_model_path))
    
    # Evaluate
    print("\nEvaluating on held-out test set...")
    results = evaluate_model(model, test_loader, device)
    
    # Print Results
    print("\n" + "="*60)
    print("BINARY CLASSIFICATION RESULTS (Benign vs Malignant)")
    print("="*60)
    acc_b = accuracy_score(results['targets_binary'], results['preds_binary'])
    prec_b = precision_score(results['targets_binary'], results['preds_binary'], average='weighted', zero_division=0)
    rec_b = recall_score(results['targets_binary'], results['preds_binary'], average='weighted', zero_division=0)
    f1_b = f1_score(results['targets_binary'], results['preds_binary'], average='weighted', zero_division=0)
    
    print(f"Accuracy:  {acc_b:.4f}")
    print(f"Precision: {prec_b:.4f}")
    print(f"Recall:    {rec_b:.4f}")
    print(f"F1-Score:  {f1_b:.4f}")
    
    print("\n" + "="*60)
    print("SUBTYPE CLASSIFICATION RESULTS")
    print("="*60)
    acc_s = accuracy_score(results['targets_subtype'], results['preds_subtype'])
    prec_s = precision_score(results['targets_subtype'], results['preds_subtype'], average='weighted', zero_division=0)
    rec_s = recall_score(results['targets_subtype'], results['preds_subtype'], average='weighted', zero_division=0)
    f1_s = f1_score(results['targets_subtype'], results['preds_subtype'], average='weighted', zero_division=0)
    
    print(f"Accuracy:  {acc_s:.4f}")
    print(f"Precision: {prec_s:.4f}")
    print(f"Recall:    {rec_s:.4f}")
    print(f"F1-Score:  {f1_s:.4f}")
    
    print("\nPer-Class Report:")
    print(classification_report(results['targets_subtype'], results['preds_subtype'], 
                                 target_names=DS2_CLASSES, zero_division=0))
    
    # Save results to file
    results_file = os.path.join(current_save_dir, 'evaluation_results.txt')
    with open(results_file, 'w') as f:
        f.write("BINARY CLASSIFICATION RESULTS (Benign vs Malignant)\n")
        f.write("="*60 + "\n")
        f.write(f"Accuracy:  {acc_b:.4f}\n")
        f.write(f"Precision: {prec_b:.4f}\n")
        f.write(f"Recall:    {rec_b:.4f}\n")
        f.write(f"F1-Score:  {f1_b:.4f}\n\n")
        
        f.write("SUBTYPE CLASSIFICATION RESULTS\n")
        f.write("="*60 + "\n")
        f.write(f"Accuracy:  {acc_s:.4f}\n")
        f.write(f"Precision: {prec_s:.4f}\n")
        f.write(f"Recall:    {rec_s:.4f}\n")
        f.write(f"F1-Score:  {f1_s:.4f}\n\n")
        f.write("Per-Class Report:\n")
        f.write(classification_report(results['targets_subtype'], results['preds_subtype'], 
                                      target_names=DS2_CLASSES, zero_division=0))
    
    print(f"\n✓ Results saved to {results_file}")

if __name__ == "__main__":
    main()
