In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# **Base line code**

In [None]:
# Data preparation
import os

data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset_split"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=transform)
val_dataset = datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
test_dataset = datasets.ImageFolder(root=f"{data_dir}/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

In [None]:
# ======================
# Quick Dataset Analysis and Baseline Reference
# ======================
print("=" * 60)
print("BASELINE DATASET ANALYSIS")
print("=" * 60)

# Quick dataset size check
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

# Check class distribution in training set only (for speed)
from collections import Counter
train_class_counts = Counter()
for _, label in train_dataset:
    train_class_counts[label] += 1

print("\nTRAIN SET CLASS DISTRIBUTION:")
for i, class_name in enumerate(train_dataset.classes):
    count = train_class_counts[i]
    percentage = (count / len(train_dataset)) * 100
    print(f"  {class_name:<15}: {count:>5} samples ({percentage:>5.1f}%)")

# Calculate imbalance ratio
counts = list(train_class_counts.values())
if counts:
    max_count = max(counts)
    min_count = min(counts)
    imbalance_ratio = max_count / min_count
    print(f"\nImbalance ratio (max/min): {imbalance_ratio:.1f}:1")

print("\nClass indices mapping:")
for idx, class_name in enumerate(train_dataset.classes):
    print(f"  Index {idx}: {class_name}")
    
print("\n" + "=" * 60)
print("NOTE: Improved model MUST use these EXACT same numbers for fair comparison!")
print("=" * 60)

In [None]:
# Create train/val/test split if needed
import os
import shutil
import random
from sklearn.model_selection import train_test_split

def create_split(source_dir, dest_dir):
    if os.path.exists(dest_dir):
        return
    
    random.seed(42)
    os.makedirs(dest_dir, exist_ok=True)
    
    for split in ['train', 'val', 'test']:
        os.makedirs(f"{dest_dir}/{split}", exist_ok=True)
    
    class_dirs = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]
    
    for class_name in class_dirs:
        class_path = os.path.join(source_dir, class_name)
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        random.shuffle(image_files)
        n_total = len(image_files)
        n_train = int(n_total * 0.7)
        n_val = int(n_total * 0.15)
        
        splits = {
            'train': image_files[:n_train],
            'val': image_files[n_train:n_train + n_val],
            'test': image_files[n_train + n_val:]
        }
        
        for split, files in splits.items():
            split_class_dir = f"{dest_dir}/{split}/{class_name}"
            os.makedirs(split_class_dir, exist_ok=True)
            for file in files:
                shutil.copy2(os.path.join(class_path, file), os.path.join(split_class_dir, file))

source_data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset"
split_data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset_split"

create_split(source_data_dir, split_data_dir)
print("Data split completed")

In [None]:
# ======================
# 2. Model definition
# ======================
num_classes = 4  # No Finding, Pneumonia, Effusion, Cardiomegaly (following project deliverable)
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
# ======================
# 3. Loss and optimizer
# ======================
# For baseline, use unweighted loss first to see the class imbalance effect
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:

# ======================
# 4. Training and evaluation functions
# ======================
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / len(loader), correct / total

def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return running_loss / len(loader), correct / total



In [None]:
# ======================
# 5. Training loop
# ======================
num_epochs = 5
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
          f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

In [None]:
# ======================
# 6. Final test evaluation with Detailed Metrics
# ======================

# Get detailed predictions for baseline model
def evaluate_with_predictions(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), correct / total, all_preds, all_labels

# Evaluate baseline model with detailed metrics
test_loss, test_acc, test_preds_baseline, test_labels_baseline = evaluate_with_predictions(model, test_loader, criterion)

print("=" * 60)
print("BASELINE MODEL RESULTS")
print("=" * 60)
print(f"Final Test: Loss={test_loss:.4f}, Acc={test_acc:.4f}")
print()

# Import necessary libraries for detailed metrics
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Class names (ImageFolderと統一されたクラス順序)
class_names = ['Cardiomegaly', 'Effusion', 'No Finding', 'Pneumonia']  # ImageFolderと同じ順序

