## ‚ö†Ô∏è Data Leakage Prevention

**Changes Made to Prevent Data Leakage:**

1. **‚ùå NO Augmented Data Used**
   - Previously: Mixed original and augmented data, then split ‚Üí augmented versions of same image in train/test
   - Now: Uses ONLY original images from Dataset 1

2. **‚úÖ Proper Dataset 1 Splitting**
   - Previously: All DS1 added to training, then tried to split same data for testing
   - Now: Split DS1 into train/val/test BEFORE any usage (70%/15%/15%)

3. **‚úÖ Separate Validation Set**
   - Previously: Only DS2 validation, no DS1 validation
   - Now: Both DS1 and DS2 have validation sets

4. **‚úÖ No Overlapping Data**
   - Train, validation, and test sets are completely separate
   - Each image appears in exactly ONE split

5. **‚ö†Ô∏è Dataset 2 Note**
   - Dataset 2 uses pre-split folders (Training/Validation/Testing)
   - Assumes the folder structure has no duplicate images across splits
   - If uncertain, verify no overlapping images between folders

---

# Dual-Head Oral Pathology Classifier
## Flat Multi-Task Learning (Joint Learning)

This notebook implements a computer vision system to analyze oral lesions using two distinct datasets.

**Architecture:** Shared Backbone with Two Independent Parallel Heads

- **Head 1 (Binary):** Is it Malignant or Benign?
- **Head 2 (Multi-Class):** What is the specific subtype?

**Datasets:**
- **Dataset 1 (DS1):** Labeled as Malignant or Benign only
- **Dataset 2 (DS2):** Labeled with specific pathology types (MC, OC, CaS, CoS, etc.)

## 1. Environment Setup & Installations

In [1]:
# Install required packages for Google Colab
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install timm
!pip install albumentations
!pip install scikit-learn
!pip install matplotlib seaborn
!pip install tqdm
!pip install pandas

Looking in indexes: https://download.pytorch.org/whl/cu118


In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

KeyboardInterrupt: 

In [None]:
# Import all required libraries
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transforms
import torchvision.models as models

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Dataset Paths Configuration

In [None]:
# Define dataset paths (Google Drive)
BASE_PATH = '/content/drive/MyDrive/dataset'

# Dataset 1 Paths (Binary: Malignant/Benign)
DS1_ORIGINAL_BENIGN = os.path.join(BASE_PATH, 'Dataset 1', 'original_data', 'benign_lesions')
DS1_ORIGINAL_MALIGNANT = os.path.join(BASE_PATH, 'Dataset 1', 'original_data', 'malignant_lesions')
DS1_AUGMENTED_BENIGN = os.path.join(BASE_PATH, 'Dataset 1', 'augmented_data', 'augmented_benign')
DS1_AUGMENTED_MALIGNANT = os.path.join(BASE_PATH, 'Dataset 1', 'augmented_data', 'augmented_malignant')

# Dataset 2 Paths (Multi-class subtypes)
# Note: There's a space after 'Dataset 2' in the folder name
DS2_TRAINING = os.path.join(BASE_PATH, 'Dataset 2 ', 'Training')
DS2_VALIDATION = os.path.join(BASE_PATH, 'Dataset 2 ', 'Validation')
DS2_TESTING = os.path.join(BASE_PATH, 'Dataset 2 ', 'Testing')

# Dataset 2 class names (subtypes)
DS2_CLASSES = ['CaS', 'CoS', 'Gum', 'MC', 'OC', 'OLP', 'OT']

# Define which DS2 classes are considered Malignant for Head 1
# MC (Mucosal Cancer), OC (Oral Cancer), CaS (Cancer Squamous) are Malignant
MALIGNANT_SUBTYPES = ['MC', 'OC', 'CaS']

print(f"DS2 Classes: {DS2_CLASSES}")
print(f"Malignant Subtypes: {MALIGNANT_SUBTYPES}")

In [None]:
# Verify paths exist
paths_to_check = [
    DS1_ORIGINAL_BENIGN, DS1_ORIGINAL_MALIGNANT,
    DS1_AUGMENTED_BENIGN, DS1_AUGMENTED_MALIGNANT,
    DS2_TRAINING, DS2_VALIDATION, DS2_TESTING
]

for path in paths_to_check:
    exists = os.path.exists(path)
    status = "‚úì" if exists else "‚úó"
    print(f"{status} {path}")

## 3. Data Exploration

In [None]:
# Count images in Dataset 1
def count_images(folder):
    if not os.path.exists(folder):
        return 0
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
    count = 0
    for ext in extensions:
        count += len(glob(os.path.join(folder, ext)))
        count += len(glob(os.path.join(folder, ext.upper())))
    return count

