# ResNet-18 Training for Brain Tumor Classification

This notebook trains a ResNet-18 model from scratch to classify brain MRI images into 4 categories:
- **NO_TUMOR**: Healthy brain (no tumor detected)
- **GLIOMA**: Glioma tumor type
- **MENINGIOMA**: Meningioma tumor type
- **PITUITARY**: Pituitary tumor type

## Overview

This notebook includes:
1. Data loading and preprocessing (using same pipeline as VGG16 preprocessing)
2. ResNet-18 architecture implementation from scratch with BasicBlock and skip connections
3. Training pipeline with validation
4. Model saving based on best validation accuracy

## Requirements

Make sure you have installed all required packages:
```bash
pip install torch torchvision seaborn pandas scikit-learn matplotlib numpy tqdm
```

Or install from requirements.txt:
```bash
pip install -r requirements.txt
```


## 1. Initialization


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from tqdm import tqdm
import os
import random
from pathlib import Path

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')


## 2. Configuration


In [None]:
# Configuration
DATA_DIR = 'data/vgg16_classification'
MODEL_SAVE_PATH = 'resnet_model.pth'

# Training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
NUM_CLASSES = 4

# Class names
CLASS_NAMES = ['NO_TUMOR', 'GLIOMA', 'MENINGIOMA', 'PITUITARY']

print(f"Configuration:")
print(f"  Data directory: {DATA_DIR}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Number of epochs: {NUM_EPOCHS}")
print(f"  Number of classes: {NUM_CLASSES}")
print(f"  Model save path: {MODEL_SAVE_PATH}")


## 3. Custom Dataset Class (from preprocess_vgg16 pipeline)


In [None]:
# Custom dataset class that uses CSV metadata to prevent data leakage
from torch.utils.data import Dataset
from PIL import Image