# Detailed classification report
print("Classification Report:")
print(classification_report(test_labels_baseline, test_preds_baseline, target_names=class_names))

# ImageFolderのクラス順序を確認（アルファベット順）
print("ImageFolder class order:", train_dataset.classes)

# Confusion Matrix
cm_baseline = confusion_matrix(test_labels_baseline, test_preds_baseline)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_baseline, annot=True, fmt='d', cmap='Reds', 
            xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
plt.title('Baseline Model - Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Calculate per-class metrics (ImageFolderの実際のクラス順序を使用)
precision_baseline, recall_baseline, f1_baseline, support_baseline = precision_recall_fscore_support(test_labels_baseline, test_preds_baseline)

print("\nPer-class Metrics:")
for i, class_name in enumerate(train_dataset.classes):
    print(f"{class_name}:")
    print(f"  Precision: {precision_baseline[i]:.4f}")
    print(f"  Recall:    {recall_baseline[i]:.4f}")
    print(f"  F1-score:  {f1_baseline[i]:.4f}")
    print(f"  Support:   {support_baseline[i]}")
    print()

# **Improved model**

In [None]:
# ======================
# Improved Model Implementation
# ======================
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, WeightedRandomSampler
from PIL import Image
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

# Load and analyze CSV data
csv_path = "/content/drive/MyDrive/NIH_ChestXray/Data_Entry_2017.csv"
if not os.path.exists(csv_path):
  print("Failed")

else:
  print("Success")

df = pd.read_csv(csv_path)

print("CSV Data Analysis:")
print(f"Total samples: {len(df)}")
print(f"Unique images: {df['Image Index'].nunique()}")
print(f"\nFinding Labels distribution:")
print(df['Finding Labels'].value_counts().head(10))

# Target diseases for 4-class classification (ImageFolderと同じ順序に統一)
TARGET_DISEASES = ['Cardiomegaly', 'Effusion', 'No Finding', 'Pneumonia']

def process_labels(finding_labels):
    """Process finding labels to match our 4-class problem"""
    if pd.isna(finding_labels):
        return 'No Finding'

    # Handle multiple findings (separated by |)
    findings = [f.strip() for f in str(finding_labels).split('|')]

    # Priority order for classification when multiple findings exist
    priority_order = ['Pneumonia', 'Cardiomegaly', 'Effusion', 'No Finding']

    for disease in priority_order:
        if disease in findings:
            return disease

    # If none of the target diseases found, classify as 'No Finding'
    return 'No Finding'

# Process labels
df['Processed_Label'] = df['Finding Labels'].apply(process_labels)

# Create label distribution
label_counts = df['Processed_Label'].value_counts()
print(f"\nProcessed Label Distribution:")
for label, count in label_counts.items():
    percentage = (count / len(df)) * 100
    print(f"  {label}: {count} ({percentage:.1f}%)")

# Create class to index mapping
class_to_idx = {cls: idx for idx, cls in enumerate(TARGET_DISEASES)}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
print(f"\nClass mapping: {class_to_idx}")

# ======================
# Custom Dataset Class with CSV Metadata (Multimodal)
# ======================
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None, class_to_idx=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.class_to_idx = class_to_idx or class_to_idx

        # Preprocess metadata for normalization
        self.age_mean = self.df['Patient Age'].mean()
        self.age_std = self.df['Patient Age'].std()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image Index'])

        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='black')

        if self.transform:
            image = self.transform(image)

        # Process metadata following project deliverable requirements
        # Age: normalized
        age = (row['Patient Age'] - self.age_mean) / self.age_std

        # Gender: binary encoding (M=1, F=0)
        gender = 1.0 if row['Patient Gender'] == 'M' else 0.0

        # View Position: binary encoding (PA=1, AP=0)
        view_position = 1.0 if row['View Position'] == 'PA' else 0.0

        # Combine metadata into tensor
        metadata = torch.FloatTensor([age, gender, view_position])

        # Get label
        label = self.class_to_idx[row['Processed_Label']]

        return image, metadata, label