print("=== Dataset 1 Statistics ===")
ds1_stats = {
    'Original Benign': count_images(DS1_ORIGINAL_BENIGN),
    'Original Malignant': count_images(DS1_ORIGINAL_MALIGNANT),
    'Augmented Benign': count_images(DS1_AUGMENTED_BENIGN),
    'Augmented Malignant': count_images(DS1_AUGMENTED_MALIGNANT)
}
for k, v in ds1_stats.items():
    print(f"  {k}: {v}")

print(f"\n  Total DS1: {sum(ds1_stats.values())}")

In [None]:
# Count images in Dataset 2
print("=== Dataset 2 Statistics ===")

for split_name, split_path in [('Training', DS2_TRAINING), ('Validation', DS2_VALIDATION), ('Testing', DS2_TESTING)]:
    print(f"\n{split_name}:")
    total = 0
    for cls in DS2_CLASSES:
        cls_path = os.path.join(split_path, cls)
        count = count_images(cls_path)
        total += count
        malignant_marker = "*" if cls in MALIGNANT_SUBTYPES else " "
        print(f"  {malignant_marker} {cls}: {count}")
    print(f"  Total: {total}")

print("\n* = Malignant subtype")

## 4. Custom Dataset Class (Union Dataset)

In [None]:
class OralPathologyDataset(Dataset):
    """
    Union Dataset for Dual-Head Multi-Task Learning.
    
    Handles both Dataset 1 (binary only) and Dataset 2 (with subtypes).
    
    Labels:
        - label_head1 (binary): 0 = Benign, 1 = Malignant
        - label_head2 (subtype): 0 to N-1 for DS2, -1 for DS1 (ignored)
    """
    
    def __init__(self, image_paths, labels_binary, labels_subtype, transform=None):
        """
        Args:
            image_paths: List of image file paths
            labels_binary: List of binary labels (0=Benign, 1=Malignant)
            labels_subtype: List of subtype labels (0 to N-1, or -1 to ignore)
            transform: Optional transforms to apply
        """
        self.image_paths = image_paths
        self.labels_binary = labels_binary
        self.labels_subtype = labels_subtype
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get labels
        label_binary = self.labels_binary[idx]
        label_subtype = self.labels_subtype[idx]
        
        return image, label_binary, label_subtype

In [None]:
def get_image_files(folder):
    """Get all image files from a folder."""
    if not os.path.exists(folder):
        return []
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
    files = []
    for ext in extensions:
        files.extend(glob(os.path.join(folder, ext)))
        files.extend(glob(os.path.join(folder, ext.upper())))
    return files


def load_dataset1_split(split='train', test_size=0.15, val_size=0.15, random_state=42):
    """
    Load Dataset 1 (Binary labels only) with proper train/val/test split.
    ONLY uses original data (NO augmented data to prevent leakage).
    
    Args:
        split: 'train', 'val', or 'test'
        test_size: Proportion for test set
        val_size: Proportion for validation set
        random_state: Random seed for reproducibility
    
    Returns:
        image_paths, labels_binary, labels_subtype (-1 for all)
    """
    from sklearn.model_selection import train_test_split
    
    # Load ONLY original data (no augmented)
    benign_paths = get_image_files(DS1_ORIGINAL_BENIGN)
    malignant_paths = get_image_files(DS1_ORIGINAL_MALIGNANT)
    
    # Combine paths and create labels
    all_paths = benign_paths + malignant_paths
    all_binary = [0] * len(benign_paths) + [1] * len(malignant_paths)
    all_subtype = [-1] * len(all_paths)  # No subtype for DS1
    
    # First split: separate test set
    temp_paths, test_paths, temp_binary, test_binary, temp_subtype, test_subtype = train_test_split(
        all_paths, all_binary, all_subtype,
        test_size=test_size,
        random_state=random_state,
        stratify=all_binary
    )
    
    # Second split: separate train and validation from remaining data
    val_size_adjusted = val_size / (1 - test_size)  # Adjust val size
    train_paths, val_paths, train_binary, val_binary, train_subtype, val_subtype = train_test_split(
        temp_paths, temp_binary, temp_subtype,
        test_size=val_size_adjusted,
        random_state=random_state,
        stratify=temp_binary
    )
    
    # Return requested split
    if split == 'train':
        print(f"Dataset 1 (TRAIN - Original only): {len(train_paths)} images")
        print(f"  Benign: {train_binary.count(0)}, Malignant: {train_binary.count(1)}")
        return train_paths, train_binary, train_subtype
    elif split == 'val':
        print(f"Dataset 1 (VAL - Original only): {len(val_paths)} images")
        print(f"  Benign: {val_binary.count(0)}, Malignant: {val_binary.count(1)}")
        return val_paths, val_binary, val_subtype
    else:  # test
        print(f"Dataset 1 (TEST - Original only): {len(test_paths)} images")
        print(f"  Benign: {test_binary.count(0)}, Malignant: {test_binary.count(1)}")
        return test_paths, test_binary, test_subtype


