In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
# Debug: Check imports
print("PyTorch version:", torch.__version__)
print("Models type:", type(models))
print("Available models:", [x for x in dir(models) if not x.startswith('_')][:10])
print("ResNet18 available:", hasattr(models, 'resnet18'))

# **Base line code**

In [None]:
# Data preparation
import os

data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset_split"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=transform)
val_dataset = datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
test_dataset = datasets.ImageFolder(root=f"{data_dir}/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Classes: {train_dataset.classes}")

In [None]:
# ======================
# Dataset Analysis: Check class distribution and labels
# ======================
import os
from collections import Counter

def analyze_dataset(data_dir):
    """Analyze the dataset structure and class distribution"""
    
    print("=" * 60)
    print("DATASET ANALYSIS")
    print("=" * 60)
    
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(data_dir, split)
        if not os.path.exists(split_path):
            print(f"Warning: {split} directory not found!")
            continue
            
        print(f"\n{split.upper()} SET:")
        print("-" * 30)
        
        class_counts = {}
        total_samples = 0
        
        # Get all class directories
        class_dirs = [d for d in os.listdir(split_path) 
                     if os.path.isdir(os.path.join(split_path, d))]
        class_dirs.sort()  # Sort for consistent output
        
        for class_name in class_dirs:
            class_path = os.path.join(split_path, class_name)
            # Count image files
            image_files = [f for f in os.listdir(class_path) 
                          if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            count = len(image_files)
            class_counts[class_name] = count
            total_samples += count
            
        # Print class distribution
        for class_name, count in class_counts.items():
            percentage = (count / total_samples) * 100 if total_samples > 0 else 0
            print(f"  {class_name:<15}: {count:>5} samples ({percentage:>5.1f}%)")
            
        print(f"  {'TOTAL':<15}: {total_samples:>5} samples")
        
        # Check for class imbalance
        if class_counts:
            max_class = max(class_counts, key=class_counts.get)
            min_class = min(class_counts, key=class_counts.get)
            imbalance_ratio = class_counts[max_class] / class_counts[min_class]
            print(f"  Imbalance ratio (max/min): {imbalance_ratio:.1f}:1")
            print(f"  Most frequent: {max_class} ({class_counts[max_class]} samples)")
            print(f"  Least frequent: {min_class} ({class_counts[min_class]} samples)")

# Analyze the dataset
print("Analyzing NIH Chest X-ray dataset...")
analyze_dataset(data_dir)

# Also check what the ImageFolder classes are mapped to
print("\n" + "=" * 60)
print("PYTORCH IMAGEFOLDER CLASS MAPPING")
print("=" * 60)
print("Class indices mapping:")
if 'train_dataset' in locals():
    for idx, class_name in enumerate(train_dataset.classes):
        print(f"  Index {idx}: {class_name}")
else:
    print("  Train dataset not loaded yet. Run the data preparation cell first.")

In [None]:
# Create train/val/test split if needed
import os
import shutil
import random
from sklearn.model_selection import train_test_split

def create_split(source_dir, dest_dir):
    if os.path.exists(dest_dir):
        return
    
    random.seed(42)
    os.makedirs(dest_dir, exist_ok=True)
    
    for split in ['train', 'val', 'test']:
        os.makedirs(f"{dest_dir}/{split}", exist_ok=True)
    
    class_dirs = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]
    
    for class_name in class_dirs:
        class_path = os.path.join(source_dir, class_name)
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        random.shuffle(image_files)
        n_total = len(image_files)
        n_train = int(n_total * 0.7)
        n_val = int(n_total * 0.15)
        
        splits = {
            'train': image_files[:n_train],
            'val': image_files[n_train:n_train + n_val],
            'test': image_files[n_train + n_val:]
        }
        
        for split, files in splits.items():
            split_class_dir = f"{dest_dir}/{split}/{class_name}"
            os.makedirs(split_class_dir, exist_ok=True)
            for file in files:
                shutil.copy2(os.path.join(class_path, file), os.path.join(split_class_dir, file))

source_data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset"
split_data_dir = "/content/drive/MyDrive/NIH_ChestXray_subset_split"

create_split(source_data_dir, split_data_dir)
print("Data split completed")

In [None]:
# ======================
# 2. Model definition
# ======================
num_classes = 4  # No Finding, Pneumonia, Effusion, Cardiomegaly (following project deliverable)
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
# ======================
# 3. Loss and optimizer
# ======================
# For baseline, use unweighted loss first to see the class imbalance effect
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:

# ======================
# 4. Training and evaluation functions
# ======================
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / len(loader), correct / total

def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return running_loss / len(loader), correct / total



In [None]:
# ======================
# 5. Training loop
# ======================
num_epochs = 5
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
          f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

In [None]:
# ======================
# 6. Final test evaluation with Detailed Metrics
# ======================