# Enhanced data transforms with augmentation for improved model
train_transform_improved = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomRotation(degrees=10),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.2),
    transforms.RandomCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform_improved = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Multimodal dataset class and transforms defined.")
print("Metadata features: Age (normalized), Gender (M=1, F=0), View Position (PA=1, AP=0)")
print("Modified to use same subset as baseline model for fair comparison")

# ======================
# Use EXACT Same ImageFolder as Baseline - SIMPLE APPROACH
# ======================

# Use the exact same directory as baseline (defined in earlier cells)
# Ensure consistency with baseline model data_dir
try:
    # Use data_dir from baseline if already defined
    baseline_data_dir = data_dir
    print(f"Using baseline data_dir: {baseline_data_dir}")
except NameError:
    # Fallback if data_dir not defined
    baseline_data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset_split"
    print(f"Using fallback data directory: {baseline_data_dir}")

# Load the EXACT same ImageFolder datasets as baseline
from torchvision import datasets
baseline_train_imagefolder = datasets.ImageFolder(root=f"{baseline_data_dir}/train")
baseline_val_imagefolder = datasets.ImageFolder(root=f"{baseline_data_dir}/val") 
baseline_test_imagefolder = datasets.ImageFolder(root=f"{baseline_data_dir}/test")

print(f"Using EXACT same baseline datasets:")
print(f"Train: {len(baseline_train_imagefolder)} samples")
print(f"Val: {len(baseline_val_imagefolder)} samples") 
print(f"Test: {len(baseline_test_imagefolder)} samples")

def create_metadata_lookup(csv_df):
    """Create a lookup dictionary for fast metadata retrieval"""
    metadata_dict = {}
    for _, row in csv_df.iterrows():
        img_name = row['Image Index']
        metadata_dict[img_name] = {
            'age': row.get('Patient Age', 50),  # Default age if missing
            'gender': row.get('Patient Gender', 'M'),  # Default gender
            'view': row.get('View Position', 'PA')  # Default view
        }
    return metadata_dict

# Create metadata lookup for fast access
metadata_lookup = create_metadata_lookup(df)
print(f"Created metadata lookup for {len(metadata_lookup)} images")

# ======================
# Simple Multimodal Dataset using ImageFolder + Metadata
# ======================
class ImageFolderWithMetadata(Dataset):
    """Wrapper around ImageFolder that adds metadata from CSV"""
    
    def __init__(self, imagefolder_dataset, metadata_lookup, transform=None):
        self.imagefolder = imagefolder_dataset
        self.metadata_lookup = metadata_lookup
        self.transform = transform
        
        # Calculate metadata statistics for normalization
        ages = []
        for i in range(len(self.imagefolder)):
            img_path, _ = self.imagefolder.samples[i]
            img_name = os.path.basename(img_path)
            if img_name in self.metadata_lookup:
                age = self.metadata_lookup[img_name]['age']
                if pd.notna(age):
                    ages.append(float(age))
        
        if ages:
            self.age_mean = np.mean(ages)
            self.age_std = np.std(ages)
        else:
            self.age_mean = 50.0  # Default
            self.age_std = 15.0   # Default

    def __len__(self):
        return len(self.imagefolder)

    def __getitem__(self, idx):
        # Get image and label from ImageFolder
        img_path, label = self.imagefolder.samples[idx]
        img_name = os.path.basename(img_path)
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Get metadata from lookup
        if img_name in self.metadata_lookup:
            metadata_info = self.metadata_lookup[img_name]
            age = float(metadata_info['age']) if pd.notna(metadata_info['age']) else 50.0
            gender = metadata_info['gender']
            view = metadata_info['view']
        else:
            # Use default values if no metadata found
            age = 50.0
            gender = 'M'
            view = 'PA'

        # Process metadata
        age_normalized = (age - self.age_mean) / self.age_std
        gender_encoded = 1.0 if gender == 'M' else 0.0
        view_encoded = 1.0 if view == 'PA' else 0.0
        
        metadata = torch.FloatTensor([age_normalized, gender_encoded, view_encoded])

        return image, metadata, label