def load_dataset2(split='Training'):
    """
    Load Dataset 2 (Both binary and subtype labels).
    
    Args:
        split: 'Training', 'Validation', or 'Testing'
    
    Returns:
        image_paths, labels_binary, labels_subtype
    """
    if split == 'Training':
        base_path = DS2_TRAINING
    elif split == 'Validation':
        base_path = DS2_VALIDATION
    else:
        base_path = DS2_TESTING
    

    image_paths = []    return image_paths, labels_binary, labels_subtype

    labels_binary = []    

    labels_subtype = []    print(f"  Total: {len(image_paths)} images")

        print(f"Dataset 2 ({split}) loaded: {class_counts}")

    class_counts = {}    

            class_counts[subtype_name] = len(subtype_images)

    for subtype_idx, subtype_name in enumerate(DS2_CLASSES):        

        subtype_path = os.path.join(base_path, subtype_name)        labels_subtype.extend([subtype_idx] * len(subtype_images))

        subtype_images = get_image_files(subtype_path)        labels_binary.extend([binary_label] * len(subtype_images))

                image_paths.extend(subtype_images)

        # Determine binary label based on subtype        
        binary_label = 1 if subtype_name in MALIGNANT_SUBTYPES else 0

## 5. Data Transforms

In [None]:
# Image size for the model
IMG_SIZE = 224

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Transforms defined.")

## 6. Create DataLoaders

In [None]:
# Batch size
BATCH_SIZE = 32
NUM_WORKERS = 2  # Adjust based on your system

# Load Dataset 1 splits (ORIGINAL DATA ONLY - NO AUGMENTATION)
print("Loading Dataset 1 (Original only - NO augmented data)...")
ds1_train_paths, ds1_train_binary, ds1_train_subtype = load_dataset1_split('train')
ds1_val_paths, ds1_val_binary, ds1_val_subtype = load_dataset1_split('val')
ds1_test_paths, ds1_test_binary, ds1_test_subtype = load_dataset1_split('test')

# Load Dataset 2 splits (uses pre-split folders)
print("\nLoading Dataset 2 (Pre-split folders)...")
ds2_train_paths, ds2_train_binary, ds2_train_subtype = load_dataset2('Training')

ds2_val_paths, ds2_val_binary, ds2_val_subtype = load_dataset2('Validation')ds2_test_paths, ds2_test_binary, ds2_test_subtype = load_dataset2('Testing')

In [None]:
# Combine DS1 and DS2 for each split (train/val/test)
print("\n" + "="*60)
print("COMBINING DATASETS (NO DATA LEAKAGE)")
print("="*60)

# Training set: DS1 train + DS2 train
train_paths = ds1_train_paths + ds2_train_paths
train_binary = ds1_train_binary + ds2_train_binary
train_subtype = ds1_train_subtype + ds2_train_subtype

print(f"\nTraining Set: {len(train_paths)} images")
print(f"  - From DS1: {len(ds1_train_paths)} (original only, subtype=-1)")
print(f"  - From DS2: {len(ds2_train_paths)} (with subtype labels)")

# Validation set: DS1 val + DS2 val
val_paths = ds1_val_paths + ds2_val_paths
val_binary = ds1_val_binary + ds2_val_binary
val_subtype = ds1_val_subtype + ds2_val_subtype

print(f"\nValidation Set: {len(val_paths)} images")
print(f"  - From DS1: {len(ds1_val_paths)} (original only, subtype=-1)")
print(f"  - From DS2: {len(ds2_val_paths)} (with subtype labels)")

# Test set: DS1 test + DS2 test (combined for overall evaluation)
test_paths_combined = ds1_test_paths + ds2_test_paths
test_binary_combined = ds1_test_binary + ds2_test_binary
test_subtype_combined = ds1_test_subtype + ds2_test_subtype

print(f"\nTest Set (Combined): {len(test_paths_combined)} images")
print(f"  - From DS1: {len(ds1_test_paths)} (original only, subtype=-1)")
print(f"  - From DS2: {len(ds2_test_paths)} (with subtype labels)")

# Create PyTorch datasets
train_dataset = OralPathologyDataset(train_paths, train_binary, train_subtype, transform=train_transform)
val_dataset = OralPathologyDataset(val_paths, val_binary, val_subtype, transform=val_transform)

# Separate test sets for individual evaluation
test_dataset_ds1 = OralPathologyDataset(ds1_test_paths, ds1_test_binary, ds1_test_subtype, transform=val_transform)
test_dataset_ds2 = OralPathologyDataset(ds2_test_paths, ds2_test_binary, ds2_test_subtype, transform=val_transform)
test_dataset_combined = OralPathologyDataset(test_paths_combined, test_binary_combined, test_subtype_combined, transform=val_transform)