# Get detailed predictions for baseline model
def evaluate_with_predictions(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), correct / total, all_preds, all_labels

# Evaluate baseline model with detailed metrics
test_loss, test_acc, test_preds_baseline, test_labels_baseline = evaluate_with_predictions(model, test_loader, criterion)

print("=" * 60)
print("BASELINE MODEL RESULTS")
print("=" * 60)
print(f"Final Test: Loss={test_loss:.4f}, Acc={test_acc:.4f}")
print()

# Import necessary libraries for detailed metrics
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Class names
class_names =  ['Cardiomegaly', 'Effusion', 'No Finding', 'Pneumonia']

# Detailed classification report
print("Classification Report:")
print(classification_report(test_labels_baseline, test_preds_baseline, target_names=class_names))

# Confusion Matrix
cm_baseline = confusion_matrix(test_labels_baseline, test_preds_baseline)
plt.figure(figsize=(8, 6))
sns.heatmap(cm_baseline, annot=True, fmt='d', cmap='Reds', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Baseline Model - Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Calculate per-class metrics
precision_baseline, recall_baseline, f1_baseline, support_baseline = precision_recall_fscore_support(test_labels_baseline, test_preds_baseline)

print("\nPer-class Metrics:")
for i, class_name in enumerate(class_names):
    print(f"{class_name}:")
    print(f"  Precision: {precision_baseline[i]:.4f}")
    print(f"  Recall:    {recall_baseline[i]:.4f}")
    print(f"  F1-score:  {f1_baseline[i]:.4f}")
    print(f"  Support:   {support_baseline[i]}")
    print()

# **Improved model**

In [None]:
# **Mutimodal Model**

# ======================
# Improved Model Implementation
# ======================
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, WeightedRandomSampler
from PIL import Image
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

# Load and analyze CSV data
csv_path = "/content/drive/MyDrive/NIH_ChestXray/Data_Entry_2017.csv"
if not os.path.exists(csv_path):
  print("Failed")

else:
  print("Success")

df = pd.read_csv(csv_path)

print("CSV Data Analysis:")
print(f"Total samples: {len(df)}")
print(f"Unique images: {df['Image Index'].nunique()}")
print(f"\nFinding Labels distribution:")
print(df['Finding Labels'].value_counts().head(10))

# Target diseases for 4-class classification
TARGET_DISEASES = ['No Finding', 'Pneumonia', 'Effusion', 'Cardiomegaly']

def process_labels(finding_labels):
    """Process finding labels to match our 4-class problem"""
    if pd.isna(finding_labels):
        return 'No Finding'

    # Handle multiple findings (separated by |)
    findings = [f.strip() for f in str(finding_labels).split('|')]

    # Priority order for classification when multiple findings exist
    priority_order = ['Pneumonia', 'Cardiomegaly', 'Effusion', 'No Finding']

    for disease in priority_order:
        if disease in findings:
            return disease

    # If none of the target diseases found, classify as 'No Finding'
    return 'No Finding'

# Process labels
df['Processed_Label'] = df['Finding Labels'].apply(process_labels)

# Create label distribution
label_counts = df['Processed_Label'].value_counts()
print(f"\nProcessed Label Distribution:")
for label, count in label_counts.items():
    percentage = (count / len(df)) * 100
    print(f"  {label}: {count} ({percentage:.1f}%)")

# Create class to index mapping
class_to_idx = {cls: idx for idx, cls in enumerate(TARGET_DISEASES)}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
print(f"\nClass mapping: {class_to_idx}")

# ======================
# Custom Dataset Class with CSV Metadata (Multimodal)
# ======================
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None, class_to_idx=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.class_to_idx = class_to_idx or class_to_idx

        # Preprocess metadata for normalization
        self.age_mean = self.df['Patient Age'].mean()
        self.age_std = self.df['Patient Age'].std()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image Index'])

        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='black')

        if self.transform:
            image = self.transform(image)

        # Process metadata following project deliverable requirements
        # Age: normalized
        age = (row['Patient Age'] - self.age_mean) / self.age_std

        # Gender: binary encoding (M=1, F=0)
        gender = 1.0 if row['Patient Gender'] == 'M' else 0.0

        # View Position: binary encoding (PA=1, AP=0)
        view_position = 1.0 if row['View Position'] == 'PA' else 0.0

        # Combine metadata into tensor
        metadata = torch.FloatTensor([age, gender, view_position])

        # Get label
        label = self.class_to_idx[row['Processed_Label']]

        return image, metadata, label

# Enhanced data transforms with augmentation for improved model
train_transform_improved = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomRotation(degrees=10),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform_improved = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Multimodal dataset class and transforms defined.")
print("Metadata features: Age (normalized), Gender (M=1, F=0), View Position (PA=1, AP=0)")

