# Kaggle & Colab Imports

In [None]:
%%capture
# KAGGLE IMPORTS
# Clone repo
!git clone https://github.com/francinze/Ch_An2DL.git /kaggle/working/ch2

# Install kaggle API
!pip install -q kaggle

# Configure kaggle.json
!mkdir -p /root/.config/kaggle

# Copy your kaggle.json there
!cp /kaggle/working/ch2/kaggle.json /root/.config/kaggle/

# Set correct permissions
!chmod 600 /root/.config/kaggle/kaggle.json

# Move into the working directory
%cd /kaggle/working/ch2/

!mkdir data
!mkdir models

# Download competition files
!kaggle competitions download -c an2dl2526c2v2 -p /data

# Unzip dataset
!unzip -o /data/an2dl2526c2v2.zip -d /data/

In [None]:
'''
# COLAB IMPORTS
%%capture
!git clone https://github.com/francinze/Ch_An2DL.git
! pip install -q kaggle
! mkdir ~/.kaggle
! cp Ch_An2DL/kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
%cd /content/Ch_An2DL/
!mkdir data
!mkdir models
!kaggle competitions download -c an2dl2526c2v2 -p /data
!unzip -o /data/an2dl2526c2v2.zip -d /data/
'''

#  Import data

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

## Organize Data by Type (Masks vs Images)

This cell organizes the data into separate directories:
- `data/train_img/` - All training images (img_XXXX.png)
- `data/train_mask/` - All training masks (mask_XXXX.png)
- `data/test_img/` - All test images (img_XXXX.png)
- `data/test_mask/` - All test masks (mask_XXXX.png)

This structure ensures that:
1. The `DATA_TYPE` variable can cleanly switch between using images or masks
2. Augmented data follows the same naming convention (train_data_augmented_img/ or train_data_augmented_mask/)
3. No confusion between different file types

In [None]:
import os
import shutil

# Detect environment and set appropriate path prefix
if 'data' not in os.listdir():
    # Kaggle or Colab environment
    PATH_PREFIX = '/'
else:
    # Local environment
    PATH_PREFIX = ''

print("="*80)
print("ORGANIZING DATA INTO SEPARATE DIRECTORIES BY TYPE")
print("="*80)

# Define source directories
train_data_dir = PATH_PREFIX + 'data/train_data/'
test_data_dir = PATH_PREFIX + 'data/test_data/'

# Define target directories for organized data
train_img_dir = PATH_PREFIX + 'data/train_img/'
train_mask_dir = PATH_PREFIX + 'data/train_mask/'
test_img_dir = PATH_PREFIX + 'data/test_img/'
test_mask_dir = PATH_PREFIX + 'data/test_mask/'

# Create target directories if they don't exist
for directory in [train_img_dir, train_mask_dir, test_img_dir, test_mask_dir]:
    os.makedirs(directory, exist_ok=True)

# Function to organize files by type
def organize_data_by_type(source_dir, img_dir, mask_dir):
    """
    Move image and mask files from source directory to separate directories.
    Only moves files if they don't already exist in the target directory.
    """
    if not os.path.exists(source_dir):
        print(f"⚠ Warning: Source directory not found: {source_dir}")
        return 0, 0
    
    files = os.listdir(source_dir)
    img_count = 0
    mask_count = 0
    
    for filename in files:
        source_path = os.path.join(source_dir, filename)
        
        # Skip if not a file
        if not os.path.isfile(source_path):
            continue
        
        # Determine target directory based on filename prefix
        if filename.startswith('img_'):
            target_path = os.path.join(img_dir, filename)
            if not os.path.exists(target_path):
                shutil.copy2(source_path, target_path)
                img_count += 1
        elif filename.startswith('mask_'):
            target_path = os.path.join(mask_dir, filename)
            if not os.path.exists(target_path):
                shutil.copy2(source_path, target_path)
                mask_count += 1
    
    return img_count, mask_count

# Organize training data
print("\nOrganizing training data...")
train_img_moved, train_mask_moved = organize_data_by_type(
    train_data_dir, train_img_dir, train_mask_dir
)
print(f"  Images: {train_img_moved} files copied to {train_img_dir}")
print(f"  Masks: {train_mask_moved} files copied to {train_mask_dir}")

# Organize test data
print("\nOrganizing test data...")
test_img_moved, test_mask_moved = organize_data_by_type(
    test_data_dir, test_img_dir, test_mask_dir
)
print(f"  Images: {test_img_moved} files copied to {test_img_dir}")
print(f"  Masks: {test_mask_moved} files copied to {test_mask_dir}")

# Verify organization
print("\n" + "="*80)
print("DATA ORGANIZATION SUMMARY")
print("="*80)
print(f"Train images: {len(os.listdir(train_img_dir)) if os.path.exists(train_img_dir) else 0} files in {train_img_dir}")
print(f"Train masks: {len(os.listdir(train_mask_dir)) if os.path.exists(train_mask_dir) else 0} files in {train_mask_dir}")
print(f"Test images: {len(os.listdir(test_img_dir)) if os.path.exists(test_img_dir) else 0} files in {test_img_dir}")
print(f"Test masks: {len(os.listdir(test_mask_dir)) if os.path.exists(test_mask_dir) else 0} files in {test_mask_dir}")
print("="*80)
print("✓ Data organization complete!")
print("  - Original files remain in train_data/ and test_data/")
print("  - Organized copies are in train_img/, train_mask/, test_img/, test_mask/")

ORGANIZING DATA INTO SEPARATE DIRECTORIES BY TYPE

Organizing training data...
  Images: 1163 files copied to data/train_img/
  Masks: 1163 files copied to data/train_mask/

Organizing test data...
  Images: 1163 files copied to data/train_img/
  Masks: 1163 files copied to data/train_mask/

Organizing test data...
  Images: 954 files copied to data/test_img/
  Masks: 954 files copied to data/test_mask/

DATA ORGANIZATION SUMMARY
Train images: 1163 files in data/train_img/
Train masks: 1163 files in data/train_mask/
Test images: 954 files in data/test_img/
Test masks: 954 files in data/test_mask/
✓ Data organization complete!
  - Original files remain in train_data/ and test_data/
  - Organized copies are in train_img/, train_mask/, test_img/, test_mask/
  Images: 954 files copied to data/test_img/
  Masks: 954 files copied to data/test_mask/

DATA ORGANIZATION SUMMARY
Train images: 1163 files in data/train_img/
Train masks: 1163 files in data/train_mask/
Test images: 954 files in data

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split

# ===== SET DATA TYPE: "IMG" or "MASK" =====
DATA_TYPE = "MASK"  # Use "IMG" for images or "MASK" for masks
# ==========================================

# Set directories based on DATA_TYPE
if DATA_TYPE == "MASK":
    train_dir = PATH_PREFIX + 'data/train_mask/'
    test_dir = PATH_PREFIX + 'data/test_mask/'
else:  # IMG
    train_dir = PATH_PREFIX + 'data/train_img/'
    test_dir = PATH_PREFIX + 'data/test_img/'

train_labels = pd.read_csv(PATH_PREFIX + 'data/train_labels.csv')

print(f"Environment detected. Using path prefix: '{PATH_PREFIX}'")
print(f"Using DATA_TYPE: {DATA_TYPE}")
print(f"Train directory: {train_dir}")
print(f"Test directory: {test_dir}")

# Display dataset info
print(f"\nTotal training samples: {len(train_labels)}")
print(f"\nClass distribution:")
print(train_labels['label'].value_counts())

