In [4]:
import numpy as np

import random
from deap import base, creator, tools, algorithms

import pandas as pd
from sklearn.model_selection import train_test_split

from jointwise.preprocessing import DataProcessor, ImprovedSubjectLevelSplitter, MRIAugmentationWithBboxes, BoundingBoxAwareOversampler, MRIDatasetWithBboxes
from jointwise.model import BaseModelBuilder, MultiTaskModel, MultiTaskTrainer

import matplotlib.patches as patches

from pathlib import Path
import cv2

import torch
from torch.utils.data import DataLoader

In [5]:
# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

random.seed(42)

# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Environment setup complete.")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print(f"Using device: {device}")

Environment setup complete.
PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3060
CUDA memory: 12.0 GB
Using device: cuda


## Data Preprocessing

In [6]:
data_processor = DataProcessor('annotations/knee.csv', 'png-output')
annotations = data_processor.load_annotations()
subject_labels = data_processor.create_target_labels()
available_images = data_processor.get_available_images()

print(f"\nReady to process {len(available_images)} images from {len(subject_labels)} subject entries")

# Let's also check if there are any cases that were previously "Both"
print("\nChecking for cases with both ACL and Meniscus tears:")
original_files = set()
acl_files = set()
meniscus_files = set()

for file_id, label in subject_labels.items():
    if file_id.endswith('_ACL'):
        acl_files.add(file_id[:-4])  # Remove _ACL suffix
    elif file_id.endswith('_Meniscus'):
        meniscus_files.add(file_id[:-9])  # Remove _Meniscus suffix
    else:
        original_files.add(file_id)

both_cases = acl_files.intersection(meniscus_files)
print(f"Found {len(both_cases)} cases that have both ACL and Meniscus tears (now treated as separate entries)")
if len(both_cases) > 0:
    print(f"Examples: {list(both_cases)[:5]}")  # Show first 5 examples

Loaded 16167 annotations
Unique files: 974
Label distribution:
label
Meniscus Tear                                5658
Cartilage - Partial Thickness loss/defect    2985
Joint Effusion                               1311
Bone-Fracture/Contusion/dislocation          1060
Bone- Subchondral edema                       986
Periarticular cysts                           864
Ligament - ACL Low Grade sprain               765
Ligament - ACL High Grade Sprain              677
Cartilage - Full Thickness loss/defect        615
Ligament - MCL Low-Mod Grade Sprain           285
Displaced Meniscal Tissue                     232
Bone - Lesion                                 183
Ligament - PCL Low-Mod grade sprain           142
LCL Complex - Low-Mod Grade Sprain            130
Soft Tissue Lesion                             90
Muscle Strain                                  65
Joint Bodies                                   38
Patellar Retinaculum - High grade sprain       24
Ligament - PCL High Grade      

In [7]:
# Test the improved subject-level splitter with balanced oversampling
print("Testing data splitting with the new approach...")

# No filtering needed since DataProcessor no longer creates 'Neither' labels
final_images = available_images  # Direct use, no filtering step needed
print(f"Images ready for training: {len(final_images)}")

# Show label distribution
filtered_labels = [img['label'] for img in final_images]
label_dist = pd.Series(filtered_labels).value_counts()
print("Label distribution:")
print(label_dist)

# For stratification, we need to work at subject level
# Group images by subject and get one label per subject (handling dual pathology cases)
subject_groups = {}
for img in final_images:
    file_id = img['file_id']
    label = img['label']

    if file_id not in subject_groups:
        subject_groups[file_id] = []
    subject_groups[file_id].append(label)

# For stratification, use primary pathology (if both, use first one found)
subject_stratify_labels = []
original_subjects = []

for file_id, labels in subject_groups.items():
    original_subjects.append(file_id)
    # Use the first label for stratification (ACL vs Meniscus)
    subject_stratify_labels.append(labels[0])

print(f"Total original subjects: {len(original_subjects)}")

# Show subject stratification distribution
stratify_dist = pd.Series(subject_stratify_labels).value_counts()
print("Subject stratification distribution:")
print(stratify_dist)

train_data, temp_data = train_test_split(final_images, test_size=0.3, random_state=42, stratify=[img['label'] for img in final_images])
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, stratify=[img['label'] for img in temp_data])

# And update the print statements to:
print(f"\nData split completed (image-level):")
print(f"Train: {len(train_data)} images")
print(f"Val: {len(val_data)} images")
print(f"Test: {len(test_data)} images")

train_labels = [img['label'] for img in train_data]
val_labels = [img['label'] for img in val_data]
test_labels = [img['label'] for img in test_data]

