# RNN with GRU Model Training for Brain Tumor Classification

This notebook trains a hybrid CNN-GRU model to classify brain MRI images into 4 categories:
- **NO_TUMOR**: Healthy brain (no tumor detected)
- **GLIOMA**: Glioma tumor type
- **MENINGIOMA**: Meningioma tumor type
- **PITUITARY**: Pituitary tumor type

## Overview

This notebook implements a **CNN-GRU hybrid architecture**:
1. **CNN Feature Extractor**: Extracts spatial features from images using convolutional layers
2. **GRU Layers**: Processes the extracted features as sequences (GRU is faster and more efficient than LSTM)
3. **Fully Connected Layers**: Final classification into 4 classes

This approach combines the spatial feature extraction capabilities of CNNs with the sequential modeling power of GRUs.

## Requirements

Make sure you have installed all required packages:
```bash
pip install torch torchvision seaborn pandas scikit-learn matplotlib numpy tqdm
```

Or install from requirements.txt:
```bash
pip install -r requirements.txt
```

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm import tqdm
import os
import random
from pathlib import Path

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 2. Configuration

In [None]:
# Configuration
DATA_DIR = 'data/vgg16_classification'
MODEL_DIR = 'models/gru'
MODEL_SAVE_PATH = os.path.join(MODEL_DIR, 'gru_brain_tumor_classifier.pth')
HISTORY_SAVE_PATH = os.path.join(MODEL_DIR, 'gru_training_history.csv')

# Training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
NUM_CLASSES = 4

# Class names
CLASS_NAMES = ['NO_TUMOR', 'GLIOMA', 'MENINGIOMA', 'PITUITARY']

# Create models directory if it doesn't exist
os.makedirs(MODEL_DIR, exist_ok=True)

In [None]:
## 3. Custom Dataset Class Using CSV Metadata

# Custom dataset class that uses CSV metadata to prevent data leakage
from torch.utils.data import Dataset
from PIL import Image