# Check image properties (load from appropriate directory based on DATA_TYPE)
if DATA_TYPE == "MASK":
    sample_file = Image.open(os.path.join(train_dir, 'mask_0000.png'))
    print(f"\nMask shape: {np.array(sample_file).shape}")
    print(f"Mask dtype: {np.array(sample_file).dtype}")
    print(f"Mask unique values: {np.unique(np.array(sample_file))}")
else:
    sample_file = Image.open(os.path.join(train_dir, 'img_0000.png'))
    print(f"\nImage shape: {np.array(sample_file).shape}")
    print(f"Image dtype: {np.array(sample_file).dtype}")

# Visualize a few samples (always show both img and mask for reference)
train_img_dir_viz = PATH_PREFIX + 'data/train_img/'
train_mask_dir_viz = PATH_PREFIX + 'data/train_mask/'
# Visualize a few samples (always show both img and mask for reference)
train_img_dir_viz = PATH_PREFIX + 'data/train_img/'
train_mask_dir_viz = PATH_PREFIX + 'data/train_mask/'

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
    img_name = train_labels.iloc[i]['sample_index']
    label = train_labels.iloc[i]['label']
    
    img = Image.open(os.path.join(train_img_dir_viz, img_name))
    mask = Image.open(os.path.join(train_mask_dir_viz, img_name.replace('img_', 'mask_')))
    
    axes[0, i].imshow(img)
    axes[0, i].set_title(f'{img_name}\n{label}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(mask, cmap='gray')
    axes[1, i].set_title(f'Mask for {img_name}')
    axes[1, i].axis('off')



plt.tight_layout()
plt.show()

# Preprocessing

## Remove Shrek & Slimes

In [None]:
# Parse the contaminated indices from the text file
contaminated_indices = []
with open('shrek_and_slimes.txt', 'r') as f:
    for line in f:
        line = line.strip()
        if line and line.isdigit():
            contaminated_indices.append(int(line))

print(f"Found {len(contaminated_indices)} contaminated samples to remove")

# Define directories to clean (both img and mask directories)
train_img_dir_clean = PATH_PREFIX + 'data/train_img/'
train_mask_dir_clean = PATH_PREFIX + 'data/train_mask/'

# Remove corresponding image and mask files from both directories
removed_count = 0
for idx in contaminated_indices:
    img_name = f'img_{idx:04d}.png'
    mask_name = f'mask_{idx:04d}.png'
    
    # Remove from train_img directory
    img_path = os.path.join(train_img_dir_clean, img_name)
    if os.path.exists(img_path):
        os.remove(img_path)
        removed_count += 1
    
    # Remove from train_mask directory
    mask_path = os.path.join(train_mask_dir_clean, mask_name)
    if os.path.exists(mask_path):
        os.remove(mask_path)
        removed_count += 1

print(f"Removed {removed_count} files from organized directories")

# Update train_labels by removing contaminated indices
train_labels = train_labels[~train_labels['sample_index'].str.extract(r'(\d+)')[0].astype(int).isin(contaminated_indices)]
print(f"Training labels updated: {len(train_labels)} samples remaining")

## Augmentation

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Analyze class distribution after removal
class_distribution = train_labels['label'].value_counts().sort_index()
print("\n" + "="*60)
print("Class Distribution After Removal of Contaminated Images")
print("="*60)
print(class_distribution)
print(f"\nTotal samples: {len(train_labels)}")

# Calculate statistics
print("\n" + "="*60)
print("STATISTICS FOR AUGMENTATION")
print("="*60)

# Class with the most samples (majority)
max_class = class_distribution.max()
max_class_name = class_distribution.idxmax()
print(f"\nClass with the most samples (Majority): {max_class_name} ({max_class} samples)")

# Class with the fewest samples (minority)
min_class = class_distribution.min()
min_class_name = class_distribution.idxmin()
print(f"Class with the fewest samples (Minority): {min_class_name} ({min_class} samples)")

# Imbalance ratio
imbalance_ratio = max_class / min_class
print(f"\nImbalance ratio (Max/Min): {imbalance_ratio:.2f}x")

# Augmentation proposal
print("\n" + "="*60)
print("RECOMMENDED AUGMENTATION STRATEGY")
print("="*60)
print("\nAugmentations to apply (as suggested by the professor):")
print("  1. Horizontal Flip (p=0.5)")
print("  2. Vertical Flip (p=0.5)")
print("  3. Random Translation (0.2, 0.2)")
print("  4. Random Zoom/Scale (0.8, 1.2)")
print("  [EXCLUDE: Random Rotation - would change dimensions]\n")

# STRATEGY: All classes grow until reaching the same target number for ALL
print("\n" + "="*80)
print("BALANCED STRATEGY: ALL CLASSES GROW TO A FIXED AND EQUAL NUMBER")
print("="*80)

# ===== MODIFY THE TARGET NUMBER OF SAMPLES HERE =====
target_samples = 1000  # Desired number of samples for EACH class
# =====================================================

print(f"\nTarget: {target_samples} samples for EACH class")

augmentation_strategy_balanced = {}
total_to_generate = 0

for class_name in class_distribution.index:
    n_samples = class_distribution[class_name]
    n_needed = target_samples - n_samples
    n_augmentations = max(0, n_needed)  # We cannot have negative augmentations
    
    augmentation_strategy_balanced[class_name] = {
        'original': n_samples,
        'target': target_samples,
        'augment_count': n_augmentations,
        'ratio_multiplier': n_augmentations / n_samples if n_samples > 0 else 0
    }
    
    total_to_generate += n_augmentations

# Projection of the dataset after augmentation
print("\n" + "="*80)
print("DATASET AFTER BALANCED AUGMENTATION")
print("="*80)
print(f"{'Class':<20} {'Original':<15} {'New Augment':<15} {'Augmentations per image':<25} {'Total':<15}")
print("-" * 80)

total_original = 0
total_augmented = 0
for class_name in class_distribution.index:
    n_original = class_distribution[class_name]
    n_aug = augmentation_strategy_balanced[class_name]['augment_count']
    n_total = n_original + n_aug
    
    total_original += n_original
    total_augmented += n_total
    
    print(f"{class_name:<20} {n_original:<15} {n_aug:<15} {augmentation_strategy_balanced[class_name]['ratio_multiplier']:<25.2f} {n_total:<15}")

print("-" * 80)
print(f"{'TOTAL':<20} {total_original:<15} {total_to_generate:<15} {np.mean([augmentation_strategy_balanced[class_name]['ratio_multiplier'] for class_name in class_distribution.index]):<25.2f} {total_augmented:<15}")

# Visualize the distribution before and after
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Before
class_distribution.plot(kind='bar', ax=axes[0], color='steelblue')
axes[0].set_title('Class Distribution - BEFORE Augmentation', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Number of samples')
axes[0].set_xlabel('Class')
axes[0].axhline(y=target_samples, color='red', linestyle='--', linewidth=2, label=f'Target: {target_samples}')
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# After
after_augmentation_balanced = {}
for class_name in class_distribution.index:
    after_augmentation_balanced[class_name] = augmentation_strategy_balanced[class_name]['target']

after_series = pd.Series(after_augmentation_balanced)
after_series.plot(kind='bar', ax=axes[1], color='seagreen')
axes[1].set_title('Class Distribution - AFTER Balanced Augmentation', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Number of samples')
axes[1].set_xlabel('Class')
axes[1].axhline(y=target_samples, color='red', linestyle='--', linewidth=2, label=f'Target: {target_samples}')
axes[1].set_ylim([0, max_class * 1.1])
axes[1].legend()
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Create folder for augmented data if it doesn't exist
# Use DATA_TYPE-specific directory to keep IMG and MASK augmentations separate
augmented_dir = PATH_PREFIX + f'data/train_data_augmented_{DATA_TYPE.lower()}/'
if not os.path.exists(augmented_dir):
    os.makedirs(augmented_dir)
    print(f"Created directory: {augmented_dir}")
else:
    existing_files = len(os.listdir(augmented_dir))
    print(f"Directory already exists: {augmented_dir}")
    print(f"Found {existing_files} existing augmented files for DATA_TYPE={DATA_TYPE}")

# Define augmentations for each class
augmentation_transforms = {
    'flip': transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
    ]),
    'translation': transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=None),
    ]),
    'zoom': transforms.Compose([
        transforms.RandomAffine(degrees=0, translate=None, scale=(0.8, 1.2)),
    ]),
    'combined': transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    ])
}

