In [None]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np

#DATASET LOADING & PREPROCESSING
# Define Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Dataset
train_dataset = datasets.ImageFolder(root="chest_xray/train", transform=transform)
val_dataset = datasets.ImageFolder(root="chest_xray/val", transform=transform)
test_dataset = datasets.ImageFolder(root="chest_xray/test", transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)


# Check the dataset size
print(f"Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")


In [None]:
#DEFINING MODEL
# Define device -> MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Define Model Training from Scratch
model_scratch = models.resnet18(pretrained=False)  # No pretrained weights
model_scratch.fc = nn.Sequential(
    nn.Dropout(0.5),  # Dropout Layer (prevents overfitting)
    nn.Linear(model_scratch.fc.in_features, 2),  # Output Layer (Binary Classification)
    nn.Softmax(dim=1)
)
model_scratch.to(device)

# Define Pretrained Model 
model_pretrained = models.resnet18(pretrained=True)  # Load pretrained weights
model_pretrained.fc = nn.Sequential(
    nn.Dropout(0.5),  # Dropout Layer
    nn.Linear(model_pretrained.fc.in_features, 2), #Binary Classification
    nn.Softmax(dim=1)
)
model_pretrained.to(device)

# Move parameters to MPS
for param in model_scratch.parameters():
    param.data = param.data.to(device)

for param in model_pretrained.parameters():
    param.data = param.data.to(device)


In [None]:
#TRAINING & EVALUATION
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=30, model_name="Model"):
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_losses.append(running_loss / len(train_loader))

        # Evaluate on validation set
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        val_losses.append(val_loss / len(val_loader))

        train_accuracy = 100 * correct / total
        print(f"{model_name} - Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")

    # Plot Training and Validation Loss
    plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
    plt.plot(range(1, epochs+1), val_losses, label='Validation Loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(f"Training & Validation Loss Curve - {model_name}")
    plt.legend()
    plt.show()

# Define loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer_scratch = torch.optim.Adam(model_scratch.parameters(), lr=0.0001, weight_decay=1e-4)
optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=0.0001, weight_decay=1e-4)

# Train Model from Scratch
train_model(model_scratch, train_loader, val_loader, criterion, optimizer_scratch, epochs=30, model_name="Scratch Model")

# Train Pretrained Model
train_model(model_pretrained, train_loader, val_loader, criterion, optimizer_pretrained, epochs=30, model_name="Pretrained Model")


In [None]:
def test_model(model, test_loader, model_name="Model"):
    model.eval()
    correct = 0
    total = 0
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)

            for i in range(len(labels)):
                label = labels[i].item()
                class_total[label] += 1
                if predicted[i] == label:
                    class_correct[label] += 1

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    overall_accuracy = 100 * correct / total
    pneumonia_acc = 100 * class_correct[1] / class_total[1] if class_total[1] > 0 else 0
    normal_acc = 100 * class_correct[0] / class_total[0] if class_total[0] > 0 else 0

    print(f"{model_name} - Overall Test Accuracy: {overall_accuracy:.2f}%")
    print(f"{model_name} - Pneumonia Accuracy: {pneumonia_acc:.2f}%")
    print(f"{model_name} - Normal Accuracy: {normal_acc:.2f}%")

# Test both models
test_model(model_scratch, test_loader, model_name="Scratch Model")
test_model(model_pretrained, test_loader, model_name="Pretrained Model")


In [None]:
#MISCLASSIFICATION VISUALIZATION
def show_misclassified_images(model, test_loader):
    model.eval()
    misclass_img = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            for i in range(len(labels)):
                if predicted[i] != labels[i]:  # Check for misclassification
                    misclass_img.append((images[i].cpu(), labels[i].cpu(), predicted[i].cpu()))

    # Plot first 5 misclassified images
    fig, axes = plt.subplots(1, 5, figsize=(15,5))
    for i, (img, true_label, pred_label) in enumerate(misclass_img[:5]):
        img = img.numpy().transpose((1, 2, 0))  # Convert tensor to image
        axes[i].imshow(img, cmap="gray")
        axes[i].set_title(f"True: {true_label}, Pred: {pred_label}")
        axes[i].axis("off")
    plt.show()

show_misclassified_images(model, test_loader)