class FilteredImageFolder(Dataset):
    """
    Custom dataset that loads images based on CSV metadata to prevent data leakage.
    Only loads images whose original_filename is NOT in other splits.
    """
    def __init__(self, metadata_df, split_name, transform=None, base_dir='data/vgg16_classification'):
        """
        Args:
            metadata_df: DataFrame with columns: image_path, full_path, class, split, filename, original_filename (optional)
            split_name: 'train', 'val', or 'test'
            transform: Image transforms
            base_dir: Base directory for images
        """
        self.split_name = split_name
        self.transform = transform
        self.base_dir = base_dir
        
        # Filter by split
        split_df = metadata_df[metadata_df['split'] == split_name].copy()
        
        # For train: use augmented metadata, filter by original_filename not in test/val
        if split_name == 'train':
            # Get original filenames from test and val splits (from original metadata)
            if 'original_filename' in metadata_df.columns:
                # This is augmented metadata
                orig_metadata_path = 'data/dataset_metadata.csv'
                if os.path.exists(orig_metadata_path):
                    orig_df = pd.read_csv(orig_metadata_path)
                    test_originals = set(orig_df[orig_df['split'] == 'test']['filename'].unique())
                    val_originals = set(orig_df[orig_df['split'] == 'val']['filename'].unique())
                    excluded_originals = test_originals.union(val_originals)
                    # Filter: only keep images whose original_filename is NOT in test/val
                    split_df = split_df[~split_df['original_filename'].isin(excluded_originals)]
        
        # For val/test: use original metadata, ensure no overlap with train
        elif split_name in ['val', 'test']:
            # Get train original filenames (from augmented metadata if available)
            aug_metadata_path = 'data/augmented_dataset_metadata.csv'
            if os.path.exists(aug_metadata_path):
                aug_df = pd.read_csv(aug_metadata_path)
                train_originals = set(aug_df[aug_df['split'] == 'train']['original_filename'].unique())
                # Filter: only keep images whose filename is NOT in train
                split_df = split_df[~split_df['filename'].isin(train_originals)]
        
        self.samples = []
        self.classes = sorted(split_df['class'].unique().tolist())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        
        for _, row in split_df.iterrows():
            # Use full_path if available, otherwise construct from base_dir and image_path
            if pd.notna(row.get('full_path')):
                img_path = row['full_path']
            else:
                img_path = os.path.join(base_dir, row['image_path'])
            
            # Normalize path separators
            img_path = img_path.replace('\\', '/')
            
            if os.path.exists(img_path):
                label = self.class_to_idx[row['class']]
                self.samples.append((img_path, label))
            else:
                print(f"Warning: Image not found: {img_path}")
        
        print(f"Loaded {len(self.samples)} images for {split_name} split (filtered from {len(split_df)} rows in CSV)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (224, 224), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


## 3. Data Loading and Preprocessing

In [None]:
# Data transforms
# Training: only normalization (augmentation already applied via augment_training_data.py)
# Use train_augmented directory which contains pre-augmented images
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test: only normalization (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets using CSV metadata to prevent data leakage
# #region agent log
import json
import time
LOG_PATH = '/home/benaaf/CSE465/Deep-MRIC/.cursor/debug.log'
def log_debug(location, message, data, hypothesis_id='A'):
    try:
        os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
        with open(LOG_PATH, 'a') as f:
            log_entry = {
                'id': f'log_{int(time.time()*1000)}',
                'timestamp': int(time.time()*1000),
                'location': location,
                'message': message,
                'data': data,
                'sessionId': 'debug-session',
                'runId': 'post-fix',
                'hypothesisId': hypothesis_id
            }
            f.write(json.dumps(log_entry) + '\n')
    except Exception as e:
        print(f"Logging error: {e}")
# #endregion
print("Loading datasets using CSV metadata files to prevent data leakage...")

# Load metadata CSV files
augmented_metadata_path = 'data/augmented_dataset_metadata.csv'
original_metadata_path = 'data/dataset_metadata.csv'

if os.path.exists(augmented_metadata_path):
    aug_metadata_df = pd.read_csv(augmented_metadata_path)
    print(f"Loaded augmented metadata: {len(aug_metadata_df)} rows")
else:
    print(f"Warning: {augmented_metadata_path} not found. Creating empty DataFrame.")
    aug_metadata_df = pd.DataFrame()

if os.path.exists(original_metadata_path):
    orig_metadata_df = pd.read_csv(original_metadata_path)
    print(f"Loaded original metadata: {len(orig_metadata_df)} rows")
else:
    print(f"Warning: {original_metadata_path} not found. Creating empty DataFrame.")
    orig_metadata_df = pd.DataFrame()

# Use FilteredImageFolder for train (from augmented metadata)
if len(aug_metadata_df) > 0:
    train_dataset = FilteredImageFolder(aug_metadata_df, 'train', transform=train_transform, base_dir=DATA_DIR)
else:
    print("Falling back to ImageFolder for train (no augmented metadata found)")
    train_dir = os.path.join(DATA_DIR, 'train_augmented')
    if not os.path.exists(train_dir):
        train_dir = os.path.join(DATA_DIR, 'train')
    train_dataset = ImageFolder(root=train_dir, transform=train_transform)

# Use FilteredImageFolder for val and test (from original metadata)
if len(orig_metadata_df) > 0:
    val_dataset = FilteredImageFolder(orig_metadata_df, 'val', transform=val_test_transform, base_dir=DATA_DIR)
    test_dataset = FilteredImageFolder(orig_metadata_df, 'test', transform=val_test_transform, base_dir=DATA_DIR)
else:
    print("Falling back to ImageFolder for val/test (no original metadata found)")
    val_dataset = ImageFolder(root=os.path.join(DATA_DIR, 'val'), transform=val_test_transform)
    test_dataset = ImageFolder(root=os.path.join(DATA_DIR, 'test'), transform=val_test_transform)

# #region agent log
# Verify no data leakage after filtering
def extract_original_filenames_from_dataset(dataset, split_name):
    """Extract original filenames from dataset samples"""
    original_filenames = set()
    for path, _ in dataset.samples:
        filename = os.path.basename(path)
        if '_aug' in filename:
            parts = filename.split('_aug')
            if len(parts) > 1:
                base = parts[0]
                ext = os.path.splitext(filename)[1]
                original_filename = base + ext
            else:
                original_filename = filename
        else:
            original_filename = filename
        original_filenames.add(original_filename)
    return original_filenames

train_originals = extract_original_filenames_from_dataset(train_dataset, 'train')
val_originals = extract_original_filenames_from_dataset(val_dataset, 'val')
test_originals = extract_original_filenames_from_dataset(test_dataset, 'test')

train_test_overlap = train_originals.intersection(test_originals)
train_val_overlap = train_originals.intersection(val_originals)
val_test_overlap = val_originals.intersection(test_originals)

log_debug('train_gru.ipynb:7', 'Post-fix data leakage check', {
    'train_unique_originals': len(train_originals),
    'val_unique_originals': len(val_originals),
    'test_unique_originals': len(test_originals),
    'train_test_overlap_count': len(train_test_overlap),
    'train_val_overlap_count': len(train_val_overlap),
    'val_test_overlap_count': len(val_test_overlap),
    'has_data_leakage': len(train_test_overlap) > 0,
    'using_csv_metadata': True,
    'train_samples': len(train_dataset),
    'val_samples': len(val_dataset),
    'test_samples': len(test_dataset)
}, 'B')
# #endregion

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'\nTrain samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Number of classes: {len(train_dataset.classes)}')
print(f'Class names: {train_dataset.classes}')

if train_test_overlap:
    print(f'\n  WARNING: Still found {len(train_test_overlap)} overlapping files between train and test!')
else:
    print(f'\n SUCCESS: No data leakage detected! Train and test sets are properly separated.')

### 3.1 Class Distribution

In [None]:
# Display class distribution
def get_class_distribution(dataset):
    class_counts = {}
    for _, label in dataset:
        class_name = dataset.classes[label]
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    return class_counts

train_dist = get_class_distribution(train_dataset)
val_dist = get_class_distribution(val_dataset)
test_dist = get_class_distribution(test_dataset)

# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (dist, title) in enumerate([(train_dist, 'Train'), (val_dist, 'Validation'), (test_dist, 'Test')]):
    classes = list(dist.keys())
    counts = list(dist.values())
    axes[idx].bar(classes, counts, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
    axes[idx].set_title(f'{title} Set Class Distribution', fontsize=14, fontweight='bold')
    axes[idx].set_xlabel('Class', fontsize=12)
    axes[idx].set_ylabel('Number of Samples', fontsize=12)
    axes[idx].tick_params(axis='x', rotation=45)
    for i, v in enumerate(counts):
        axes[idx].text(i, v, str(v), ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Print summary
print("\nClass Distribution Summary:")
print(f"{'Class':<15} {'Train':<10} {'Val':<10} {'Test':<10}")
print("-" * 50)
for class_name in CLASS_NAMES:
    print(f"{class_name:<15} {train_dist.get(class_name, 0):<10} {val_dist.get(class_name, 0):<10} {test_dist.get(class_name, 0):<10}")

## 4. CNN-GRU Model Architecture

In [None]:
class BrainTumorCNNGRU(nn.Module):
    def __init__(self, num_classes=4):
        super(BrainTumorCNNGRU, self).__init__()
        
        # CNN Feature Extractor
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        # Adaptive pooling to get fixed size feature maps
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))
        
        # Flatten spatial dimensions for GRU
        # After pooling: 128 channels * 7 * 7 = 6272 features
        self.feature_size = 128 * 7 * 7
        
        # GRU layers (GRU is faster and more efficient than LSTM)
        self.gru = nn.GRU(input_size=self.feature_size, hidden_size=256, 
                         num_layers=2, batch_first=True, dropout=0.3)
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        # CNN feature extraction
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.adaptive_pool(x)
        
        # Reshape for GRU: (batch, channels, height, width) -> (batch, seq_len, features)
        batch_size = x.size(0)
        x = x.view(batch_size, -1, self.feature_size)  # (batch, 1, feature_size)
        
        # GRU processing
        gru_out, h_n = self.gru(x)
        # Use the last output
        gru_out = gru_out[:, -1, :]  # (batch, hidden_size)
        
        # Classification
        x = self.fc(gru_out)
        return x

# Initialize model
model = BrainTumorCNNGRU(num_classes=NUM_CLASSES).to(device)

# Print model architecture
print("CNN-GRU Model Architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Training Configuration

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

print(f"Loss function: CrossEntropyLoss")
print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print(f"Learning rate scheduler: ReduceLROnPlateau")

## 6. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0
patience_counter = 0
early_stopping_patience = 10

print("Starting training...")
print(f"Training for {NUM_EPOCHS} epochs")
print("-" * 60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'history': history
        }, MODEL_SAVE_PATH)
        print(f"Saved best model (Val Acc: {val_acc:.2f}%)")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= early_stopping_patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

### 6.1 Save Training History

In [None]:
# Save training history to CSV
history_df = pd.DataFrame(history)
history_df.to_csv(HISTORY_SAVE_PATH, index=False)
print(f"Training history saved to {HISTORY_SAVE_PATH}")

## 7. Training Curves Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot(history['train_acc'], label='Train Accuracy', linewidth=2)
axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'gru_training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

## 8. Load Best Model and Evaluate on Test Set

In [None]:
# Load best model
checkpoint = torch.load(MODEL_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.2f}%")

# Evaluate on test set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Define y_true and y_pred for evaluation
y_true = all_labels
y_pred = all_preds

# Calculate and print overall accuracy score
accuracy = accuracy_score(y_true, y_pred)

# #region agent log
log_debug('train_gru.ipynb:21', 'Test accuracy calculation', {
    'accuracy': float(accuracy),
    'accuracy_percent': float(accuracy * 100),
    'total_samples': len(y_true),
    'correct_predictions': sum(1 for i in range(len(y_true)) if y_true[i] == y_pred[i]),
    'incorrect_predictions': sum(1 for i in range(len(y_true)) if y_true[i] != y_pred[i]),
    'y_true_sample': y_true[:10],
    'y_pred_sample': y_pred[:10]
}, 'D')
# #endregion

print(f"\nOverall Accuracy Score: {accuracy:.4f} ({accuracy*100:.2f}%)")

## 9. Confusion Matrix

In [None]:
# Generate confusion matrix using y_true and y_pred
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - GRU Model', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'gru_confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print confusion matrix as table
print("\nConfusion Matrix:")
print(f"{'':<15}", end='')
for name in CLASS_NAMES:
    print(f"{name:<15}", end='')
print()
for i, name in enumerate(CLASS_NAMES):
    print(f"{name:<15}", end='')
    for j in range(len(CLASS_NAMES)):
        print(f"{cm[i][j]:<15}", end='')
    print()

## 10. Classification Report

In [None]:
# Calculate and print the full classification report (showing precision, recall, and f1-score for all classes)
print("Full Classification Report:")
print("=" * 70)
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))
print("=" * 70)