print("\n" + "="*80)
print("STARTING AUGMENTATION PROCESS")
print("="*80)

# Loop through each class and generate augmentations
total_augmented = 0

for class_name in sorted(augmentation_strategy_balanced.keys()):
    info = augmentation_strategy_balanced[class_name]
    n_augment = info['augment_count']
    
    if n_augment == 0:
        print(f"\n{class_name}: No augmentation needed (already at target)")
        continue
    
    print(f"\n{'-'*80}")
    print(f"Class: {class_name}")
    print(f"Augmentations to generate: {n_augment}")
    print(f"{'-'*80}")
    
    # Get original images of this class
    class_samples = train_labels[train_labels['label'] == class_name]['sample_index'].tolist()
    n_original = len(class_samples)
    
    # Calculate how many augmentations per original image
    aug_per_img = n_augment / n_original
    
    # For each original image
    aug_count = 0
    for img_idx, img_name in enumerate(class_samples):
        # Determine which file to load based on DATA_TYPE
        if DATA_TYPE == "MASK":
            file_name = img_name.replace('img_', 'mask_')
        else:  # IMG
            file_name = img_name
        
        img_path = os.path.join(train_dir, file_name)
        
        if not os.path.exists(img_path):
            print(f"  File not found: {file_name}")
            continue
        
        # Load the original image/mask
        if DATA_TYPE == "MASK":
            img = Image.open(img_path).convert('L')  # Grayscale for masks
        else:  # IMG
            img = Image.open(img_path).convert('RGB')
        img_pil = img.copy()
        
        # Generate augmentations for this image
        n_to_generate = int(np.ceil(aug_per_img)) if img_idx < n_augment % n_original else int(np.floor(aug_per_img))
        
        for aug_num in range(n_to_generate):
            if aug_count <= n_augment:
                base_name = file_name.replace('.png', '')

                # Choose an augmentation type cyclically
                aug_types = list(augmentation_transforms.keys())
                aug_type = aug_types[aug_count % len(aug_types)]
                transform = augmentation_transforms[aug_type]
                img_augmented = transform(img_pil)
                augmented_img_name = f"{base_name}_aug_{aug_num}_{aug_type}.png"
                
                # Save augmented image
                augmented_img_path = os.path.join(augmented_dir, augmented_img_name)
                img_augmented.save(augmented_img_path)
                
            aug_count += 1
        
        # Progress update
        if (img_idx + 1) % max(1, n_original // 5) == 0 or img_idx == n_original - 1:
            print(f"  Processed {img_idx + 1}/{n_original} original samples ({aug_count} augmentations generated)")
    
    total_augmented += aug_count
    print(f"  {class_name}: Completed! {aug_count} augmentations generated")

print("\n" + "="*80)
print(f"AUGMENTATION COMPLETED!")
print(f"Total augmented images generated: {total_augmented}")
print(f"Save directory: {augmented_dir}")
print("="*80)


# Verify file countprint(f"First 5 files: {augmented_files[:5]}")

augmented_files = os.listdir(augmented_dir)
print(f"\nFiles in augmented folder: {len(augmented_files)}")

In [None]:
from torch.utils.data import TensorDataset

# Define target image size
IMG_SIZE = (224, 224)  # Standard size for many CNN architectures
# Create DataLoaders
BATCH_SIZE = 32

# ===== GPU OPTIMIZATION SETTINGS =====
# Optimal num_workers: 4 * num_GPUs for T4 x2
# Auto-detect environment and set appropriate values
import torch
if torch.cuda.is_available():
    NUM_WORKERS = min(8, 4 * torch.cuda.device_count())  # 4 workers per GPU
    PIN_MEMORY = True  # Faster CPU-to-GPU transfer
    PERSISTENT_WORKERS = True if NUM_WORKERS > 0 else False
else:
    # CPU-only environment
    NUM_WORKERS = 0  # Avoid multiprocessing overhead on CPU
    PIN_MEMORY = False
    PERSISTENT_WORKERS = False
# ======================================

# Load original + augmented images into tensors
print("\n" + "="*80)
print("LOADING BALANCED DATASET (Original + Augmented)")
print("="*80)
print(f"Current DATA_TYPE: {DATA_TYPE}")
print(f"Augmented directory: {augmented_dir}")

# Check if augmented directory exists and validate files
if not os.path.exists(augmented_dir):
    print(f"\nWARNING: Augmented directory does not exist!")
    print(f"Expected: {augmented_dir}")
    print(f"No augmented data will be loaded. Only original images will be used.")
    augmented_files = []
else:
    # Create list of augmented images
    augmented_files = os.listdir(augmented_dir)
    print(f"Augmented images found: {len(augmented_files)}")
    
    # Validate that augmented files match DATA_TYPE
    if len(augmented_files) > 0:
        sample_file = augmented_files[0]
        expected_prefix = 'mask_' if DATA_TYPE == "MASK" else 'img_'
        if not sample_file.startswith(expected_prefix):
            print(f"\nERROR: Augmented files don't match DATA_TYPE={DATA_TYPE}!")
            print(f"Found files starting with '{sample_file.split('_')[0]}_' but expected '{expected_prefix}'")
            print(f"To fix: Either regenerate augmentations or change DATA_TYPE setting.")
            raise ValueError(f"Augmented data mismatch: files don't match DATA_TYPE={DATA_TYPE}")
        else:
            print(f"Validation passed: Augmented files match DATA_TYPE={DATA_TYPE}")

# Create new dataframe with all images (original + augmented)
train_labels_augmented = train_labels.copy()

# Add augmented images
augmented_rows = []
for aug_img_name in augmented_files:
    # Extract original file name (works for both img_ and mask_ prefixes)
    # Format: {prefix}_{number}_aug_{aug_num}_{aug_type}.png
    base_name = aug_img_name.split('_aug_')[0] + '.png'
    
    # Find the class in the original dataframe
    # train_labels uses img_ prefix, so convert if needed
    if DATA_TYPE == "MASK":
        # Augmented file is mask_XXXX, but train_labels has img_XXXX
        search_name = base_name.replace('mask_', 'img_')
    else:
        search_name = base_name
    
    original_row = train_labels[train_labels['sample_index'] == search_name]
    if not original_row.empty:
        class_label = original_row.iloc[0]['label']
        augmented_rows.append({'sample_index': aug_img_name, 'label': class_label})

augmented_df = pd.DataFrame(augmented_rows)
train_labels_augmented = pd.concat([train_labels_augmented, augmented_df], ignore_index=True)

print(f"\nOriginal dataset: {len(train_labels)} samples")
print(f"Augmented dataset: {len(train_labels_augmented)} samples")
print(f"\nDistribution in augmented dataset:")
print(train_labels_augmented['label'].value_counts().sort_index())

# Load images into tensors (original + augmented)
def load_augmented_images_to_tensor(train_dir, augmented_dir, labels_df, img_size=IMG_SIZE, data_type="MASK"):
    """Load original and augmented images into tensors"""
    images = []
    labels = []
    
    for idx, row in labels_df.iterrows():
        img_name = row['sample_index']
        label = row['label']
        
        # Determine which folder to load from
        if '_aug_' not in img_name:
            # Original image - convert filename if needed for masks
            if data_type == "MASK":
                file_name = img_name.replace('img_', 'mask_')
            else:
                file_name = img_name
            img_path = os.path.join(train_dir, file_name)
        else:
            # Augmented image - already has correct prefix
            img_path = os.path.join(augmented_dir, img_name)
        
        if not os.path.exists(img_path):
            print(f"⚠ Warning: Image not found: {img_path}")
            continue
        try:
            if data_type == "MASK":
                # Load as grayscale and convert to 3-channel
                img = Image.open(img_path).convert('L')
                img = img.resize(img_size, Image.BILINEAR)
                img_array = np.array(img)
                img_array = np.stack([img_array, img_array, img_array], axis=-1)
            else:
                # Load as RGB
                img = Image.open(img_path).convert('RGB')
                img = img.resize(img_size, Image.BILINEAR)
                img_array = np.array(img)
        except Exception as e:
            print(f"Warning: Failed to load image {img_path}: {e}")
            continue
        images.append(img_array)
        labels.append(label)
    
    # Convert to tensors
    images = np.array(images)
    images_tensor = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
    
    label_map = {'Triple negative': 0, 'Luminal A': 1, 'Luminal B': 2, 'HER2(+)': 3}
    label_indices = [label_map[label] for label in labels]
    labels_tensor = torch.tensor(label_indices, dtype=torch.long)
    
    return images_tensor, labels_tensor, label_map

print("\nLoading images into tensors...")
X_train_augmented, y_train_augmented, label_map = load_augmented_images_to_tensor(
    train_dir, augmented_dir, train_labels_augmented, IMG_SIZE, DATA_TYPE
)

print(f"Images tensor shape: {X_train_augmented.shape}")
print(f"Labels tensor shape: {y_train_augmented.shape}")

# Split training/validation (stratified)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_augmented, y_train_augmented, test_size=0.2, random_state=42, stratify=y_train_augmented
)