print(f"\nTrain label distribution:")
print(pd.Series(train_labels).value_counts())
print(f"\nVal label distribution:")
print(pd.Series(val_labels).value_counts())
print(f"\nTest label distribution:")
print(pd.Series(test_labels).value_counts())

# Apply augmentation-based oversampling to balance the training set
print(f"\nClass distribution before balancing:")
train_label_counts = pd.Series(train_labels).value_counts()
for label, count in train_label_counts.items():
    print(f"{label}: {count}")

# For now, let's just show the results without oversampling to verify the clean pipeline
print(f"\n✅ SUCCESS: Clean data pipeline working without 'Neither' filtering!")
print(f"\nFinal dataset sizes (before balancing):")
print(f"Train: {len(train_data)}")
print(f"Val: {len(val_data)}")
print(f"Test: {len(test_data)}")

Testing data splitting with the new approach...
Images ready for training: 6296
Label distribution:
Meniscus_tear    4858
ACL_tear         1438
Name: count, dtype: int64
Total original subjects: 753
Subject stratification distribution:
Meniscus_tear    583
ACL_tear         170
Name: count, dtype: int64

Data split completed (image-level):
Train: 4407 images
Val: 944 images
Test: 945 images

Train label distribution:
Meniscus_tear    3400
ACL_tear         1007
Name: count, dtype: int64

Val label distribution:
Meniscus_tear    729
ACL_tear         215
Name: count, dtype: int64

Test label distribution:
Meniscus_tear    729
ACL_tear         216
Name: count, dtype: int64

Class distribution before balancing:
Meniscus_tear: 3400
ACL_tear: 1007

✅ SUCCESS: Clean data pipeline working without 'Neither' filtering!

Final dataset sizes (before balancing):
Train: 4407
Val: 944
Test: 945


In [8]:
# Demonstrate how augmentations affect bounding boxes
def find_image_with_bboxes():
    """Find an image that has bounding box annotations"""

    print("Searching for images with bounding boxes...")

    # Search through final images to find one with bounding boxes
    for i, sample_item in enumerate(final_images[:20]):
        original_file_id = sample_item['file_id']
        if original_file_id.endswith('_ACL') or original_file_id.endswith('_Meniscus'):
            # Remove suffix for annotation lookup
            base_file_id = original_file_id.replace('_ACL', '').replace('_Meniscus', '')
        else:
            base_file_id = original_file_id

        # Find bounding boxes for this image
        image_path = sample_item['path']
        image_name = Path(image_path).stem  # e.g., "file1000002_032"

        # Extract file and slice info for annotation lookup
        parts = image_name.split('_')
        if len(parts) >= 2:
            file_part = parts[0]  # e.g., "file1000002"
            slice_part = int(parts[1])  # e.g., 32

            # Look for bounding boxes in annotations
            relevant_annotations = annotations[
                (annotations['file'] == file_part) &
                (annotations['slice'] == slice_part)
            ]

            if len(relevant_annotations) > 0:
                # Convert annotations to bounding box format
                bboxes = []
                bbox_labels = []

                for _, row in relevant_annotations.iterrows():
                    x, y, w, h = row['x'], row['y'], row['width'], row['height']
                    # Convert to [x_min, y_min, x_max, y_max] format
                    bbox = [x, y, x + w, y + h]
                    bboxes.append(bbox)
                    bbox_labels.append(row['label'])

                print(f"Found {len(bboxes)} bounding boxes for image {image_name}")
                return sample_item, bboxes, bbox_labels

    print("No images with bounding boxes found in the first 20 samples")
    return None, [], []

In [9]:
# Demonstration: Proper workflow with improved classes
def demonstrate_improved_workflow():
    """Show the correct workflow using both classes for their intended purposes"""

    print("="*70)
    print("IMPROVED WORKFLOW DEMONSTRATION")
    print("="*70)

    # Step 1: Use ImprovedSubjectLevelSplitter for data splitting ONLY
    print("\n🔄 STEP 1: Subject-Level Data Splitting")
    print("-" * 40)

    improved_splitter = ImprovedSubjectLevelSplitter(
        final_images, test_size=0.2, val_size=0.2, random_state=42
    )
    train_split, val_split, test_split = improved_splitter.split_subjects()

    # Step 2: Use BoundingBoxAwareOversampler for class balancing
    print("\n🔄 STEP 2: Augmentation-Based Class Balancing")
    print("-" * 40)

    # Create bbox-aware augmentation for oversampling
    bbox_augmentation = MRIAugmentationWithBboxes(image_size=(320, 320))
    oversample_transform = bbox_augmentation.get_train_augmentation_with_bboxes()

    # Create the enhanced oversampler
    bbox_oversampler = BoundingBoxAwareOversampler(data_processor, oversample_transform)

    # Balance only the training set (val/test should remain unbalanced for proper evaluation)
    print("Balancing training set only...")
    balanced_train = bbox_oversampler.oversample_with_augmentation(train_split)

    print(f"\nFinal dataset sizes:")
    print(f"Balanced Train: {len(balanced_train)}")
    print(f"Val (unbalanced): {len(val_split)}")
    print(f"Test (unbalanced): {len(test_split)}")

    return balanced_train, val_split, test_split