print(f"\n‚úì All datasets created WITHOUT data leakage")
print(f"‚úì DS1 uses ONLY original images (no augmented)")
print(f"‚úì Train/Val/Test splits are completely separate")

In [None]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# Separate test loaders for individual dataset evaluation
test_loader_ds1 = DataLoader(
    test_dataset_ds1, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader_ds2 = DataLoader(
    test_dataset_ds2, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# Combined test loader for overall evaluation
test_loader_combined = DataLoader(
    test_dataset_combined,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"\nDataLoader Summary:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test DS1 batches: {len(test_loader_ds1)}")
print(f"  Test DS2 batches: {len(test_loader_ds2)}")
print(f"  Test Combined batches: {len(test_loader_combined)}")

## 7. Model Architecture (Y-Shape Dual Head)

In [None]:
class MultiTaskOralClassifier(nn.Module):
    """
    Dual-Head Multi-Task Model for Oral Pathology Classification.
    
    Architecture:
        - Shared Backbone: ResNet50 (pretrained)
        - Head 1 (Binary): Malignant vs Benign
        - Head 2 (Multi-class): Subtype classification
    
    Both heads are INDEPENDENT and trained in PARALLEL (Flat MTL).
    """
    
    def __init__(self, num_subtypes=7, backbone='resnet50', dropout=0.5):
        super(MultiTaskOralClassifier, self).__init__()
        
        # Load pretrained backbone
        if backbone == 'resnet50':
            self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            num_features = self.backbone.fc.in_features
            # Remove the final classification layer
            self.backbone.fc = nn.Identity()
        elif backbone == 'convnext_tiny':
            self.backbone = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
            num_features = self.backbone.classifier[2].in_features
            self.backbone.classifier = nn.Identity()
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        
        self.dropout = nn.Dropout(p=dropout)
        
        # Head 1: Binary Classification (Malignant vs Benign)
        self.head_binary = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(512, 2)  # 2 classes: Benign, Malignant
        )
        
        # Head 2: Subtype Classification
        self.head_subtype = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(512, num_subtypes)  # N classes
        )
        
        print(f"Model initialized with {backbone} backbone")
        print(f"  - Feature size: {num_features}")
        print(f"  - Head 1 (Binary): 2 classes")
        print(f"  - Head 2 (Subtype): {num_subtypes} classes")
    
    def forward(self, x):
        # Extract features using shared backbone
        features = self.backbone(x)
        features = self.dropout(features)
        
        # Pass through BOTH heads INDEPENDENTLY
        out_binary = self.head_binary(features)
        out_subtype = self.head_subtype(features)
        
        return out_binary, out_subtype
    
    def freeze_backbone(self):
        """Freeze backbone weights for transfer learning."""
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("Backbone frozen.")
    
    def unfreeze_backbone(self):
        """Unfreeze backbone weights for fine-tuning."""
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("Backbone unfrozen.")

In [None]:
# Initialize the model
NUM_SUBTYPES = len(DS2_CLASSES)  # 7 classes

model = MultiTaskOralClassifier(
    num_subtypes=NUM_SUBTYPES,
    backbone='resnet50',
    dropout=0.5
).to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 8. Loss Functions (with Masking)

In [None]:
class MultiTaskLoss(nn.Module):
    """
    Combined loss for Multi-Task Learning.
    
    - Loss 1 (Binary): CrossEntropyLoss for all samples
    - Loss 2 (Subtype): CrossEntropyLoss with ignore_index=-1
      This ensures DS1 samples (with subtype=-1) don't contribute to Head 2 loss.
    """
    
    def __init__(self, weight_binary=1.0, weight_subtype=1.0):
        super(MultiTaskLoss, self).__init__()
        
        self.weight_binary = weight_binary
        self.weight_subtype = weight_subtype
        
        # Loss for binary classification (applied to ALL samples)
        self.criterion_binary = nn.CrossEntropyLoss()
        
        # Loss for subtype classification (IGNORES samples with label=-1)
        # This is the CRITICAL "Masking Trick"
        self.criterion_subtype = nn.CrossEntropyLoss(ignore_index=-1)
    
    def forward(self, pred_binary, pred_subtype, target_binary, target_subtype):
        """
        Calculate combined loss.
        
        Args:
            pred_binary: Predictions from Head 1 (N, 2)
            pred_subtype: Predictions from Head 2 (N, num_subtypes)
            target_binary: Binary labels (N,)
            target_subtype: Subtype labels (N,), -1 for DS1 samples
        
        Returns:
            total_loss, loss_binary, loss_subtype
        """
        # Binary loss (all samples contribute)
        loss_binary = self.criterion_binary(pred_binary, target_binary)
        
        # Subtype loss (only DS2 samples contribute, DS1 samples with -1 are ignored)
        loss_subtype = self.criterion_subtype(pred_subtype, target_subtype)
        
        # Handle case when all samples are DS1 (no valid subtype labels)
        if torch.isnan(loss_subtype):
            loss_subtype = torch.tensor(0.0, device=pred_binary.device)
        
        # Combined loss
        total_loss = (self.weight_binary * loss_binary) + (self.weight_subtype * loss_subtype)
        
        return total_loss, loss_binary, loss_subtype