# ======================
# Create Balanced Dataset Splits
# ======================
from sklearn.model_selection import train_test_split

# Filter dataframe to only include our target diseases and available images
img_dir = "/content/drive/MyDrive/NIH_ChestXray/images"
available_images = set(os.listdir(img_dir)) if os.path.exists(img_dir) else set()

# Filter df to only include available images
df_filtered = df[df['Image Index'].isin(available_images)] if available_images else df
print(f"Images available: {len(df_filtered)} out of {len(df)}")

# Create stratified splits to maintain class balance
# First split: train + val vs test (80-20)
train_val_df, test_df = train_test_split(
    df_filtered,
    test_size=0.2,
    stratify=df_filtered['Processed_Label'],
    random_state=42
)

# Second split: train vs val (75-25 of remaining)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.25,
    stratify=train_val_df['Processed_Label'],
    random_state=42
)

print(f"\nDataset splits:")
print(f"Train: {len(train_df)} samples")
print(f"Validation: {len(val_df)} samples")
print(f"Test: {len(test_df)} samples")

# Show class distribution in each split
for split_name, split_df in [("Train", train_df), ("Validation", val_df), ("Test", test_df)]:
    print(f"\n{split_name} set distribution:")
    dist = split_df['Processed_Label'].value_counts()
    for label, count in dist.items():
        percentage = (count / len(split_df)) * 100
        print(f"  {label}: {count} ({percentage:.1f}%)")

# Create datasets
train_dataset_improved = ChestXrayDataset(train_df, img_dir, train_transform_improved, class_to_idx)
val_dataset_improved = ChestXrayDataset(val_df, img_dir, val_test_transform_improved, class_to_idx)
test_dataset_improved = ChestXrayDataset(test_df, img_dir, val_test_transform_improved, class_to_idx)

# ======================
# Create Weighted Sampler for Balanced Training (Multimodal)
# ======================
def create_weighted_sampler(dataset, target_labels):
    """Create a weighted sampler to handle class imbalance"""
    # Get all labels from dataset (now returns 3 items: image, metadata, label)
    labels = [dataset[i][2] for i in range(len(dataset))]  # Changed index from 1 to 2

    # Calculate class weights (inverse frequency)
    class_counts = np.bincount(labels)
    class_weights = 1.0 / class_counts

    # Create sample weights
    sample_weights = [class_weights[label] for label in labels]

    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

# Create weighted sampler for training
train_sampler = create_weighted_sampler(train_dataset_improved, train_df['Processed_Label'].tolist())

# Custom collate function for multimodal data
def multimodal_collate_fn(batch):
    """Custom collate function to handle image, metadata, and labels"""
    images, metadata, labels = zip(*batch)

    # Stack images and labels as usual
    images = torch.stack(images, 0)
    labels = torch.LongTensor(labels)

    # Stack metadata
    metadata = torch.stack(metadata, 0)

    return images, metadata, labels