# Also generate report as dictionary for detailed access
report = classification_report(y_true, y_pred, 
                                target_names=CLASS_NAMES, 
                                output_dict=True)

print("Classification Report:")
print("=" * 70)
print(f"{'Class':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 70)

for class_name in CLASS_NAMES:
    metrics = report[class_name]
    print(f"{class_name:<15} {metrics['precision']:<12.4f} {metrics['recall']:<12.4f} "
          f"{metrics['f1-score']:<12.4f} {int(metrics['support']):<10})")

print("-" * 70)
print(f"{'Accuracy':<15} {'':<12} {'':<12} {report['accuracy']:<12.4f} {len(y_true):<10}")
print(f"{'Macro Avg':<15} {report['macro avg']['precision']:<12.4f} "
      f"{report['macro avg']['recall']:<12.4f} {report['macro avg']['f1-score']:<12.4f} "
      f"{int(report['macro avg']['support']):<10}")
print(f"{'Weighted Avg':<15} {report['weighted avg']['precision']:<12.4f} "
      f"{report['weighted avg']['recall']:<12.4f} {report['weighted avg']['f1-score']:<12.4f} "
      f"{int(report['weighted avg']['support']):<10}")
print("=" * 70)

# Save report to file
report_path = os.path.join(MODEL_DIR, 'gru_classification_report.txt')
with open(report_path, 'w') as f:
    f.write(classification_report(y_true, y_pred, target_names=CLASS_NAMES))
print(f"\nClassification report saved to {report_path}")

## 11. Sample Predictions Visualization

In [None]:
# Visualize sample predictions
model.eval()
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.ravel()

# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.to(device)

with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)