# Initialize loss function
criterion = MultiTaskLoss(weight_binary=1.0, weight_subtype=1.0)
print("Multi-Task Loss initialized with masking for subtype head.")

## 9. Optimizer and Scheduler

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 30

# Optimizer: AdamW
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler: Cosine Annealing
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
    eta_min=1e-6
)

print(f"Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")
print(f"Scheduler: CosineAnnealingLR (T_max={NUM_EPOCHS})")

## 10. Training Functions

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """
    Train for one epoch.
    
    Returns:
        avg_loss, avg_loss_binary, avg_loss_subtype, acc_binary, acc_subtype
    """
    model.train()
    
    running_loss = 0.0
    running_loss_binary = 0.0
    running_loss_subtype = 0.0
    
    all_preds_binary = []
    all_targets_binary = []
    all_preds_subtype = []
    all_targets_subtype = []
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    
    for images, targets_binary, targets_subtype in pbar:
        images = images.to(device)
        targets_binary = targets_binary.to(device)
        targets_subtype = targets_subtype.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass - get predictions from BOTH heads
        pred_binary, pred_subtype = model(images)
        
        # Calculate loss (with masking for subtype)
        loss, loss_binary, loss_subtype = criterion(
            pred_binary, pred_subtype, targets_binary, targets_subtype
        )
        
        # Backward pass - updates shared backbone AND both heads
        loss.backward()
        optimizer.step()
        
        # Track losses
        running_loss += loss.item()
        running_loss_binary += loss_binary.item()
        running_loss_subtype += loss_subtype.item() if not torch.isnan(loss_subtype) else 0
        
        # Track predictions for accuracy
        preds_binary = torch.argmax(pred_binary, dim=1)
        all_preds_binary.extend(preds_binary.cpu().numpy())
        all_targets_binary.extend(targets_binary.cpu().numpy())
        
        # Only track subtype predictions for DS2 samples (where target != -1)
        mask = targets_subtype != -1
        if mask.sum() > 0:
            preds_subtype = torch.argmax(pred_subtype[mask], dim=1)
            all_preds_subtype.extend(preds_subtype.cpu().numpy())
            all_targets_subtype.extend(targets_subtype[mask].cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Calculate metrics
    num_batches = len(train_loader)
    avg_loss = running_loss / num_batches
    avg_loss_binary = running_loss_binary / num_batches
    avg_loss_subtype = running_loss_subtype / num_batches
    
    acc_binary = accuracy_score(all_targets_binary, all_preds_binary)
    acc_subtype = accuracy_score(all_targets_subtype, all_preds_subtype) if all_targets_subtype else 0.0
    
    return avg_loss, avg_loss_binary, avg_loss_subtype, acc_binary, acc_subtype

In [None]:
def validate(model, val_loader, criterion, device):
    """
    Validate the model.
    
    Returns:
        avg_loss, acc_binary, acc_subtype
    """
    model.eval()
    
    running_loss = 0.0
    
    all_preds_binary = []
    all_targets_binary = []
    all_preds_subtype = []
    all_targets_subtype = []
    
    with torch.no_grad():
        for images, targets_binary, targets_subtype in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            targets_binary = targets_binary.to(device)
            targets_subtype = targets_subtype.to(device)
            
            # Forward pass
            pred_binary, pred_subtype = model(images)
            
            # Calculate loss
            loss, _, _ = criterion(pred_binary, pred_subtype, targets_binary, targets_subtype)
            running_loss += loss.item()
            
            # Track predictions
            preds_binary = torch.argmax(pred_binary, dim=1)
            all_preds_binary.extend(preds_binary.cpu().numpy())
            all_targets_binary.extend(targets_binary.cpu().numpy())
            
            # Track subtype predictions
            mask = targets_subtype != -1
            if mask.sum() > 0:
                preds_subtype = torch.argmax(pred_subtype[mask], dim=1)
                all_preds_subtype.extend(preds_subtype.cpu().numpy())
                all_targets_subtype.extend(targets_subtype[mask].cpu().numpy())
    
    avg_loss = running_loss / len(val_loader)
    acc_binary = accuracy_score(all_targets_binary, all_preds_binary)
    acc_subtype = accuracy_score(all_targets_subtype, all_preds_subtype) if all_targets_subtype else 0.0
    
    return avg_loss, acc_binary, acc_subtype

## 11. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_loss_binary': [],
    'train_loss_subtype': [],
    'train_acc_binary': [],
    'train_acc_subtype': [],
    'val_loss': [],
    'val_acc_binary': [],
    'val_acc_subtype': [],
    'lr': []
}