# Run the demonstration
print("Creating improved workflow...")
improved_train, improved_val, improved_test = demonstrate_improved_workflow()

Creating improved workflow...
IMPROVED WORKFLOW DEMONSTRATION

🔄 STEP 1: Subject-Level Data Splitting
----------------------------------------
Total original subjects: 753
Subject stratification distribution:
Meniscus_tear    583
ACL_tear         170
Name: count, dtype: int64

Data split completed:
Train: 3788 images from 451 original subjects
Val: 1218 images from 151 original subjects
Test: 1290 images from 151 original subjects

Train label distribution:
Meniscus_tear    2899
ACL_tear          889
Name: count, dtype: int64

Val label distribution:
Meniscus_tear    939
ACL_tear         279
Name: count, dtype: int64

Test label distribution:
Meniscus_tear    1020
ACL_tear          270
Name: count, dtype: int64

🔄 STEP 2: Augmentation-Based Class Balancing
----------------------------------------
Balancing training set only...

Class Distribution before augmentation-based oversampling:
	Meniscus_tear: 2899
	ACL_tear: 889
Generating 2010 augmented samples for class 'ACL_tear'...


TypeError: string indices must be integers, not 'str'

In [None]:
# Initialize model builder
model_builder = BaseModelBuilder(num_classes = 2)

## Model Specifications

In [None]:
from collections import Counter

# Get all labels from the training dataset
all_labels = [item['label'] for item in train_data]
class_counts = Counter(all_labels)
total_samples = len(all_labels)
num_classes = len(class_counts)

# Calculate weights (inverse frequency)
multitask_class_weights = []
unique_labels = sorted(class_counts.keys())  # Ensure consistent order
for label in unique_labels:
    count = class_counts[label]
    weight = total_samples / (num_classes * count)
    multitask_class_weights.append(weight)

multitask_class_weights = torch.tensor(multitask_class_weights, dtype=torch.float32).to(device)
print(f"Multi-task class weights: {multitask_class_weights}")

In [None]:
def multitask_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    # For bounding boxes, pad or select first bbox (if multiple per image)
    bboxes = []
    for item in batch:
        if len(item['bboxes']) > 0:
            # Use first bbox, or pad if needed
            bboxes.append(torch.tensor(item['bboxes'][0], dtype=torch.float32))
        else:
            bboxes.append(torch.zeros(4, dtype=torch.float32))  # No bbox, fill with zeros
    bboxes = torch.stack(bboxes)
    return {'image': images, 'label': labels, 'bboxes': bboxes}

In [None]:
multitask_batch_size = 8  # or your preferred batch size

# Create multi-task datasets
multitask_train_dataset = MRIDatasetWithBboxes(train_data, data_processor, transform=MRIAugmentationWithBboxes().get_train_augmentation_with_bboxes())
multitask_val_dataset = MRIDatasetWithBboxes(val_data, data_processor, transform=MRIAugmentationWithBboxes().get_val_augmentation_with_bboxes())

# Create multi-task DataLoaders
multitask_train_loader = DataLoader(
    multitask_train_dataset,
    batch_size=multitask_batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    collate_fn=multitask_collate_fn
)
multitask_val_loader = DataLoader(
    multitask_val_dataset,
    batch_size=multitask_batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    collate_fn=multitask_collate_fn
)

## Hyperparameter Tuning using the Genetic Algorithm

