# COVID-19 Chest X-Ray Classification

This notebook implements a deep learning model to classify chest X-ray images as COVID-19 positive or negative.

## Project Overview
- **Objective**: Train a model to classify people as having COVID vs not having COVID based on chest X-ray images
- **Target Accuracy**: >50% (better than random guessing)
- **Dataset**: COVID-19 Radiography Database from Kaggle
- **Framework**: PyTorch with transfer learning using ResNet


## 1. Import Libraries and Setup

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Download and Setup Dataset

In [None]:
# We'll use the COVID-19 Radiography Database from Kaggle
# For this demo, we'll create a simple structure and use a subset of data

# Create directories
os.makedirs('data/COVID', exist_ok=True)
os.makedirs('data/Normal', exist_ok=True)
os.makedirs('models', exist_ok=True)

print("Directory structure created!")
print("Please download the COVID-19 Radiography Database from:")
print("https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database")
print("And extract the COVID and Normal folders to the data/ directory")

## 3. Custom Dataset Class

In [None]:
class COVID19Dataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

## 4. Data Loading and Preprocessing

In [None]:
def load_data(covid_dir, normal_dir, max_samples_per_class=1000):
    image_paths = []
    labels = []
    
    # Load COVID images (label = 1)
    covid_files = [f for f in os.listdir(covid_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:max_samples_per_class]
    for file in covid_files:
        image_paths.append(os.path.join(covid_dir, file))
        labels.append(1)
    
    # Load Normal images (label = 0)
    normal_files = [f for f in os.listdir(normal_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:max_samples_per_class]
    for file in normal_files:
        image_paths.append(os.path.join(normal_dir, file))
        labels.append(0)
    
    return image_paths, labels

# Data transforms
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load data (this will work once the dataset is downloaded)
try:
    image_paths, labels = load_data('data/COVID', 'data/Normal')
    print(f"Loaded {len(image_paths)} images")
    print(f"COVID cases: {sum(labels)}")
    print(f"Normal cases: {len(labels) - sum(labels)}")
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )
    
    # Create datasets
    train_dataset = COVID19Dataset(X_train, y_train, transform=transform_train)
    test_dataset = COVID19Dataset(X_test, y_test, transform=transform_test)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
except FileNotFoundError:
    print("Dataset not found. Please download and extract the dataset first.")
    print("For now, we'll create a mock dataset for demonstration.")
    
    # Create mock data for demonstration
    mock_images = torch.randn(100, 3, 224, 224)
    mock_labels = torch.randint(0, 2, (100,))
    
    train_data = [(mock_images[i], mock_labels[i]) for i in range(80)]
    test_data = [(mock_images[i], mock_labels[i]) for i in range(80, 100)]
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
    
    print("Using mock data for demonstration.")

## 5. Model Definition (Transfer Learning with ResNet)

In [None]:
class COVID19Classifier(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(COVID19Classifier, self).__init__()
        
        # Use ResNet18 as backbone
        self.resnet = models.resnet18(pretrained=pretrained)
        
        # Freeze early layers for transfer learning
        for param in list(self.resnet.parameters())[:-10]:
            param.requires_grad = False
        
        # Replace final layer
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        return self.resnet(x)

# Initialize model
model = COVID19Classifier().to(device)
print(f"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")

## 6. Training Setup

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Training function
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
    
    return total_loss / len(train_loader), 100. * correct / total

# Evaluation function
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = 100. * correct / total
    return total_loss / len(test_loader), accuracy, all_predictions, all_targets

## 7. Model Training

In [None]:
# Training loop
num_epochs = 15
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

print("Starting training...")
for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device)
    
    # Evaluate
    test_loss, test_acc, _, _ = evaluate_model(model, test_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    
    print(f'Epoch [{epoch+1}/{num_epochs}]:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    print()

print("Training completed!")

## 8. Model Evaluation and Visualization

In [None]:
# Plot training history
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Accuracy Over Time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

# Final evaluation
final_test_loss, final_test_acc, predictions, targets = evaluate_model(model, test_loader, criterion, device)

# Confusion Matrix
cm = confusion_matrix(targets, predictions)
plt.subplot(1, 3, 3)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Normal', 'COVID'], 
            yticklabels=['Normal', 'COVID'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

plt.tight_layout()
plt.show()

# Print detailed results
print(f"Final Test Accuracy: {final_test_acc:.2f}%")
print("\nClassification Report:")
print(classification_report(targets, predictions, target_names=['Normal', 'COVID']))

# Calculate sensitivity and specificity
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn)  # True Positive Rate
specificity = tn / (tn + fp)  # True Negative Rate

print(f"\nSensitivity (COVID Detection): {sensitivity:.3f}")
print(f"Specificity (Normal Detection): {specificity:.3f}")

## 9. Save Model

In [None]:
# Save the trained model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': final_test_acc,
    'epoch': num_epochs
}, 'models/covid_classifier.pth')

print("Model saved successfully!")

## 10. Reflection and Learning Summary

### Key Learnings from This COVID-19 Classification Task:

1. **Data Quality Matters**: Medical imaging datasets require careful preprocessing and validation - the quality and representativeness of chest X-ray images significantly impacts model performance, and class imbalance is a common challenge in medical datasets.

2. **Transfer Learning Effectiveness**: Using pre-trained ResNet models on ImageNet provides excellent feature extraction capabilities for medical images, even though they weren't originally trained on medical data - this demonstrates the power of learned visual representations.

3. **Medical AI Limitations**: Achieving high accuracy on limited datasets doesn't guarantee clinical utility - real-world deployment requires extensive validation, diverse patient populations, and consideration of edge cases that may not be present in research datasets.

4. **Evaluation Metrics Importance**: In medical applications, sensitivity (detecting true COVID cases) and specificity (avoiding false positives) are often more important than overall accuracy - a model with 90% accuracy but poor sensitivity could miss critical cases.

5. **Ethical Considerations**: COVID-19 classification models highlight the responsibility of AI developers in healthcare - false negatives could delay treatment while false positives could cause unnecessary panic, emphasizing the need for careful validation and human oversight.

6. **Data Augmentation Benefits**: Simple augmentation techniques like rotation and flipping help improve model robustness without requiring additional data collection, which is particularly valuable in medical imaging where acquiring new labeled data is expensive and time-consuming.