# Best model tracking
best_val_loss = float('inf')
best_model_path = '/content/drive/MyDrive/dataset/best_model.pth'

print(f"Starting training for {NUM_EPOCHS} epochs...")
print("="*70)

In [None]:
# Main training loop
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 40)
    
    # Train
    train_loss, train_loss_b, train_loss_s, train_acc_b, train_acc_s = train_one_epoch(
        model, train_loader, criterion, optimizer, device
    )
    
    # Validate
    val_loss, val_acc_b, val_acc_s = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Log history
    history['train_loss'].append(train_loss)
    history['train_loss_binary'].append(train_loss_b)
    history['train_loss_subtype'].append(train_loss_s)
    history['train_acc_binary'].append(train_acc_b)
    history['train_acc_subtype'].append(train_acc_s)
    history['val_loss'].append(val_loss)
    history['val_acc_binary'].append(val_acc_b)
    history['val_acc_subtype'].append(val_acc_s)
    history['lr'].append(current_lr)
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f} (Binary: {train_loss_b:.4f}, Subtype: {train_loss_s:.4f})")
    print(f"Train Acc:  Binary: {train_acc_b:.4f}, Subtype: {train_acc_s:.4f}")
    print(f"Val Loss:   {val_loss:.4f}")
    print(f"Val Acc:    Binary: {val_acc_b:.4f}, Subtype: {val_acc_s:.4f}")
    print(f"LR: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc_binary': val_acc_b,
            'val_acc_subtype': val_acc_s,
        }, best_model_path)
        print(f"‚úì Best model saved! (val_loss: {val_loss:.4f})")

print("\n" + "="*70)
print("Training complete!")

## 12. Training Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss plot
ax1 = axes[0, 0]
ax1.plot(history['train_loss'], label='Train Total Loss', color='blue')
ax1.plot(history['val_loss'], label='Val Loss', color='orange')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Total Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Individual losses
ax2 = axes[0, 1]
ax2.plot(history['train_loss_binary'], label='Binary Loss', color='green')
ax2.plot(history['train_loss_subtype'], label='Subtype Loss', color='red')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('Train Loss by Head')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Binary accuracy
ax3 = axes[1, 0]
ax3.plot(history['train_acc_binary'], label='Train', color='blue')
ax3.plot(history['val_acc_binary'], label='Val', color='orange')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.set_title('Binary Head Accuracy (Head 1)')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Subtype accuracy
ax4 = axes[1, 1]
ax4.plot(history['train_acc_subtype'], label='Train', color='blue')
ax4.plot(history['val_acc_subtype'], label='Val', color='orange')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accuracy')
ax4.set_title('Subtype Head Accuracy (Head 2)')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/dataset/training_history.png', dpi=150)
plt.show()

## 13. Model Evaluation

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best val_loss: {checkpoint['val_loss']:.4f}")

In [None]:
def evaluate_model(model, test_loader, device, dataset_name="Test"):
    """
    Comprehensive model evaluation.
    
    Returns predictions and targets for both heads.
    """
    model.eval()
    
    all_preds_binary = []
    all_targets_binary = []
    all_preds_subtype = []
    all_targets_subtype = []
    all_probs_binary = []
    all_probs_subtype = []
    
    with torch.no_grad():
        for images, targets_binary, targets_subtype in tqdm(test_loader, desc=f"Evaluating {dataset_name}"):
            images = images.to(device)
            
            # Forward pass
            pred_binary, pred_subtype = model(images)
            
            # Get probabilities
            probs_binary = torch.softmax(pred_binary, dim=1)
            probs_subtype = torch.softmax(pred_subtype, dim=1)
            
            # Get predictions
            preds_binary = torch.argmax(pred_binary, dim=1)
            preds_subtype = torch.argmax(pred_subtype, dim=1)
            
            # Store all predictions
            all_preds_binary.extend(preds_binary.cpu().numpy())
            all_targets_binary.extend(targets_binary.numpy())
            all_probs_binary.extend(probs_binary.cpu().numpy())
            
            # Store subtype predictions (only for valid labels)
            mask = targets_subtype != -1
            if mask.sum() > 0:
                all_preds_subtype.extend(preds_subtype[mask].cpu().numpy())
                all_targets_subtype.extend(targets_subtype[mask].numpy())
                all_probs_subtype.extend(probs_subtype[mask].cpu().numpy())
    
    return {
        'preds_binary': np.array(all_preds_binary),
        'targets_binary': np.array(all_targets_binary),
        'probs_binary': np.array(all_probs_binary),
        'preds_subtype': np.array(all_preds_subtype),
        'targets_subtype': np.array(all_targets_subtype),
        'probs_subtype': np.array(all_probs_subtype) if all_probs_subtype else None
    }