# Create datasets using the exact same ImageFolder as baseline
train_dataset_improved = ImageFolderWithMetadata(
    baseline_train_imagefolder, 
    metadata_lookup, 
    train_transform_improved
)
val_dataset_improved = ImageFolderWithMetadata(
    baseline_val_imagefolder, 
    metadata_lookup, 
    val_test_transform_improved
)
test_dataset_improved = ImageFolderWithMetadata(
    baseline_test_imagefolder, 
    metadata_lookup, 
    val_test_transform_improved
)

print(f"Created improved datasets with metadata:")
print(f"Train: {len(train_dataset_improved)} samples")
print(f"Val: {len(val_dataset_improved)} samples")
print(f"Test: {len(test_dataset_improved)} samples")

# ======================
# Create Weighted Sampler for Balanced Training (Multimodal) - IMPROVED
# ======================
def create_weighted_sampler(dataset, target_labels, weight_smoothing='sqrt'):
    """Create a weighted sampler to handle class imbalance with smoothed weights"""
    # Get all labels from dataset (now returns 3 items: image, metadata, label)
    labels = [dataset[i][2] for i in range(len(dataset))]  # Changed index from 1 to 2

    # Calculate class counts
    class_counts = np.bincount(labels)
    print(f"Class counts for weighted sampler: {class_counts}")
    
    # Apply smoothed weighting to prevent over-correction
    if weight_smoothing == 'sqrt':
        # Use square root to reduce extreme weights
        class_weights = 1.0 / np.sqrt(class_counts)
    elif weight_smoothing == 'log':
        # Use logarithmic smoothing
        class_weights = 1.0 / np.log(class_counts + 1)
    else:
        # Original inverse frequency (not recommended)
        class_weights = 1.0 / class_counts
    
    print(f"Smoothed class weights ({weight_smoothing}): {class_weights}")
    
    # Create sample weights
    sample_weights = [class_weights[label] for label in labels]

    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

# Create weighted sampler for training with sqrt smoothing
# Use ImageFolder labels directly
train_labels_list = [train_dataset_improved.imagefolder.samples[i][1] for i in range(len(train_dataset_improved))]
train_sampler = create_weighted_sampler(train_dataset_improved, train_labels_list, weight_smoothing='sqrt')

# Custom collate function for multimodal data
def multimodal_collate_fn(batch):
    """Custom collate function to handle image, metadata, and labels"""
    images, metadata, labels = zip(*batch)

    # Stack images and labels as usual
    images = torch.stack(images, 0)
    labels = torch.LongTensor(labels)

    # Stack metadata
    metadata = torch.stack(metadata, 0)

    return images, metadata, labels

