# Phase 3: Transfer Learning with ResNet50

**Project:** AI-Powered Pneumonia Detection from Chest X-Rays  
**Author:** Georgios Kitsakis  
**Date:** 2025-10-29

## Objectives
1. Use pre-trained ResNet50 model
2. Fine-tune for pneumonia detection
3. Achieve 90-95% accuracy
4. Generate detailed evaluation metrics
5. Create confusion matrix and classification report

## Why Transfer Learning?
- **Pre-trained on ImageNet**: Already learned useful features
- **Faster Training**: Only fine-tune last layers
- **Better Accuracy**: State-of-the-art architecture
- **Industry Standard**: Used in production systems

## 0. Setup (Google Colab Compatible)

In [None]:
# Check if running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab - GPU will be used!")
    
    # Mount Google Drive (optional - to save models)
    from google.colab import drive
    drive.mount('/content/drive')
    
    # You'll need to upload your dataset to Colab or Google Drive
    # For now, we'll assume dataset is uploaded to /content/data/
    BASE_DIR = '/content/data'
    MODELS_DIR = '/content/drive/MyDrive/pneumonia_models'  # Save to Drive
    REPORTS_DIR = '/content/drive/MyDrive/pneumonia_reports'
    
except ImportError:
    IN_COLAB = False
    print("Running locally")
    BASE_DIR = '../data'
    MODELS_DIR = '../models'
    REPORTS_DIR = '../reports'

import os
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(REPORTS_DIR, exist_ok=True)

## 1. Import Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import time
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score

from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"PyTorch Version: {torch.__version__}")

# Configure matplotlib
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 2. Configuration

In [None]:
# Dataset paths
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
TEST_DIR = os.path.join(BASE_DIR, 'test')
VAL_DIR = os.path.join(BASE_DIR, 'val')

# Hyperparameters
IMG_SIZE = 224  # ResNet50 expects 224x224
BATCH_SIZE = 32
NUM_WORKERS = 2 if torch.cuda.is_available() else 0  # Use workers on GPU
NUM_EPOCHS = 10  # Fewer epochs needed with transfer learning
LEARNING_RATE = 0.0001

print(f"\nTraining Configuration:")
print(f"  Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Workers: {NUM_WORKERS}")

## 3. Data Loading

In [None]:
# Transforms - ResNet expects 3-channel images and ImageNet normalization
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3 channels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom Dataset
class ChestXRayDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['NORMAL', 'PNEUMONIA']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        for class_name in self.classes:
            class_path = os.path.join(root_dir, class_name)
            if not os.path.exists(class_path):
                continue
                
            for img_name in os.listdir(class_path):
                if img_name.endswith(('.jpeg', '.jpg', '.png')):
                    self.images.append(os.path.join(class_path, img_name))
                    self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('L')  # Load as grayscale
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create datasets
print("\nLoading datasets...")
train_dataset = ChestXRayDataset(TRAIN_DIR, transform=train_transform)
val_dataset = ChestXRayDataset(VAL_DIR, transform=val_test_transform)
test_dataset = ChestXRayDataset(TEST_DIR, transform=val_test_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True if torch.cuda.is_available() else False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=True if torch.cuda.is_available() else False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                         num_workers=NUM_WORKERS, pin_memory=True if torch.cuda.is_available() else False)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)} images ({len(train_loader)} batches)")
print(f"  Val: {len(val_dataset)} images ({len(val_loader)} batches)")
print(f"  Test: {len(test_dataset)} images ({len(test_loader)} batches)")

## 4. Build Transfer Learning Model

We'll use **ResNet50** pre-trained on ImageNet and fine-tune it for our task.

In [None]:
# Load pre-trained ResNet50
print("Loading pre-trained ResNet50...")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Freeze early layers (we'll only train the last few layers)
print("Freezing early layers...")
for name, param in model.named_parameters():
    if 'layer4' not in name and 'fc' not in name:
        param.requires_grad = False

# Replace final fully connected layer for binary classification
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 2)  # 2 classes: NORMAL, PNEUMONIA
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: ResNet50 (Transfer Learning)")
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
print(f"Frozen Parameters: {total_params - trainable_params:,}")
print(f"\nTraining only: layer4 and custom FC layers")

## 5. Training Setup

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

print("Training Setup:")
print(f"  Loss: Cross Entropy")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Scheduler: ReduceLROnPlateau (factor=0.5, patience=2)")

## 6. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc='Training', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.2f}%'})
    
    return running_loss / len(dataloader.dataset), 100 * correct / total

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation', leave=False)
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.2f}%'})
    
    return running_loss / len(dataloader.dataset), 100 * correct / total

print("Training functions defined.")

## 7. Train the Model

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0
best_model_path = os.path.join(MODELS_DIR, 'resnet50_best.pth')

print("\n" + "="*70)
print("Starting Training - ResNet50 Transfer Learning")
print("="*70)

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}]")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print results
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss
        }, best_model_path)
        print(f"  ✓ Best model saved! (Val Acc: {val_acc:.2f}%)")

elapsed_time = time.time() - start_time

print("\n" + "="*70)
print("Training Complete!")
print("="*70)
print(f"Total Time: {elapsed_time/60:.2f} minutes")
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

## 8. Visualize Training

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, NUM_EPOCHS + 1)