print(f"\nTrain set augmented: {X_train.shape[0]} samples")
print(f"Validation set augmented: {X_val.shape[0]} samples")

# Create new DataLoaders with GPU optimizations
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

# DataLoader configuration (conditional persistent_workers)
train_loader_kwargs = {
    'batch_size': BATCH_SIZE,
    'shuffle': True,
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY
}
val_loader_kwargs = {
    'batch_size': BATCH_SIZE,
    'shuffle': False,
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY
}
# Only add persistent_workers if num_workers > 0 (not supported otherwise)
if NUM_WORKERS > 0:
    train_loader_kwargs['persistent_workers'] = PERSISTENT_WORKERS
    val_loader_kwargs['persistent_workers'] = PERSISTENT_WORKERS

train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)

print(f"Optimization: {NUM_WORKERS} workers, pin_memory={PIN_MEMORY}, persistent_workers={PERSISTENT_WORKERS}")

print(f"\nCreated DataLoaders:")
print(f"Val batches: {len(val_loader)}")
print(f"Train batches: {len(train_loader)}")

In [None]:
# Load all images and labels into tensors
def load_images_to_tensor(data_dir, img_size=IMG_SIZE):
    """Load all images from directory into a tensor with resizing"""
    # Determine which files to load based on DATA_TYPE
    if DATA_TYPE == "MASK":
        image_files = sorted([f for f in os.listdir(data_dir) if f.startswith('mask_')])
    else:  # IMG
        image_files = sorted([f for f in os.listdir(data_dir) if f.startswith('img_')])
    
    images = []
    for img_name in image_files:
        img_path = os.path.join(data_dir, img_name)
        
        # Load image with appropriate mode based on DATA_TYPE
        if DATA_TYPE == "MASK":
            img = Image.open(img_path).convert('L')  # Grayscale for masks
            # Convert grayscale to 3-channel for compatibility with model
            img_array = np.array(img.resize(img_size, Image.BILINEAR))
            img_array = np.stack([img_array, img_array, img_array], axis=-1)
        else:  # IMG
            img = Image.open(img_path).convert('RGB')
            img = img.resize(img_size, Image.BILINEAR)
            img_array = np.array(img)
        
        images.append(img_array)
    
    # Stack into numpy array: (N, H, W, C)
    images = np.array(images)
    # Convert to tensor and permute to (N, C, H, W)
    images_tensor = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
    
    return images_tensor, image_files

# Load test data
print(f"\nLoading test data using DATA_TYPE: {DATA_TYPE}")
X_test, test_filenames = load_images_to_tensor(test_dir)
print(f"Test images shape: {X_test.shape}")

test_dataset = TensorDataset(X_test)

# Create DataLoader with GPU optimizations
test_loader_kwargs = {
    'batch_size': BATCH_SIZE,
    'shuffle': False,
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY
}
if NUM_WORKERS > 0:
    test_loader_kwargs['persistent_workers'] = PERSISTENT_WORKERS

test_loader = DataLoader(test_dataset, **test_loader_kwargs)

print(f"\nDataLoader created:")
print(f"Test batches: {len(test_loader)}")
print(f"Optimization: {NUM_WORKERS} workers, pin_memory={PIN_MEMORY}")

In [None]:
input_shape = X_train.shape[1:]  # (C, H, W)
num_classes = len(label_map)

# ===== MULTI-GPU SETUP =====
# Check for multiple GPUs and set up DataParallel
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    device = torch.device('cuda:0')
    print(f"Found {num_gpus} GPU(s) available:")
    for i in range(num_gpus):
        print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
    if num_gpus > 1:
        print(f"Multi-GPU training enabled: Will use {num_gpus} GPUs with DataParallel")
    else:
        print(f"Single GPU training")
else:
    device = torch.device('cpu')
    num_gpus = 0
    print("No GPU available, using CPU")
# ===========================

In [None]:
import torch.nn as nn

# Number of training epochs
LEARNING_RATE = 1e-3
EPOCHS = 500
PATIENCE = 50

# Regularisation
DROPOUT_RATE = 0.2         # Dropout probability
L1_LAMBDA = 0            # L1 penalty
L2_LAMBDA = 0            # L2 penalty

# Set up loss function and optimizer
criterion = nn.CrossEntropyLoss()

# Print the defined parameters
print("Epochs:", EPOCHS)
print("Batch Size:", BATCH_SIZE)
print("Learning Rate:", LEARNING_RATE)
print("Dropout Rate:", DROPOUT_RATE)
print("L1 Penalty:", L1_LAMBDA)
print("L2 Penalty:", L2_LAMBDA)

# Download Pretrained Models

In [None]:
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights, EfficientNet_B0_Weights, EfficientNet_B3_Weights, VGG16_Weights

# ===== UNCOMMENT THE MODEL YOU WANT TO USE =====

# ResNet-18 (Smaller, faster)
model_pretrained = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
MODEL_NAME = "resnet18"

