# Model Training for Pneumonia Detection

This notebook trains a deep learning model for pneumonia detection using transfer learning with ResNet50 or EfficientNet-B0.

In [None]:
import os
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim

sys.path.append('../')

from src.dataset import get_dataloaders
from src.model import get_model
from src.train import PneumoniaTrainer
from src.utils import set_seed, get_device, plot_metrics, plot_confusion_matrix, compute_metrics
import numpy as np
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set seed for reproducibility
set_seed(42)
device = get_device()

## Load Data

In [None]:
# Configuration
DATA_DIR = '../data/chest_xray'
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
MODEL_NAME = 'resnet50'  # or 'efficientnet_b0'

# Load dataloaders
print("Loading datasets...")
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    num_workers=4,
    image_size=224
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## Initialize Model and Trainer

In [None]:
# Create model
print(f"Creating {MODEL_NAME} model...")
model = get_model(MODEL_NAME, num_classes=2)
model = model.to(device)

# Print model summary
print(f"\nModel: {MODEL_NAME}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Create trainer
trainer = PneumoniaTrainer(
    model=model,
    device=device,
    learning_rate=LEARNING_RATE
)

print("\nTrainer initialized!")

## Train the Model

This will take some time depending on your hardware. With GPU, expect ~10-15 minutes per epoch.

In [None]:
# Create checkpoint directory
os.makedirs('../checkpoints', exist_ok=True)

# Train
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    checkpoint_dir='../checkpoints',
    patience=5
)

print("\nâœ… Training completed!")

## Visualize Training History

In [None]:
# Plot training curves
plot_metrics(
    train_losses=history['train_loss'],
    val_losses=history['val_loss'],
    train_accs=history['train_acc'],
    val_accs=history['val_acc'],
    save_path='../training_curves.png'
)

## Evaluate on Test Set

In [None]:
# Load best model
checkpoint_path = '../checkpoints/best_model.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print("Best model loaded!")

# Evaluate on test set
test_metrics = trainer.validate(test_loader)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"Test Loss: {test_metrics['loss']:.4f}")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Test F1-Score: {test_metrics['f1']:.4f}")
print("="*60)