In [1]:
# %% Imports
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

In [2]:
# %% Dataset Paths
data_dir = "../data/chest_xray"  # root folder

train_dir = os.path.join(data_dir, "train")
val_dir   = os.path.join(data_dir, "val")
test_dir  = os.path.join(data_dir, "test")

# %% Transforms
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
# %% Datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
val_dataset   = datasets.ImageFolder(root=val_dir, transform=transform)
test_dataset  = datasets.ImageFolder(root=test_dir, transform=transform)

# %% DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

print("Classes:", train_dataset.classes)  # ['NORMAL', 'PNEUMONIA']
print("Training samples:", len(train_dataset))
print("Validation samples:", len(val_dataset))
print("Test samples:", len(test_dataset))

print("Class distribution (train):", Counter(train_dataset.targets))

Classes: ['NORMAL', 'PNEUMONIA']
Training samples: 5216
Validation samples: 16
Test samples: 624
Class distribution (train): Counter({1: 3875, 0: 1341})


In [4]:
# %% Model Definition
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):  # 2 classes: NORMAL, PNEUMONIA
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*32*32, 128), nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleCNN(num_classes=len(train_dataset.classes))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
# %% Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for imgs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch+1}, Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {acc:.4f}")


Epoch 1, Train Loss: 0.1132, Train Acc: 0.9563
Epoch 2, Train Loss: 0.0693, Train Acc: 0.9753
Epoch 3, Train Loss: 0.0503, Train Acc: 0.9799
Epoch 4, Train Loss: 0.0284, Train Acc: 0.9895
Epoch 5, Train Loss: 0.0234, Train Acc: 0.9921
Epoch 6, Train Loss: 0.0076, Train Acc: 0.9973
Epoch 7, Train Loss: 0.0176, Train Acc: 0.9946
Epoch 8, Train Loss: 0.0044, Train Acc: 0.9985
Epoch 9, Train Loss: 0.0091, Train Acc: 0.9964
Epoch 10, Train Loss: 0.0076, Train Acc: 0.9971


In [7]:
# %% Validation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in val_loader:
        outputs = model(imgs)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
val_acc = accuracy_score(all_labels, all_preds)
print(f"Validation Accuracy: {val_acc:.4f}")

Validation Accuracy: 0.6875


In [8]:

# %% Save Model
model_path = "../backend/models/cnn_chestxray.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'num_classes': len(train_dataset.classes)
}, model_path)
print(f"Model saved to {model_path}")

Model saved to ../backend/models/cnn_chestxray.pth


In [9]:
# %% Load Model Later
checkpoint = torch.load(model_path)
model = SimpleCNN(num_classes=checkpoint['num_classes'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()
print("Model loaded and ready for inference.")

Model loaded and ready for inference.