# ResNet-50 (Deeper, more powerful)
# model_pretrained = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# MODEL_NAME = "resnet50"

# EfficientNet-B0 (Efficient, good balance)
# model_pretrained = models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
# MODEL_NAME = "efficientnet_b0"

# EfficientNet-B3 (More powerful EfficientNet)
# model_pretrained = models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
# MODEL_NAME = "efficientnet_b3"

# VGG-16 (Classic architecture)
# model_pretrained = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
# MODEL_NAME = "vgg16"

# ===============================================

print(f"Loaded pretrained model: {MODEL_NAME}")
print(f"Model architecture:\n{model_pretrained}")

## Transfer Learning Setup

In [None]:
import torch.nn as nn

# ===== STEP 1: Freeze all layers in the feature extractor =====
for param in model_pretrained.parameters():
    param.requires_grad = False

print("\nAll feature extractor weights frozen")

# ===== STEP 2: Replace the classifier head =====
# Different models have different classifier layer names
if MODEL_NAME.startswith('resnet'):
    # ResNet has 'fc' as final layer
    num_features = model_pretrained.fc.in_features
    model_pretrained.fc = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced ResNet classifier: {num_features} -> 256 -> {num_classes}")
    
elif MODEL_NAME.startswith('efficientnet'):
    # EfficientNet has 'classifier' as final layer
    num_features = model_pretrained.classifier[1].in_features
    model_pretrained.classifier = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced EfficientNet classifier: {num_features} -> 256 -> {num_classes}")
    
elif MODEL_NAME.startswith('vgg'):
    # VGG has 'classifier' as a sequential module
    num_features = model_pretrained.classifier[0].in_features
    model_pretrained.classifier = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced VGG classifier: {num_features} -> 512 -> 256 -> {num_classes}")

# Move model to device FIRST
model_pretrained = model_pretrained.to(device)

# Then wrap with DataParallel if multiple GPUs are available
if num_gpus > 1:
    model_pretrained = nn.DataParallel(model_pretrained)
    print(f"Model wrapped with DataParallel for {num_gpus} GPUs")

print(f"\nModel ready for transfer learning on {device} ({num_gpus} GPU(s))")

In [None]:
from torchsummary import summary

# Display model architecture summary
print("\n" + "="*80)
print("MODEL SUMMARY")
print("="*80)
summary(model_pretrained, input_size=input_shape)