# Loss
ax1.plot(epochs_range, history['train_loss'], 'b-', label='Training Loss', linewidth=2, marker='o')
ax1.plot(epochs_range, history['val_loss'], 'r-', label='Validation Loss', linewidth=2, marker='s')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss (ResNet50)', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(epochs_range, history['train_acc'], 'b-', label='Training Accuracy', linewidth=2, marker='o')
ax2.plot(epochs_range, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2, marker='s')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Validation Accuracy (ResNet50)', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(REPORTS_DIR, 'training_curves_resnet50.png'), dpi=300, bbox_inches='tight')
plt.show()

## 9. Detailed Test Evaluation

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']} (Val Acc: {checkpoint['val_acc']:.2f}%)")

# Evaluate on test set
model.eval()
all_predictions = []
all_labels = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate metrics
test_acc = accuracy_score(all_labels, all_predictions) * 100
precision = precision_score(all_labels, all_predictions, average='weighted') * 100
recall = recall_score(all_labels, all_predictions, average='weighted') * 100
f1 = f1_score(all_labels, all_predictions, average='weighted') * 100

print(f"\n{'='*70}")
print(f"Test Set Evaluation - ResNet50")
print(f"{'='*70}")
print(f"Accuracy:  {test_acc:.2f}%")
print(f"Precision: {precision:.2f}%")
print(f"Recall:    {recall:.2f}%")
print(f"F1-Score:  {f1:.2f}%")
print(f"{'='*70}")

## 10. Confusion Matrix

In [None]:
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
class_names = ['NORMAL', 'PNEUMONIA']

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'}, linewidths=1, linecolor='black')
plt.title('Confusion Matrix - ResNet50', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=13, fontweight='bold')
plt.ylabel('True Label', fontsize=13, fontweight='bold')

# Add percentages
for i in range(len(class_names)):
    for j in range(len(class_names)):
        percentage = cm[i, j] / cm[i].sum() * 100
        plt.text(j + 0.5, i + 0.7, f'({percentage:.1f}%)', 
                ha='center', va='center', fontsize=11, color='red')

plt.tight_layout()
plt.savefig(os.path.join(REPORTS_DIR, 'confusion_matrix_resnet50.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_predictions, target_names=class_names, digits=4))

## 11. Calculate Sensitivity & Specificity

In [None]:
# Extract values from confusion matrix
tn, fp, fn, tp = cm.ravel()

# Calculate metrics
sensitivity = tp / (tp + fn) * 100  # True Positive Rate (Recall for PNEUMONIA)
specificity = tn / (tn + fp) * 100  # True Negative Rate
ppv = tp / (tp + fp) * 100  # Positive Predictive Value (Precision for PNEUMONIA)
npv = tn / (tn + fn) * 100  # Negative Predictive Value

print(f"\n{'='*70}")
print(f"Medical Metrics (PNEUMONIA Detection)")
print(f"{'='*70}")
print(f"Sensitivity (Recall):           {sensitivity:.2f}%  [TP / (TP + FN)]")
print(f"Specificity:                    {specificity:.2f}%  [TN / (TN + FP)]")
print(f"Positive Predictive Value:      {ppv:.2f}%  [TP / (TP + FP)]")
print(f"Negative Predictive Value:      {npv:.2f}%  [TN / (TN + FN)]")
print(f"\nConfusion Matrix Values:")
print(f"  True Positives (TP):  {tp}")
print(f"  True Negatives (TN):  {tn}")
print(f"  False Positives (FP): {fp}")
print(f"  False Negatives (FN): {fn}")
print(f"{'='*70}")

## 12. Save Results Summary

In [None]:
# Create results summary
results_summary = f"""
# ResNet50 Transfer Learning Results

**Author:** Georgios Kitsakis  
**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  
**Model:** ResNet50 (Transfer Learning)

## Training Configuration
- Image Size: {IMG_SIZE}x{IMG_SIZE}
- Batch Size: {BATCH_SIZE}
- Epochs: {NUM_EPOCHS}
- Learning Rate: {LEARNING_RATE}
- Device: {device}
- Training Time: {elapsed_time/60:.2f} minutes

## Model Architecture
- Base: ResNet50 (pre-trained on ImageNet)
- Trainable Parameters: {trainable_params:,}
- Total Parameters: {total_params:,}

## Results

### Overall Metrics
- **Test Accuracy:** {test_acc:.2f}%
- **Precision:** {precision:.2f}%
- **Recall:** {recall:.2f}%
- **F1-Score:** {f1:.2f}%

### Medical Metrics
- **Sensitivity (TPR):** {sensitivity:.2f}%
- **Specificity (TNR):** {specificity:.2f}%
- **Positive Predictive Value:** {ppv:.2f}%
- **Negative Predictive Value:** {npv:.2f}%

### Confusion Matrix
```
                Predicted
              NORMAL  PNEUMONIA
Actual NORMAL    {tn}      {fp}
       PNEUMONIA {fn}      {tp}
```

## Key Findings
- Transfer learning significantly improved accuracy compared to baseline CNN
- Model shows strong performance in detecting pneumonia cases
- High sensitivity is crucial for medical applications (minimize false negatives)

## Next Steps
- Implement Grad-CAM for explainability
- Deploy in Streamlit web application
- Consider ensemble methods for further improvement
"""

# Save to file
results_path = os.path.join(REPORTS_DIR, 'results_resnet50.md')
with open(results_path, 'w') as f:
    f.write(results_summary)

print(f"\nResults summary saved to: {results_path}")
print("\n" + "="*70)
print("PHASE 3 COMPLETE: Transfer Learning with ResNet50")
print("="*70)