In [None]:
# Evaluate on DS1 Test Set (Binary Head)
print("\n" + "="*70)
print("EVALUATION ON DATASET 1 TEST SET (Binary Classification)")
print("="*70)

results_ds1 = evaluate_model(model, test_loader_ds1, device, "DS1")

# Binary metrics for DS1
acc_binary_ds1 = accuracy_score(results_ds1['targets_binary'], results_ds1['preds_binary'])
print(f"\nBinary Accuracy (DS1): {acc_binary_ds1:.4f}")

print("\nClassification Report (Binary - DS1):")
print(classification_report(
    results_ds1['targets_binary'], 
    results_ds1['preds_binary'],
    target_names=['Benign', 'Malignant']
))

In [None]:
# Evaluate on DS2 Test Set (Both Heads)
print("\n" + "="*70)
print("EVALUATION ON DATASET 2 TEST SET (Both Classifications)")
print("="*70)

results_ds2 = evaluate_model(model, test_loader_ds2, device, "DS2")

# Binary metrics for DS2
acc_binary_ds2 = accuracy_score(results_ds2['targets_binary'], results_ds2['preds_binary'])
print(f"\nBinary Accuracy (DS2): {acc_binary_ds2:.4f}")

print("\nClassification Report (Binary - DS2):")
print(classification_report(
    results_ds2['targets_binary'], 
    results_ds2['preds_binary'],
    target_names=['Benign', 'Malignant']
))

# Subtype metrics for DS2
acc_subtype_ds2 = accuracy_score(results_ds2['targets_subtype'], results_ds2['preds_subtype'])
print(f"\nSubtype Accuracy (DS2): {acc_subtype_ds2:.4f}")

print("\nClassification Report (Subtype - DS2):")
print(classification_report(
    results_ds2['targets_subtype'], 
    results_ds2['preds_subtype'],
    target_names=DS2_CLASSES
))