# Create data loaders with improved settings
train_loader_improved = DataLoader(
    train_dataset_improved,
    batch_size=32,
    sampler=train_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

val_loader_improved = DataLoader(
    val_dataset_improved,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

test_loader_improved = DataLoader(
    test_dataset_improved,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

print("Multimodal weighted sampler and data loaders created with improved weighting.")
print("Using SAME SUBSET as baseline model for fair comparison")
print(f"Train batches: {len(train_loader_improved)}")
print(f"Val batches: {len(val_loader_improved)}")
print(f"Test batches: {len(test_loader_improved)}")
print("Each batch contains: (images, metadata, labels)")

# Verify class distribution in training data
if len(train_dataset_improved) > 0:
    train_label_counts = np.bincount([train_dataset_improved[i][2] for i in range(len(train_dataset_improved))])
    print(f"Training set class distribution: {dict(zip(TARGET_DISEASES, train_label_counts))}")
else:
    print("No training data found - check subset directory and CSV matching")

# ======================
# Multimodal Improved Model Architecture
# ======================
class MultimodalChestXrayModel(nn.Module):
    def __init__(self, num_classes=4, pretrained=True):
        super(MultimodalChestXrayModel, self).__init__()

        # Image branch - CNN for visual feature extraction
        self.image_backbone = models.resnet50(pretrained=pretrained)

        # Freeze early layers to prevent overfitting
        for param in list(self.image_backbone.parameters())[:-40]:
            param.requires_grad = False

        # Remove the final classification layer from ResNet
        self.image_features = nn.Sequential(*list(self.image_backbone.children())[:-1])

        # Get number of features from ResNet50
        image_feature_dim = 2048  # ResNet50 output dimension

        # Metadata branch - MLP for tabular data (age, gender, view position)
        metadata_dim = 3  # age (normalized), gender (binary), view_position (binary: PA=1, AP=0)
        self.metadata_branch = nn.Sequential(
            nn.Linear(metadata_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Fusion layer - concatenate image and metadata features
        combined_feature_dim = image_feature_dim + 32  # 2048 + 32 = 2080

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(combined_feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, image, metadata):
        # Image branch
        image_features = self.image_features(image)
        image_features = torch.flatten(image_features, 1)  # Flatten to (batch_size, 2048)

        # Metadata branch
        metadata_features = self.metadata_branch(metadata)

        # Fusion - concatenate features
        combined_features = torch.cat([image_features, metadata_features], dim=1)

        # Final classification
        output = self.classifier(combined_features)

        return output

# Create multimodal improved model
model_improved = MultimodalChestXrayModel(num_classes=4, pretrained=True)
model_improved = model_improved.to(device)

# Calculate class weights for weighted loss - IMPROVED
train_labels = [train_dataset_improved[i][2] for i in range(len(train_dataset_improved))]

# Get original balanced weights
balanced_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)

# Apply smoothing to prevent over-correction
alpha = 0.3  # Smoothing parameter (0=no weighting, 1=full balanced weighting)
uniform_weights = np.ones_like(balanced_weights)
smoothed_weights = alpha * balanced_weights + (1 - alpha) * uniform_weights

class_weights_tensor = torch.FloatTensor(smoothed_weights).to(device)

print(f"Original balanced weights: {dict(zip(TARGET_DISEASES, balanced_weights))}")
print(f"Smoothed weights (alpha={alpha}): {dict(zip(TARGET_DISEASES, smoothed_weights))}")

# Improved loss and optimizer with smoothed weights
criterion_improved = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer_improved = optim.AdamW(model_improved.parameters(), lr=0.0005, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_improved, mode='min', patience=3, factor=0.5
)

print("Multimodal improved model architecture ready (Image + Metadata).")
print(f"Model parameters: {sum(p.numel() for p in model_improved.parameters()):,}")

# ======================
# Enhanced Training Functions (Multimodal)
# ======================
def train_one_epoch_improved(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (images, metadata, labels) in enumerate(loader):
        images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, metadata)  # Pass both image and metadata
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Print progress every 100 batches
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(loader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

def evaluate_improved(model, loader, criterion, device, verbose=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, metadata, labels in loader:
            images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)
            outputs = model(images, metadata)  # Pass both image and metadata
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader)
    epoch_acc = correct / total
    
    # Print class prediction distribution for monitoring
    if verbose:
        pred_counts = np.bincount(all_preds, minlength=len(TARGET_DISEASES))
        label_counts = np.bincount(all_labels, minlength=len(TARGET_DISEASES))
        print(f"  Predicted distribution: {dict(zip(TARGET_DISEASES, pred_counts))}")
        print(f"  True label distribution: {dict(zip(TARGET_DISEASES, label_counts))}")

    return epoch_loss, epoch_acc, all_preds, all_labels

print("Enhanced multimodal training functions defined.")
print("Functions now handle both image and metadata inputs.")

# ======================
# Data Integrity Check Before Training
# ======================
print("=" * 60)
print("DATA INTEGRITY CHECK")
print("=" * 60)

print("Baseline vs Improved Dataset Comparison:")
print(f"Baseline uses: {data_dir}")
print(f"Improved uses: {baseline_data_dir} (SAME AS BASELINE)")

# Quick sanity check
if len(train_dataset_improved) == 0:
    print("ERROR: No training data found! Check CSV and image matching.")
else:
    print(f"✓ Training data loaded: {len(train_dataset_improved)} samples")
    
    # Test loading a sample
    try:
        sample_img, sample_meta, sample_label = train_dataset_improved[0]
        print(f"✓ Sample data shapes: Image={sample_img.shape}, Metadata={sample_meta.shape}, Label={sample_label}")
    except Exception as e:
        print(f"ERROR loading sample: {e}")

print("=" * 60)

# ======================
# Train Improved Model
# ======================
print("Starting training of improved model...")
print("=" * 60)

num_epochs_improved = 10
best_val_acc = 0.0
patience_counter = 0
patience_limit = 5

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs_improved):
    print(f"\nEpoch {epoch+1}/{num_epochs_improved}")
    print("-" * 40)

    # Training phase
    train_loss, train_acc = train_one_epoch_improved(
        model_improved, train_loader_improved, optimizer_improved, criterion_improved, device
    )

    # Validation phase
    val_loss, val_acc, val_preds, val_labels = evaluate_improved(
        model_improved, val_loader_improved, criterion_improved, device, verbose=(epoch % 2 == 0)
    )

    # Update learning rate scheduler
    scheduler.step(val_loss)

    # Save metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Early stopping and model saving
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # Save best model
        torch.save(model_improved.state_dict(), 'best_improved_model.pth')
        print(f"New best validation accuracy: {best_val_acc:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience_limit:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print(f"\nTraining completed. Best validation accuracy: {best_val_acc:.4f}")

# ======================
# Evaluate Improved Model
# ======================
# Load best model
model_improved.load_state_dict(torch.load('best_improved_model.pth'))

# Evaluate on test set
test_loss, test_acc, test_preds_improved, test_labels_improved = evaluate_improved(
    model_improved, test_loader_improved, criterion_improved, device, verbose=True
)

print("=" * 60)
print("IMPROVED MODEL RESULTS")
print("=" * 60)
print(f"Final Test: Loss={test_loss:.4f}, Acc={test_acc:.4f}")
print()

# Detailed classification report
print("Classification Report:")
print(classification_report(test_labels_improved, test_preds_improved, target_names=TARGET_DISEASES))

# Confusion Matrix
cm_improved = confusion_matrix(test_labels_improved, test_preds_improved)
plt.figure(figsize=(10, 8))
sns.heatmap(cm_improved, annot=True, fmt='d', cmap='Blues',
            xticklabels=TARGET_DISEASES, yticklabels=TARGET_DISEASES)
plt.title('Improved Model - Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Calculate per-class metrics
precision_improved, recall_improved, f1_improved, support_improved = precision_recall_fscore_support(
    test_labels_improved, test_preds_improved
)

print("\nPer-class Metrics:")
for i, class_name in enumerate(TARGET_DISEASES):
    print(f"{class_name}:")
    print(f"  Precision: {precision_improved[i]:.4f}")
    print(f"  Recall:    {recall_improved[i]:.4f}")
    print(f"  F1-score:  {f1_improved[i]:.4f}")
    print(f"  Support:   {support_improved[i]}")
    print()

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Val Loss', marker='s')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Acc', marker='o')
plt.plot(val_accuracies, label='Val Acc', marker='s')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
# Compare F1 scores
if 'f1_baseline' in locals():
    x = np.arange(len(TARGET_DISEASES))
    width = 0.35

    plt.bar(x - width/2, f1_baseline, width, label='Baseline', alpha=0.8)
    plt.bar(x + width/2, f1_improved, width, label='Improved', alpha=0.8)

    plt.xlabel('Classes')
    plt.ylabel('F1-Score')
    plt.title('F1-Score Comparison')
    plt.xticks(x, TARGET_DISEASES, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()