In [None]:
def multitask_evaluate_individual(individual):
    # Unpack hyperparameters
    learning_rate, dropout_rate, classification_weight, bbox_weight, weight_decay = individual

    backbones = ['xception']
    val_accs = []

    for backbone in backbones:
        model = MultiTaskModel(
            backbone_name=backbone,
            num_classes=2,
            num_bbox_coords=4,
            dropout_rate=dropout_rate,
            freeze_backbone=True
        ).to(device)

        trainer = MultiTaskTrainer(
            model,
            device,
            class_weights=multitask_class_weights,
            classification_weight=classification_weight,
            bbox_weight=bbox_weight
        )

        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        try:
            history = trainer.train(
                multitask_train_loader,
                multitask_val_loader,
                epochs=3,  # Fast evaluation
                learning_rate=learning_rate,
                patience=2
            )
            val_acc = max(history['val_acc'])
        except Exception as e:
            print(f"Error during GA evaluation for {backbone}: {e}")
            val_acc = 0.0

        val_accs.append(val_acc)
        del model, trainer, optimizer
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Aggregate fitness (mean validation accuracy across backbones)
    avg_val_acc = np.mean(val_accs)
    return (avg_val_acc,)

def custom_mutate(individual, mu, sigma, indpb):
    # Apply Gaussian mutation as before
    individual, = tools.mutGaussian(individual, mu, sigma, indpb)

    # Clamp the learning rate to be within its valid range (1e-5, 1e-3)
    individual[0] = max(1e-5, min(individual[0], 1e-3))
    individual[1] = max(0.3, min(individual[1], 0.7)) # dropout_rate
    individual[2] = max(0.5, min(individual[2], 2.0)) # classification_weight
    individual[3] = max(0.5, min(individual[3], 2.0)) # bbox_weight
    individual[4] = max(1e-6, min(individual[4], 1e-2)) # weight_decay

    return individual,

# DEAP setup
if hasattr(creator, 'FitnessMax'):
    del creator.FitnessMax
if hasattr(creator, 'Individual'):
    del creator.Individual

creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

toolbox = base.Toolbox()
toolbox.register("learning_rate", random.uniform, 1e-5, 1e-3)
toolbox.register("dropout_rate", random.uniform, 0.3, 0.7)
toolbox.register("classification_weight", random.uniform, 0.5, 2.0)
toolbox.register("bbox_weight", random.uniform, 0.5, 2.0)
toolbox.register("weight_decay", random.uniform, 1e-6, 1e-2)
toolbox.register("individual", tools.initCycle, creator.Individual,
                 (toolbox.learning_rate, toolbox.dropout_rate,
                  toolbox.classification_weight, toolbox.bbox_weight,
                  toolbox.weight_decay), n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", multitask_evaluate_individual)
toolbox.register("mate", tools.cxTwoPoint)

# Use the custom mutation function
toolbox.register("mutate", custom_mutate, mu=0, sigma=0.1, indpb=0.2)

toolbox.register("select", tools.selTournament, tournsize=3)

# Run GA optimization
print("Starting genetic algorithm for MultiTaskModel hyperparameters...")
population = toolbox.population(n=4)
stats = tools.Statistics(lambda ind: ind.fitness.values)
stats.register("avg", np.mean)
stats.register("max", np.max)

population, logbook = algorithms.eaSimple(
    population, toolbox,
    cxpb=0.7, mutpb=0.3,
    ngen=2, stats=stats, verbose=True
)

best_ind = tools.selBest(population, 1)[0]
best_params = {
    'learning_rate': best_ind[0],
    'dropout_rate': best_ind[1],
    'classification_weight': best_ind[2],
    'bbox_weight': best_ind[3],
    'weight_decay': best_ind[4],
    'fitness': best_ind.fitness.values[0]
}
print("Best hyperparameters found by GA:")
for k, v in best_params.items():
    print(f"{k}: {v}")

## Model Backbone Training

In [None]:
backbones = ['resnext50', 'densenet201', 'efficientnet_b0', 'xception']
results = {}

for backbone in backbones:
    print(f"\nTraining MultiTaskModel with backbone: {backbone}")
    multitask_model = MultiTaskModel(
        backbone_name=backbone,
        num_classes=2,
        num_bbox_coords=4,
        dropout_rate=best_params['dropout_rate'],      # Use optimized dropout
        freeze_backbone=True
    ).to(device)

    multitask_trainer = MultiTaskTrainer(
        multitask_model,
        device,
        class_weights=multitask_class_weights,
        classification_weight=best_params['classification_weight'],  # Use optimized weight
        bbox_weight=best_params['bbox_weight']                       # Use optimized weight
    )

    history = multitask_trainer.train(
        multitask_train_loader,
        multitask_val_loader,
        epochs=30,
        learning_rate=best_params['learning_rate'],    # Use optimized learning rate
        patience=10,
        save_path=f'best_multitask_model_{backbone}.pth'
    )
    results[backbone] = history