# Count trainable vs frozen parameters
total_params = sum(p.numel() for p in model_pretrained.parameters())
trainable_params = sum(p.numel() for p in model_pretrained.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print("\n" + "="*80)
print("PARAMETER STATISTICS")
print("="*80)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters (classifier only): {trainable_params:,}")
print(f"Frozen parameters (feature extractor): {frozen_params:,}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")
print("="*80)

In [None]:
# Define optimizer - ONLY train classifier parameters (feature extractor is frozen)
# Filter to get only parameters that require gradients (classifier layers)
trainable_params = filter(lambda p: p.requires_grad, model_pretrained.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=LEARNING_RATE, weight_decay=L2_LAMBDA)

# Enable mixed precision training for GPU acceleration
scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda'))

print(f"✓ Optimizer configured to train only classifier layers")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay (L2): {L2_LAMBDA}")

In [None]:
# GPU Memory and Utilization Monitoring
if torch.cuda.is_available():
    print("\n" + "="*80)
    print("GPU STATUS BEFORE TRAINING")
    print("="*80)
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
        print(f"  Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
        print(f"  Max Memory Allocated: {torch.cuda.max_memory_allocated(i) / 1024**3:.2f} GB")
    print("="*80)

# Training

In [None]:
# Initialize best model tracking variables
best_model = None
best_performance = float('-inf')

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, l1_lambda=0, l2_lambda=0):
    """
    Perform one complete training epoch through the entire training dataset.

    Args:
        model (nn.Module): The neural network model to train
        train_loader (DataLoader): PyTorch DataLoader containing training data batches
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss, MSELoss)
        optimizer (torch.optim): Optimization algorithm (e.g., Adam, SGD)
        scaler (GradScaler): PyTorch's gradient scaler for mixed precision training
        device (torch.device): Computing device ('cuda' for GPU, 'cpu' for CPU)
        l1_lambda (float): Lambda for L1 regularization
        l2_lambda (float): Lambda for L2 regularization

    Returns:
        tuple: (average_loss, f1 score) - Training loss and f1 score for this epoch
    """
    model.train()  # Set model to training mode

    running_loss = 0.0
    all_predictions = []
    all_targets = []

    # Iterate through training batches
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Move data to device (GPU/CPU)
        inputs, targets = inputs.to(device), targets.to(device)

        # Clear gradients from previous step
        optimizer.zero_grad(set_to_none=True)

        # Forward pass with mixed precision (if CUDA available)
        with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
            logits = model(inputs)
            loss = criterion(logits, targets)

            # Add L1 and L2 regularization
            l1_norm = sum(p.abs().sum() for p in model.parameters())
            l2_norm = sum(p.pow(2).sum() for p in model.parameters())
            loss = loss + l1_lambda * l1_norm + l2_lambda * l2_norm


        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Accumulate metrics
        running_loss += loss.item() * inputs.size(0)
        predictions = logits.argmax(dim=1)
        all_predictions.append(predictions.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

    # Calculate epoch metrics
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_f1 = f1_score(
        np.concatenate(all_targets),
        np.concatenate(all_predictions),
        average='weighted'
    )

    return epoch_loss, epoch_f1

In [None]:
def validate_one_epoch(model, val_loader, criterion, device):
    """
    Perform one complete validation epoch through the entire validation dataset.

    Args:
        model (nn.Module): The neural network model to evaluate (must be in eval mode)
        val_loader (DataLoader): PyTorch DataLoader containing validation data batches
        criterion (nn.Module): Loss function used to calculate validation loss
        device (torch.device): Computing device ('cuda' for GPU, 'cpu' for CPU)

    Returns:
        tuple: (average_loss, accuracy) - Validation loss and accuracy for this epoch

    Note:
        This function automatically sets the model to evaluation mode and disables
        gradient computation for efficiency during validation.
    """
    model.eval()  # Set model to evaluation mode

    running_loss = 0.0
    all_predictions = []
    all_targets = []

    # Disable gradient computation for validation
    with torch.no_grad():
        for inputs, targets in val_loader:
            # Move data to device
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass with mixed precision (if CUDA available)
            with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                logits = model(inputs)
                loss = criterion(logits, targets)

            # Accumulate metrics
            running_loss += loss.item() * inputs.size(0)
            predictions = logits.argmax(dim=1)
            all_predictions.append(predictions.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    # Calculate epoch metrics
    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_accuracy = f1_score(
        np.concatenate(all_targets),
        np.concatenate(all_predictions),
        average='weighted'
    )

    return epoch_loss, epoch_accuracy

In [None]:
def fit(model, train_loader, val_loader, epochs, criterion, optimizer, scaler, device,
        l1_lambda=0, l2_lambda=0, patience=0, evaluation_metric="val_f1", mode='max',
        restore_best_weights=True, writer=None, verbose=10, experiment_name=""):
    """
    Train the neural network model on the training data and validate on the validation data.

    Args:
        model (nn.Module): The neural network model to train
        train_loader (DataLoader): PyTorch DataLoader containing training data batches
        val_loader (DataLoader): PyTorch DataLoader containing validation data batches
        epochs (int): Number of training epochs
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss, MSELoss)
        optimizer (torch.optim): Optimization algorithm (e.g., Adam, SGD)
        scaler (GradScaler): PyTorch's gradient scaler for mixed precision training
        device (torch.device): Computing device ('cuda' for GPU, 'cpu' for CPU)
        l1_lambda (float): L1 regularization coefficient (default: 0)
        l2_lambda (float): L2 regularization coefficient (default: 0)
        patience (int): Number of epochs to wait for improvement before early stopping (default: 0)
        evaluation_metric (str): Metric to monitor for early stopping (default: "val_f1")
        mode (str): 'max' for maximizing the metric, 'min' for minimizing (default: 'max')
        restore_best_weights (bool): Whether to restore model weights from best epoch (default: True)
        writer (SummaryWriter, optional): TensorBoard SummaryWriter object for logging (default: None)
        verbose (int, optional): Frequency of printing training progress (default: 10)
        experiment_name (str, optional): Experiment name for saving models (default: "")

    Returns:
        tuple: (model, training_history) - Trained model and metrics history
    """

    # Initialize metrics tracking
    training_history = {
        'train_loss': [], 'val_loss': [],
        'train_f1': [], 'val_f1': []
    }

    # Configure early stopping if patience is set
    if patience > 0:
        patience_counter = 0
        best_metric = float('-inf') if mode == 'max' else float('inf')
        best_epoch = 0

    print(f"Training {epochs} epochs...")

    # Main training loop: iterate through epochs
    for epoch in range(1, epochs + 1):

        # Forward pass through training data, compute gradients, update weights
        train_loss, train_f1 = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler, device, l1_lambda, l2_lambda
        )

        # Evaluate model on validation data without updating weights
        val_loss, val_f1 = validate_one_epoch(
            model, val_loader, criterion, device
        )

        # Store metrics for plotting and analysis
        training_history['train_loss'].append(train_loss)
        training_history['val_loss'].append(val_loss)
        training_history['train_f1'].append(train_f1)
        training_history['val_f1'].append(val_f1)

        # Print progress every N epochs or on first epoch
        if verbose > 0:
            if epoch % verbose == 0 or epoch == 1:
                print(f"Epoch {epoch:3d}/{epochs} | "
                    f"Train: Loss={train_loss:.4f}, F1 Score={train_f1:.4f} | "
                    f"Val: Loss={val_loss:.4f}, F1 Score={val_f1:.4f}")

        # Early stopping logic: monitor metric and save best model
        if patience > 0:
            current_metric = training_history[evaluation_metric][-1]
            is_improvement = (current_metric > best_metric) if mode == 'max' else (current_metric < best_metric)

            if is_improvement:
                best_metric = current_metric
                best_epoch = epoch
                torch.save(model.state_dict(),"models/"+experiment_name+'_model.pt')
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping triggered after {epoch} epochs.")
                    break

    # Restore best model weights if early stopping was used
    if restore_best_weights and patience > 0:
        model.load_state_dict(torch.load("models/"+experiment_name+'_model.pt'))
        print(f"Best model restored from epoch {best_epoch} with {evaluation_metric} {best_metric:.4f}")

    # Save final model if no early stopping
    if patience == 0:
        torch.save(model.state_dict(), "models/"+experiment_name+'_model.pt')

    # Close TensorBoard writer
    if writer is not None:
        writer.close()

    return model, training_history

## Fitting

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Set experiment name for this run
EXPERIMENT_NAME = f"pretrained_{MODEL_NAME}_augmented"

# Train with augmented (balanced) dataset
print("\n" + "="*80)
print(f"TRAINING WITH PRETRAINED {MODEL_NAME.upper()} - TRANSFER LEARNING")
print("="*80)
print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print(f"Strategy: Frozen feature extractor + Trainable classifier")
print("="*80 + "\n")

# Train model and track training history using AUGMENTED dataset
model_pretrained, history = fit(
    model=model_pretrained,
    train_loader=train_loader,  # ← USE AUGMENTED LOADER
    val_loader=val_loader,      # ← USE AUGMENTED LOADER
    epochs=EPOCHS,
    criterion=criterion,
    optimizer=optimizer,
    scaler=scaler,
    device=device,
    verbose=1,
    experiment_name=EXPERIMENT_NAME,
    patience=PATIENCE
)

# Update best model if current performance is superior
if history['val_f1'][-1] > best_performance:
    best_model = model_pretrained    
    best_performance = history['val_f1'][-1]
    print(f"\n✓ New best model saved with F1 Score: {best_performance:.4f}")

# Identify High-Loss Samples (Data Quality Check)

In [None]:
def calculate_per_sample_loss(model, dataset, criterion, device):
    """
    Calculate loss for each individual sample in the dataset.
    
    Returns:
        losses: numpy array of per-sample losses
        predictions: numpy array of predicted labels
        targets: numpy array of true labels
    """
    model.eval()
    
    losses = []
    predictions = []
    targets = []
    
    with torch.no_grad():
        for i in range(len(dataset)):
            inputs, target = dataset[i]
            inputs = inputs.unsqueeze(0).to(device)  # Add batch dimension
            target_tensor = torch.tensor([target]).to(device)
            
            # Forward pass
            with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                logits = model(inputs)
                loss = criterion(logits, target_tensor)
            
            losses.append(loss.item())
            predictions.append(logits.argmax(dim=1).cpu().item())
            targets.append(target)
    
    return np.array(losses), np.array(predictions), np.array(targets)

print("Calculating per-sample losses on training set...")
train_losses, train_preds, train_targets = calculate_per_sample_loss(
    model_pretrained, train_dataset, criterion, device
)

print(f"\nLoss statistics:")
print(f"Mean loss: {train_losses.mean():.4f}")
print(f"Median loss: {np.median(train_losses):.4f}")
print(f"Max loss: {train_losses.max():.4f}")
print(f"Min loss: {train_losses.min():.4f}")
print(f"Std loss: {train_losses.std():.4f}")

# Identify high-loss samples
top_k = 50  # Number of worst samples to examine
worst_indices = np.argsort(train_losses)[-top_k:][::-1]  # Highest losses first

print(f"\n{'='*80}")
print(f"TOP {top_k} HIGHEST LOSS SAMPLES (Potential Data Quality Issues)")
print(f"{'='*80}")
print(f"{'Index':<10} {'Loss':<12} {'True Label':<20} {'Predicted':<20} {'Correct':<10}")
print('-' * 80)

reverse_label_map = {v: k for k, v in label_map.items()}
problematic_samples = []

for rank, idx in enumerate(worst_indices, 1):
    loss = train_losses[idx]
    true_label = reverse_label_map[train_targets[idx]]
    pred_label = reverse_label_map[train_preds[idx]]
    is_correct = train_targets[idx] == train_preds[idx]
    
    problematic_samples.append({
        'dataset_index': idx,
        'loss': loss,
        'true_label': true_label,
        'predicted_label': pred_label,
        'correct': is_correct
    })
    
    if rank <= 20:  # Print top 20
        print(f"{idx:<10} {loss:<12.4f} {true_label:<20} {pred_label:<20} {str(is_correct):<10}")

print('=' * 80)

In [None]:
# Visualize the worst samples
fig, axes = plt.subplots(5, 5, figsize=(20, 20))
axes = axes.flatten()

print(f"\nVisualizing top 25 highest-loss samples...")

for i in range(min(25, len(worst_indices))):
    idx = worst_indices[i]
    loss = train_losses[idx]
    true_label = reverse_label_map[train_targets[idx]]
    pred_label = reverse_label_map[train_preds[idx]]
    
    # Get the image tensor
    img_tensor, _ = train_dataset[idx]
    
    # Convert tensor to displayable image (C, H, W) -> (H, W, C)
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    
    # Display image
    axes[i].imshow(img)
    axes[i].set_title(
        f"Rank {i+1}: Loss={loss:.3f}\n"
        f"True: {true_label}\n"
        f"Pred: {pred_label}",
        fontsize=9,
        color='red' if true_label != pred_label else 'green'
    )
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Top 25 Highest Loss Training Samples', fontsize=16, y=1.001)
plt.show()

# Plot loss distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Histogram of losses
axes[0].hist(train_losses, bins=100, color='steelblue', alpha=0.7, edgecolor='black')
axes[0].axvline(train_losses.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {train_losses.mean():.3f}')
axes[0].axvline(np.median(train_losses), color='green', linestyle='--', linewidth=2, label=f'Median: {np.median(train_losses):.3f}')
axes[0].set_xlabel('Loss')
axes[0].set_ylabel('Number of Samples')
axes[0].set_title('Distribution of Per-Sample Losses')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Sorted losses
sorted_losses = np.sort(train_losses)
axes[1].plot(sorted_losses, color='steelblue', linewidth=2)
axes[1].axhline(train_losses.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {train_losses.mean():.3f}')
axes[1].set_xlabel('Sample Rank (sorted)')
axes[1].set_ylabel('Loss')
axes[1].set_title('Sorted Per-Sample Losses')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Remove High-Loss Samples (Data Cleaning)

In [None]:
# Define threshold for removing high-loss samples
# Option 1: Remove top N samples with highest loss
REMOVE_TOP_N = 100  # Adjust this value based on visual inspection

# Option 2: Remove samples above a certain loss percentile
LOSS_PERCENTILE_THRESHOLD = 95  # Remove top 5% highest losses

# Choose method (uncomment one)
METHOD = "top_n"  # Remove top N samples
# METHOD = "percentile"  # Remove by percentile

print(f"{'='*80}")
print(f"REMOVING HIGH-LOSS SAMPLES")
print(f"{'='*80}")

if METHOD == "top_n":
    # First, we need to get the actual worst indices (not limited by top_k)
    # Re-calculate worst_indices for removal (up to REMOVE_TOP_N)
    all_worst_indices = np.argsort(train_losses)[::-1]  # All samples sorted by loss (highest first)
    n_to_remove = min(REMOVE_TOP_N, len(train_losses))  # Don't try to remove more than available
    samples_to_remove = all_worst_indices[:n_to_remove]
    threshold_loss = train_losses[samples_to_remove[-1]]
    print(f"Method: Remove top {n_to_remove} samples")
    print(f"Loss threshold: {threshold_loss:.4f}")
else:
    # Remove samples above percentile threshold
    threshold_loss = np.percentile(train_losses, LOSS_PERCENTILE_THRESHOLD)
    samples_to_remove = np.where(train_losses > threshold_loss)[0]
    print(f"Method: Remove samples above {LOSS_PERCENTILE_THRESHOLD}th percentile")
    print(f"Loss threshold: {threshold_loss:.4f}")

print(f"Samples to remove: {len(samples_to_remove)}")
print(f"Original training set size: {len(train_dataset)}")
print(f"New training set size: {len(train_dataset) - len(samples_to_remove)}")

# Create mask for samples to keep
keep_mask = np.ones(len(train_dataset), dtype=bool)
keep_mask[samples_to_remove] = False

# Filter the datasets
X_train_cleaned = X_train[keep_mask]
y_train_cleaned = y_train[keep_mask]

print(f"\nCleaned dataset shapes:")
print(f"X_train: {X_train_cleaned.shape}")
print(f"y_train: {y_train_cleaned.shape}")

# Check class distribution after cleaning
print(f"\nClass distribution after cleaning:")
unique, counts = np.unique(y_train_cleaned.cpu().numpy(), return_counts=True)
for label_idx, count in zip(unique, counts):
    label_name = reverse_label_map[label_idx]
    print(f"  {label_name}: {count} samples")

# Create new cleaned DataLoader
train_dataset_cleaned = TensorDataset(X_train_cleaned, y_train_cleaned)

train_loader_kwargs = {
    'batch_size': BATCH_SIZE,
    'shuffle': True,
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY
}
if NUM_WORKERS > 0:
    train_loader_kwargs['persistent_workers'] = PERSISTENT_WORKERS

train_loader_cleaned = DataLoader(train_dataset_cleaned, **train_loader_kwargs)

print(f"\nNew DataLoader created:")
print(f"Train batches: {len(train_loader_cleaned)}")
print(f"{'='*80}")

## Retrain with Cleaned Data

In [None]:
# Reinitialize model with fresh weights
import torch.nn as nn

# ===== STEP 1: Freeze all layers in the feature extractor =====
for param in model_pretrained.parameters():
    param.requires_grad = False

print("\nAll feature extractor weights frozen")

# ===== STEP 2: Replace the classifier head =====
# Different models have different classifier layer names
if MODEL_NAME.startswith('resnet'):
    # ResNet has 'fc' as final layer
    num_features = model_pretrained.fc.in_features
    model_pretrained.fc = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced ResNet classifier: {num_features} -> 256 -> {num_classes}")
    
elif MODEL_NAME.startswith('efficientnet'):
    # EfficientNet has 'classifier' as final layer
    num_features = model_pretrained.classifier[1].in_features
    model_pretrained.classifier = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced EfficientNet classifier: {num_features} -> 256 -> {num_classes}")
    
elif MODEL_NAME.startswith('vgg'):
    # VGG has 'classifier' as a sequential module
    num_features = model_pretrained.classifier[0].in_features
    model_pretrained.classifier = nn.Sequential(
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(DROPOUT_RATE),
        nn.Linear(256, num_classes)
    )
    print(f"Replaced VGG classifier: {num_features} -> 512 -> 256 -> {num_classes}")

# Move model to device FIRST
model_pretrained = model_pretrained.to(device)

# Then wrap with DataParallel if multiple GPUs are available
if num_gpus > 1:
    model_pretrained = nn.DataParallel(model_pretrained)
    print(f"Model wrapped with DataParallel for {num_gpus} GPUs")

print(f"\nModel ready for transfer learning on {device} ({num_gpus} GPU(s))")

# Reinitialize optimizer and scaler
optimizer_cleaned = torch.optim.AdamW(model_pretrained.parameters(), lr=LEARNING_RATE, weight_decay=L2_LAMBDA)
scaler_cleaned = torch.amp.GradScaler(enabled=(device.type == 'cuda'))

# Set experiment name for cleaned model
EXPERIMENT_NAME_CLEANED = f"{EXPERIMENT_NAME}_cleaned"

print("\n" + "="*80)
print("TRAINING WITH CLEANED DATASET (High-loss samples removed)")
print("="*80)
print(f"Train loader: {len(train_loader_cleaned)} batches")
print(f"Val loader: {len(val_loader)} batches (unchanged)")
print("="*80 + "\n")

# Train model with cleaned data
cnn_model_cleaned, history_cleaned = fit(
    model=cnn_model_cleaned,
    train_loader=train_loader_cleaned,  # ← CLEANED LOADER
    val_loader=val_loader,              # Validation set unchanged
    epochs=EPOCHS,
    criterion=criterion,
    optimizer=optimizer_cleaned,
    scaler=scaler_cleaned,
    device=device,
    verbose=1,
    experiment_name=EXPERIMENT_NAME_CLEANED,
    patience=PATIENCE
)

# Update best model if current performance is superior
if history_cleaned['val_f1'][-1] > best_performance:
    best_model = cnn_model_cleaned
    best_performance = history_cleaned['val_f1'][-1]
    print(f"\n New best model saved with F1 Score: {best_performance:.4f}")
    print(f" Improvement from data cleaning!")

In [None]:
# Compare original vs cleaned training
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Loss comparison
axes[0, 0].plot(history['train_loss'], label='Original - Train', alpha=0.6, linestyle='--', color='#1f77b4')
axes[0, 0].plot(history['val_loss'], label='Original - Val', alpha=0.8, color='#1f77b4')
axes[0, 0].plot(history_cleaned['train_loss'], label='Cleaned - Train', alpha=0.6, linestyle='--', color='#ff7f0e')
axes[0, 0].plot(history_cleaned['val_loss'], label='Cleaned - Val', alpha=0.8, color='#ff7f0e')
axes[0, 0].set_title('Loss Comparison: Original vs Cleaned Data')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# F1 Score comparison
axes[0, 1].plot(history['train_f1'], label='Original - Train', alpha=0.6, linestyle='--', color='#1f77b4')
axes[0, 1].plot(history['val_f1'], label='Original - Val', alpha=0.8, color='#1f77b4')
axes[0, 1].plot(history_cleaned['train_f1'], label='Cleaned - Train', alpha=0.6, linestyle='--', color='#ff7f0e')
axes[0, 1].plot(history_cleaned['val_f1'], label='Cleaned - Val', alpha=0.8, color='#ff7f0e')
axes[0, 1].set_title('F1 Score Comparison: Original vs Cleaned Data')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Training loss only (zoomed)
axes[1, 0].plot(history['train_loss'], label='Original', alpha=0.8, color='#1f77b4')
axes[1, 0].plot(history_cleaned['train_loss'], label='Cleaned', alpha=0.8, color='#ff7f0e')
axes[1, 0].set_title('Training Loss: Original vs Cleaned Data')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Training Loss')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Validation F1 only (zoomed)
axes[1, 1].plot(history['val_f1'], label='Original', alpha=0.8, color='#1f77b4', marker='o')
axes[1, 1].plot(history_cleaned['val_f1'], label='Cleaned', alpha=0.8, color='#ff7f0e', marker='s')
axes[1, 1].set_title('Validation F1 Score: Original vs Cleaned Data')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Validation F1 Score')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary comparison
print("\n" + "="*80)
print("TRAINING COMPARISON SUMMARY")
print("="*80)
print(f"\nOriginal Dataset:")
print(f"  Best Val F1: {max(history['val_f1']):.4f}")
print(f"  Final Val F1: {history['val_f1'][-1]:.4f}")
print(f"  Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"  Final Val Loss: {history['val_loss'][-1]:.4f}")

print(f"\nCleaned Dataset (removed {len(samples_to_remove)} high-loss samples):")
print(f"  Best Val F1: {max(history_cleaned['val_f1']):.4f}")
print(f"  Final Val F1: {history_cleaned['val_f1'][-1]:.4f}")
print(f"  Final Train Loss: {history_cleaned['train_loss'][-1]:.4f}")
print(f"  Final Val Loss: {history_cleaned['val_loss'][-1]:.4f}")

improvement = max(history_cleaned['val_f1']) - max(history['val_f1'])
print(f"\nImprovement: {improvement:+.4f} ({improvement*100:+.2f}%)")
print("="*80)

## Plotting

In [None]:
# Get validation predictions
val_preds = []
val_targets = []
best_model.eval()

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        logits = best_model(inputs)
        preds = logits.argmax(dim=1).cpu().numpy()
        
        val_preds.append(preds)
        val_targets.append(targets.numpy())

val_preds = np.concatenate(val_preds)
val_targets = np.concatenate(val_targets)

# Calculate overall validation set metrics
val_acc = accuracy_score(val_targets, val_preds)
val_prec = precision_score(val_targets, val_preds, average='weighted')
val_rec = recall_score(val_targets, val_preds, average='weighted')
val_f1 = f1_score(val_targets, val_preds, average='weighted')

print(f"Accuracy over the validation set: {val_acc:.4f}")
print(f"Precision over the validation set: {val_prec:.4f}")
print(f"Recall over the validation set: {val_rec:.4f}")
print(f"F1 score over the validation set: {val_f1:.4f}")

# Generate confusion matrix
cm = confusion_matrix(val_targets, val_preds)
labels = np.array([f"{num}" for num in cm.flatten()]).reshape(cm.shape)

# Visualize confusion matrix
plt.figure(figsize=(8, 7))
sns.heatmap(cm, annot=labels, fmt='', cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix — Validation Set')
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Create a figure with two side-by-side subplots (two columns)
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(18, 5))

# Plot of training and validation loss on the first axis
ax1.plot(history['train_loss'], label='Training loss', alpha=0.3, color='#ff7f0e', linestyle='--')
ax1.plot(history['val_loss'], label='Validation loss', alpha=0.9, color='#ff7f0e')
ax1.set_title('Loss')
ax1.legend()
ax1.grid(alpha=0.3)

# Plot of training and validation accuracy on the second axis
ax2.plot(history['train_f1'], label='Training f1', alpha=0.3, color='#ff7f0e', linestyle='--')
ax2.plot(history['val_f1'], label='Validation f1', alpha=0.9, color='#ff7f0e')
ax2.set_title('F1 Score')
ax2.legend()
ax2.grid(alpha=0.3)

# Adjust the layout and display the plot
plt.tight_layout()
plt.subplots_adjust(right=0.85)
plt.show()

# Inference

In [None]:
# Collect predictions
test_preds = []
best_model.eval()  # Set model to evaluation mode

with torch.no_grad():  # Disable gradient computation for inference
    for batch in test_loader:
        xb = batch[0].to(device)  # Extract tensor from tuple and move to device

        # Forward pass: get model predictions
        logits = best_model(xb)
        preds = logits.argmax(dim=1).cpu().numpy()

        # Store batch results
        test_preds.append(preds)

# Combine all batches into single array
test_preds = np.concatenate(test_preds)

In [None]:
# Create reverse label mapping
reverse_label_map = {v: k for k, v in label_map.items()}

test_filenames = [fn.replace('mask', 'img') for fn in test_filenames]

# Create submission dataframe
submission_df = pd.DataFrame({
    'sample_index': test_filenames,
    'label': [reverse_label_map[pred] for pred in test_preds]
})

# Create descriptive filename with all hyperparameters
filename_parts = [
    f"submission_{EXPERIMENT_NAME}",
    f"data_{DATA_TYPE}",
    f"bs_{BATCH_SIZE}",
    f"lr_{LEARNING_RATE}",
    f"drop_{DROPOUT_RATE}",
    f"l1_{L1_LAMBDA}",
    f"l2_{L2_LAMBDA}",
    f"epochs_{EPOCHS}",
    f"patience_{PATIENCE}",
    f"imgsize_{IMG_SIZE[0]}x{IMG_SIZE[1]}",
    f"f1_{val_f1:.4f}"
]
submission_filename = "_".join(filename_parts) + ".csv"

# Save to CSV
submission_df.to_csv(submission_filename, index=False)
print(f"Submission file created: {submission_filename}")
print(f"Total predictions: {len(submission_df)}")
print("\nFirst few predictions:")
print(submission_df.head(10))