# Create data loaders with improved settings
train_loader_improved = DataLoader(
    train_dataset_improved,
    batch_size=32,
    sampler=train_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

val_loader_improved = DataLoader(
    val_dataset_improved,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

test_loader_improved = DataLoader(
    test_dataset_improved,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=multimodal_collate_fn
)

print("Multimodal weighted sampler and data loaders created.")
print(f"Train batches: {len(train_loader_improved)}")
print(f"Val batches: {len(val_loader_improved)}")
print(f"Test batches: {len(test_loader_improved)}")
print("Each batch contains: (images, metadata, labels)")

# ======================
# Multimodal Improved Model Architecture (Following Project Deliverable)
# ======================
class MultimodalChestXrayModel(nn.Module):
    def __init__(self, num_classes=4, pretrained=True):
        super(MultimodalChestXrayModel, self).__init__()

        # Image branch - CNN for visual feature extraction
        self.image_backbone = models.resnet50(pretrained=pretrained)

        # Freeze early layers to prevent overfitting
        for param in list(self.image_backbone.parameters())[:-20]:
            param.requires_grad = False

        # Remove the final classification layer from ResNet
        self.image_features = nn.Sequential(*list(self.image_backbone.children())[:-1])

        # Get number of features from ResNet50
        image_feature_dim = 2048  # ResNet50 output dimension

        # Metadata branch - MLP for tabular data (age, gender, view position)
        metadata_dim = 3  # age (normalized), gender (binary), view_position (binary: PA=1, AP=0)
        self.metadata_branch = nn.Sequential(
            nn.Linear(metadata_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Fusion layer - concatenate image and metadata features
        combined_feature_dim = image_feature_dim + 32  # 2048 + 32 = 2080

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(combined_feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, image, metadata):
        # Image branch
        image_features = self.image_features(image)
        image_features = torch.flatten(image_features, 1)  # Flatten to (batch_size, 2048)

        # Metadata branch
        metadata_features = self.metadata_branch(metadata)

        # Fusion - concatenate features
        combined_features = torch.cat([image_features, metadata_features], dim=1)

        # Final classification
        output = self.classifier(combined_features)

        return output

# Create multimodal improved model
model_improved = MultimodalChestXrayModel(num_classes=4, pretrained=True)
model_improved = model_improved.to(device)

# Calculate class weights for weighted loss
train_labels = [train_dataset_improved[i][2] for i in range(len(train_dataset_improved))]
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

print(f"Class weights: {dict(zip(TARGET_DISEASES, class_weights))}")

# Improved loss and optimizer
criterion_improved = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer_improved = optim.AdamW(model_improved.parameters(), lr=0.001, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_improved, mode='min', patience=3, factor=0.5
)

print("Multimodal improved model architecture ready (Image + Metadata).")
print(f"Model parameters: {sum(p.numel() for p in model_improved.parameters()):,}")

# ======================
# Enhanced Training Functions (Multimodal)
# ======================
def train_one_epoch_improved(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (images, metadata, labels) in enumerate(loader):
        images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, metadata)  # Pass both image and metadata
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Print progress every 100 batches
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(loader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

def evaluate_improved(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, metadata, labels in loader:
            images, metadata, labels = images.to(device), metadata.to(device), labels.to(device)
            outputs = model(images, metadata)  # Pass both image and metadata
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc, all_preds, all_labels

print("Enhanced multimodal training functions defined.")
print("Functions now handle both image and metadata inputs.")

# ======================
# Train Improved Model
# ======================
print("Starting training of improved model...")
print("=" * 60)

num_epochs_improved = 10
best_val_acc = 0.0
patience_counter = 0
patience_limit = 5

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs_improved):
    print(f"\nEpoch {epoch+1}/{num_epochs_improved}")
    print("-" * 40)

    # Training phase
    train_loss, train_acc = train_one_epoch_improved(
        model_improved, train_loader_improved, optimizer_improved, criterion_improved, device
    )

    # Validation phase
    val_loss, val_acc, _, _ = evaluate_improved(
        model_improved, val_loader_improved, criterion_improved, device
    )

    # Update learning rate scheduler
    scheduler.step(val_loss)

    # Save metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Early stopping and model saving
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # Save best model
        torch.save(model_improved.state_dict(), 'best_improved_model.pth')
        print(f"New best validation accuracy: {best_val_acc:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience_limit:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print(f"\nTraining completed. Best validation accuracy: {best_val_acc:.4f}")

# ======================
# Evaluate Improved Model
# ======================
# Load best model
model_improved.load_state_dict(torch.load('best_improved_model.pth'))

# Evaluate on test set
test_loss, test_acc, test_preds_improved, test_labels_improved = evaluate_improved(
    model_improved, test_loader_improved, criterion_improved, device
)

print("=" * 60)
print("IMPROVED MODEL RESULTS")
print("=" * 60)
print(f"Final Test: Loss={test_loss:.4f}, Acc={test_acc:.4f}")
print()

# Detailed classification report
print("Classification Report:")
print(classification_report(test_labels_improved, test_preds_improved, target_names=TARGET_DISEASES))

# Confusion Matrix
cm_improved = confusion_matrix(test_labels_improved, test_preds_improved)
plt.figure(figsize=(10, 8))
sns.heatmap(cm_improved, annot=True, fmt='d', cmap='Blues',
            xticklabels=TARGET_DISEASES, yticklabels=TARGET_DISEASES)
plt.title('Improved Model - Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Calculate per-class metrics
precision_improved, recall_improved, f1_improved, support_improved = precision_recall_fscore_support(
    test_labels_improved, test_preds_improved
)

print("\nPer-class Metrics:")
for i, class_name in enumerate(TARGET_DISEASES):
    print(f"{class_name}:")
    print(f"  Precision: {precision_improved[i]:.4f}")
    print(f"  Recall:    {recall_improved[i]:.4f}")
    print(f"  F1-score:  {f1_improved[i]:.4f}")
    print(f"  Support:   {support_improved[i]}")
    print()

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Val Loss', marker='s')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Acc', marker='o')
plt.plot(val_accuracies, label='Val Acc', marker='s')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
# Compare F1 scores
if 'f1_baseline' in locals():
    comparison_data = {
        'Baseline': f1_baseline,
        'Improved': f1_improved
    }
    x = np.arange(len(TARGET_DISEASES))
    width = 0.35

    plt.bar(x - width/2, f1_baseline, width, label='Baseline', alpha=0.8)
    plt.bar(x + width/2, f1_improved, width, label='Improved', alpha=0.8)

    plt.xlabel('Classes')
    plt.ylabel('F1-Score')
    plt.title('F1-Score Comparison')
    plt.xticks(x, TARGET_DISEASES, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()