## 14. Confusion Matrices

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes, title, ax):
    """Plot confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=classes, yticklabels=classes, ax=ax
    )
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(title)


# Plot confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# DS1 Binary
plot_confusion_matrix(
    results_ds1['targets_binary'], results_ds1['preds_binary'],
    ['Benign', 'Malignant'], 'DS1 - Binary Classification', axes[0]
)

# DS2 Binary
plot_confusion_matrix(
    results_ds2['targets_binary'], results_ds2['preds_binary'],
    ['Benign', 'Malignant'], 'DS2 - Binary Classification', axes[1]
)

# DS2 Subtype
plot_confusion_matrix(
    results_ds2['targets_subtype'], results_ds2['preds_subtype'],
    DS2_CLASSES, 'DS2 - Subtype Classification', axes[2]
)

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/dataset/confusion_matrices.png', dpi=150)
plt.show()

## 15. Summary of Results

In [None]:
# Final Summary
print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)

print("\nüìä Model: Dual-Head Multi-Task Classifier (Flat MTL)")
print(f"   Backbone: ResNet50 (pretrained)")
print(f"   Head 1: Binary (Malignant/Benign)")
print(f"   Head 2: Subtype ({len(DS2_CLASSES)} classes)")

print("\nüìà Performance Metrics:")
print("\n   Head 1 - Binary Classification:")
print(f"   ‚îú‚îÄ‚îÄ DS1 Test Accuracy: {acc_binary_ds1:.4f} ({acc_binary_ds1*100:.2f}%)")
print(f"   ‚îî‚îÄ‚îÄ DS2 Test Accuracy: {acc_binary_ds2:.4f} ({acc_binary_ds2*100:.2f}%)")

print("\n   Head 2 - Subtype Classification:")
print(f"   ‚îî‚îÄ‚îÄ DS2 Test Accuracy: {acc_subtype_ds2:.4f} ({acc_subtype_ds2*100:.2f}%)")

print("\nüìÅ Saved Files:")
print(f"   ‚îú‚îÄ‚îÄ Best Model: {best_model_path}")
print(f"   ‚îú‚îÄ‚îÄ Training History: /content/drive/MyDrive/dataset/training_history.png")
print(f"   ‚îî‚îÄ‚îÄ Confusion Matrices: /content/drive/MyDrive/dataset/confusion_matrices.png")

print("\n" + "="*70)

## 16. Inference Example

In [None]:
def predict_single_image(model, image_path, device):
    """
    Make prediction on a single image.
    
    Returns both head predictions with confidence scores.
    """
    model.eval()
    
    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = val_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_binary, pred_subtype = model(image_tensor)
        
        # Get probabilities
        probs_binary = torch.softmax(pred_binary, dim=1)[0]
        probs_subtype = torch.softmax(pred_subtype, dim=1)[0]
        
        # Get predictions
        pred_b = torch.argmax(probs_binary).item()
        pred_s = torch.argmax(probs_subtype).item()
    
    binary_labels = ['Benign', 'Malignant']
    
    return {
        'binary_prediction': binary_labels[pred_b],
        'binary_confidence': probs_binary[pred_b].item(),
        'subtype_prediction': DS2_CLASSES[pred_s],
        'subtype_confidence': probs_subtype[pred_s].item(),
        'all_subtype_probs': {cls: probs_subtype[i].item() for i, cls in enumerate(DS2_CLASSES)}
    }


# Example usage (uncomment and provide an image path)
# test_image = '/content/drive/MyDrive/dataset/Dataset 2 /Testing/MC/sample.jpg'
# result = predict_single_image(model, test_image, device)
# print(f"Binary: {result['binary_prediction']} ({result['binary_confidence']*100:.2f}%)")
# print(f"Subtype: {result['subtype_prediction']} ({result['subtype_confidence']*100:.2f}%)")

In [None]:
# Visualize some predictions
def visualize_predictions(model, test_loader, device, num_samples=8):
    """
    Visualize model predictions on sample images.
    """
    model.eval()
    
    # Get a batch
    images, targets_binary, targets_subtype = next(iter(test_loader))
    images = images[:num_samples].to(device)
    
    with torch.no_grad():
        pred_binary, pred_subtype = model(images)
        preds_b = torch.argmax(pred_binary, dim=1).cpu()
        preds_s = torch.argmax(pred_subtype, dim=1).cpu()
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images_denorm = images.cpu() * std + mean
    
    binary_labels = ['Benign', 'Malignant']
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    for i in range(num_samples):
        img = images_denorm[i].permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        
        true_b = binary_labels[targets_binary[i]]
        pred_b = binary_labels[preds_b[i]]
        true_s = DS2_CLASSES[targets_subtype[i]] if targets_subtype[i] != -1 else 'N/A'
        pred_s = DS2_CLASSES[preds_s[i]]
        
        color = 'green' if preds_b[i] == targets_binary[i] else 'red'
        
        axes[i].set_title(f"True: {true_b} / {true_s}\nPred: {pred_b} / {pred_s}", 
                          color=color, fontsize=10)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/dataset/prediction_samples.png', dpi=150)
    plt.show()


# Visualize predictions
visualize_predictions(model, test_loader_ds2, device)

## 17. Save Final Model

In [None]:
# Save final model with all metadata
final_model_path = '/content/drive/MyDrive/dataset/final_model.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'history': history,
    'config': {
        'num_subtypes': NUM_SUBTYPES,
        'ds2_classes': DS2_CLASSES,
        'malignant_subtypes': MALIGNANT_SUBTYPES,
        'img_size': IMG_SIZE,
        'backbone': 'resnet50'
    },
    'results': {
        'ds1_binary_accuracy': acc_binary_ds1,
        'ds2_binary_accuracy': acc_binary_ds2,
        'ds2_subtype_accuracy': acc_subtype_ds2
    }
}, final_model_path)

print(f"Final model saved to: {final_model_path}")

---

## Notes

### Key Points of Flat Multi-Task Learning:

1. **No Dependency:** The model does NOT check if it's "Malignant" first before checking the "Type."

2. **Parallel Output:** The model looks at the image and **simultaneously** outputs two answers:
   - Answer A: "It looks Malignant/Benign"
   - Answer B: "It looks like [subtype] type"

3. **Masking Trick:** Using `ignore_index=-1` in CrossEntropyLoss ensures that:
   - DS1 samples contribute to Head 1 (Binary) training
   - DS1 samples do NOT contribute to Head 2 (Subtype) training
   - DS2 samples contribute to BOTH heads

4. **Advantages:**
   - Prevents error propagation from one head to another
   - Shared backbone learns features useful for both tasks
   - Both heads can be trained with different amounts of data

---

## Summary: Flat Multi-Task Learning Architecture

**Key Points:**
1. ‚úì **No Hierarchical Dependency:** Both heads predict independently
2. ‚úì **Shared Backbone:** ResNet50 extracts features for both tasks
3. ‚úì **Masked Loss:** DS1 images don't contribute to Head 2's gradient
4. ‚úì **Parallel Training:** Both heads update simultaneously during backpropagation
5. ‚úì **Independent Evaluation:** Each head's performance measured separately

**Architecture Flow:**
```
Image ‚Üí ResNet50 ‚Üí Features ‚Üí ‚î¨‚Üí Head 1 (Binary) ‚Üí Benign/Malignant
                              ‚îî‚Üí Head 2 (Subtype) ‚Üí CaS/CoS/Gum/MC/OC/OLP/OT
```

**Advantages:**
- No error propagation between heads
- Simpler training compared to hierarchical models
- Both tasks benefit from shared feature learning
- Flexible: Can use predictions from either head independently