# Denormalize for visualization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

for idx in range(min(10, len(images))):
    img = images[idx].cpu()
    img = img * std + mean  # Denormalize
    img = torch.clamp(img, 0, 1)
    img = img.permute(1, 2, 0).numpy()
    
    true_label = CLASS_NAMES[labels[idx]]
    pred_label = CLASS_NAMES[predicted[idx]]
    confidence = probabilities[idx][predicted[idx]].item() * 100
    
    # Color: green if correct, red if wrong
    color = 'green' if predicted[idx] == labels[idx] else 'red'
    
    axes[idx].imshow(img)
    axes[idx].set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)', 
                        color=color, fontsize=10, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle('Sample Predictions on Test Set', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, 'gru_sample_predictions.png'), dpi=300, bbox_inches='tight')
plt.show()

## 12. Summary

### Model Performance Summary:
- **Test Accuracy**: Calculated above
- **Best Validation Accuracy**: Calculated above
- **Model saved to**: `models/gru/gru_brain_tumor_classifier.pth`
- **Training history saved to**: `models/gru/gru_training_history.csv`

### Files Generated:
1. Trained model: `models/gru/gru_brain_tumor_classifier.pth`
2. Training history: `models/gru/gru_training_history.csv`
3. Training curves: `models/gru/gru_training_curves.png`
4. Confusion matrix: `models/gru/gru_confusion_matrix.png`
5. Classification report: `models/gru/gru_classification_report.txt`
6. Sample predictions: `models/gru/gru_sample_predictions.png`

All files are saved in the `models/gru/` directory and can be used for reporting and further analysis.

### GRU vs LSTM:
- **GRU** is generally faster and uses less memory than LSTM
- **GRU** has fewer parameters (no cell state, only hidden state)
- **LSTM** may capture longer dependencies but GRU often performs similarly with better efficiency
- Both architectures are available for comparison: `train_rnn_lstm.ipynb` and `train_gru.ipynb`