class FilteredImageFolder(Dataset):
    """
    Custom dataset that loads images based on CSV metadata to prevent data leakage.
    Only loads images whose original_filename is NOT in other splits.
    """
    def __init__(self, metadata_df, split_name, transform=None, base_dir='data/vgg16_classification'):
        """
        Args:
            metadata_df: DataFrame with columns: image_path, full_path, class, split, filename, original_filename (optional)
            split_name: 'train', 'val', or 'test'
            transform: Image transforms
            base_dir: Base directory for images
        """
        self.split_name = split_name
        self.transform = transform
        self.base_dir = base_dir
        
        # Filter by split
        split_df = metadata_df[metadata_df['split'] == split_name].copy()
        
        # For train: use augmented metadata, filter by original_filename not in test/val
        if split_name == 'train':
            # Get original filenames from test and val splits (from original metadata)
            if 'original_filename' in metadata_df.columns:
                # This is augmented metadata
                orig_metadata_path = 'data/dataset_metadata.csv'
                if os.path.exists(orig_metadata_path):
                    orig_df = pd.read_csv(orig_metadata_path)
                    test_originals = set(orig_df[orig_df['split'] == 'test']['filename'].unique())
                    val_originals = set(orig_df[orig_df['split'] == 'val']['filename'].unique())
                    excluded_originals = test_originals.union(val_originals)
                    # Filter: only keep images whose original_filename is NOT in test/val
                    split_df = split_df[~split_df['original_filename'].isin(excluded_originals)]
        
        # For val/test: use original metadata, ensure no overlap with train
        elif split_name in ['val', 'test']:
            # Get train original filenames (from augmented metadata if available)
            aug_metadata_path = 'data/augmented_dataset_metadata.csv'
            if os.path.exists(aug_metadata_path):
                aug_df = pd.read_csv(aug_metadata_path)
                train_originals = set(aug_df[aug_df['split'] == 'train']['original_filename'].unique())
                # Filter: only keep images whose filename is NOT in train
                split_df = split_df[~split_df['filename'].isin(train_originals)]
        
        self.samples = []
        self.classes = sorted(split_df['class'].unique().tolist())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        
        for _, row in split_df.iterrows():
            # Use full_path if available, otherwise construct from base_dir and image_path
            if pd.notna(row.get('full_path')):
                img_path = row['full_path']
            else:
                img_path = os.path.join(base_dir, row['image_path'])
            
            # Normalize path separators
            img_path = img_path.replace('\\', '/')
            
            if os.path.exists(img_path):
                label = self.class_to_idx[row['class']]
                self.samples.append((img_path, label))
            else:
                print(f"Warning: Image not found: {img_path}")
        
        print(f"Loaded {len(self.samples)} images for {split_name} split (filtered from {len(split_df)} rows in CSV)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (224, 224), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


## 4. Data Loading and Preprocessing


In [None]:
# Data transforms (same as train_cnn.ipynb)
# Training: only normalization (augmentation already applied via augment_training_data.py)
# Use train_augmented directory which contains pre-augmented images
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test: only normalization (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets using CSV metadata to prevent data leakage
print("Loading datasets using CSV metadata files to prevent data leakage...")

# Load metadata CSV files
augmented_metadata_path = 'data/augmented_dataset_metadata.csv'
original_metadata_path = 'data/dataset_metadata.csv'

if os.path.exists(augmented_metadata_path):
    aug_metadata_df = pd.read_csv(augmented_metadata_path)
    print(f"Loaded augmented metadata: {len(aug_metadata_df)} rows")
else:
    print(f"Warning: {augmented_metadata_path} not found. Creating empty DataFrame.")
    aug_metadata_df = pd.DataFrame()

if os.path.exists(original_metadata_path):
    orig_metadata_df = pd.read_csv(original_metadata_path)
    print(f"Loaded original metadata: {len(orig_metadata_df)} rows")
else:
    print(f"Warning: {original_metadata_path} not found. Creating empty DataFrame.")
    orig_metadata_df = pd.DataFrame()

# Use FilteredImageFolder for train (from augmented metadata)
if len(aug_metadata_df) > 0:
    train_dataset = FilteredImageFolder(aug_metadata_df, 'train', transform=train_transform, base_dir=DATA_DIR)
else:
    print("Falling back to ImageFolder for train (no augmented metadata found)")
    train_dir = os.path.join(DATA_DIR, 'train_augmented')
    if not os.path.exists(train_dir):
        train_dir = os.path.join(DATA_DIR, 'train')
    train_dataset = ImageFolder(root=train_dir, transform=train_transform)

# Use FilteredImageFolder for val and test (from original metadata)
if len(orig_metadata_df) > 0:
    val_dataset = FilteredImageFolder(orig_metadata_df, 'val', transform=val_test_transform, base_dir=DATA_DIR)
    test_dataset = FilteredImageFolder(orig_metadata_df, 'test', transform=val_test_transform, base_dir=DATA_DIR)
else:
    print("Falling back to ImageFolder for val/test (no original metadata found)")
    val_dataset = ImageFolder(root=os.path.join(DATA_DIR, 'val'), transform=val_test_transform)
    test_dataset = ImageFolder(root=os.path.join(DATA_DIR, 'test'), transform=val_test_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'\nTrain samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Number of classes: {len(train_dataset.classes)}')
print(f'Class names: {train_dataset.classes}')


## 5. ResNet-18 Architecture


In [None]:
class BasicBlock(nn.Module):
    """
    BasicBlock for ResNet-18 and ResNet-34
    Contains two 3x3 convolutions with skip connection
    """
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # First 3x3 convolution
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Second 3x3 convolution
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Downsample layer for skip connection when dimensions change
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x
        
        # First convolution block
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # Second convolution block
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply downsample to identity if needed (for dimension matching)
        if self.downsample is not None:
            identity = self.downsample(x)
        
        # Skip connection: add identity to output
        out += identity
        out = self.relu(out)
        
        return out


class ResNet18(nn.Module):
    """
    ResNet-18 architecture from scratch
    
    Architecture:
    - Initial 7x7 convolution + max pooling
    - 4 layers, each with 2 BasicBlocks
    - Global average pooling
    - Fully connected classifier
    
    Uses He/Kaiming initialization for optimal training from scratch
    """
    def __init__(self, num_classes=4):
        super(ResNet18, self).__init__()
        
        # Initial layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # ResNet layers
        # Layer 1: 64 channels, 2 blocks
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        
        # Layer 2: 64->128 channels, 2 blocks
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        
        # Layer 3: 128->256 channels, 2 blocks
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        
        # Layer 4: 256->512 channels, 2 blocks
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)
        
        # Initialize weights using He/Kaiming initialization
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        """
        Create a layer with multiple BasicBlocks
        """
        downsample = None
        
        # If stride != 1 or channels change, we need a downsample layer for skip connection
        if stride != 1 or in_channels != out_channels * BasicBlock.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )
        
        layers = []
        # First block may have stride > 1 and/or channel change
        layers.append(BasicBlock(in_channels, out_channels, stride, downsample))
        
        # Remaining blocks have stride=1 and same channels
        in_channels = out_channels * BasicBlock.expansion
        for _ in range(1, blocks):
            layers.append(BasicBlock(in_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        """
        Initialize convolutional weights using He/Kaiming initialization
        This is optimal for ReLU activations when training from scratch
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # He/Kaiming initialization for ReLU
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # ResNet layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global average pooling and classifier
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

# Initialize model
model = ResNet18(num_classes=NUM_CLASSES).to(device)

# Print model architecture
print("ResNet-18 Model Architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


## 6. Training Configuration


In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer: Adam (same as train_cnn.ipynb)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

# Learning rate scheduler: ReduceLROnPlateau (same as train_cnn.ipynb)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

print(f"Loss function: CrossEntropyLoss")
print(f"Optimizer: Adam (lr={LEARNING_RATE}, weight_decay=1e-5)")
print(f"Learning rate scheduler: ReduceLROnPlateau (mode='min', factor=0.5, patience=5)")


In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0

print("Starting training...")
print(f"Training for {NUM_EPOCHS} epochs")
print("-" * 60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate (ReduceLROnPlateau uses validation loss)
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model (based on validation accuracy)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'history': history
        }, MODEL_SAVE_PATH)
        print(f"Saved best model (Val Acc: {val_acc:.2f}%)")

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Model saved to: {MODEL_SAVE_PATH}")


## 8. Training Curves Visualization


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curve
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot(history['train_acc'], label='Train Accuracy', linewidth=2)
axes[1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('resnet_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()


## 9. Load Best Model and Evaluate on Test Set


In [None]:
# Load best model
checkpoint = torch.load(MODEL_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.2f}%")

# Evaluate on test set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate and print overall accuracy score
accuracy = accuracy_score(all_labels, all_preds)
print(f"\nTest Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")


## 10. Confusion Matrix


In [None]:
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - ResNet-18 Model', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('resnet_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()


## 11. Classification Report


In [None]:
# Calculate and print the full classification report
print("Classification Report:")
print("=" * 70)
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))
print("=" * 70)


## 12. Summary

### Model Performance Summary:
- **Test Accuracy**: Calculated above
- **Best Validation Accuracy**: Calculated above
- **Model saved to**: `resnet_model.pth`

### Architecture Features:
- Full ResNet-18 implementation from scratch
- BasicBlock with skip connections (residual connections)
- Batch Normalization after all convolutional layers
- He/Kaiming initialization for optimal training from scratch
- Adam Optimizer (lr=0.001) with ReduceLROnPlateau Learning Rate Scheduler

### Files Generated:
1. Trained model: `resnet_model.pth`
2. Training curves: `resnet_training_curves.png`
3. Confusion matrix: `resnet_confusion_matrix.png`

All files are saved in the current directory.
