<a href="https://colab.research.google.com/github/bhattarai-aavash/Pneumonia_Classifier/blob/main/Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Load Dataset**

The dataset used in this tutorial can be found here. [link](https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **Importing Dependencies**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets

# **Create data tansformer**

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

In [None]:
data_dir ='/content/drive/MyDrive/chest_xray'

In [None]:
image_datasets = {
    'train': datasets.ImageFolder(root=data_dir+'/train', transform=data_transforms['train']),
    'val': datasets.ImageFolder(root=data_dir+'/val', transform=data_transforms['val']),
    'test': datasets.ImageFolder(root=data_dir+'/test', transform=data_transforms['test'])
}


In [None]:
print(len(image_datasets['train'].classes))  # Number of classes
print(image_datasets['train'].classes)  # Class names
sample_image, sample_label = image_datasets['train'][0]

2
['NORMAL', 'PNEUMONIA']


# **Create Dataloaders**

In [None]:
from torch.utils.data import DataLoader
batch_size = 32
train_loader = DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=False, num_workers=1)


# **Training Model**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from tqdm import tqdm
import os

# Load the pre-trained ResNet-50 model
model = resnet50(pretrained=True)

# Modify the last fully connected layer to match the number of classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(image_datasets['train'].classes))

# Set the device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Set the path for saving and loading the checkpoint
checkpoint_path = '/content/drive/MyDrive/checkpoint.pth'

# Check if a checkpoint exists
if os.path.exists(checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    epoch_loss = checkpoint['loss']
else:
    start_epoch = 0
    epoch_loss = []

# Train the model
num_epochs = 12

for epoch in range(start_epoch, num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0

    with tqdm(train_loader, unit="batch") as t:
        for inputs, labels in t:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            train_correct += torch.sum(preds == labels.data)

            t.set_postfix({"loss": loss.item()})

    train_loss = train_loss / len(image_datasets['train'])
    train_acc = train_correct.double() / len(image_datasets['train'])

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            val_correct += torch.sum(preds == labels.data)

    val_loss = val_loss / len(image_datasets['val'])
    val_acc = val_correct.double() / len(image_datasets['val'])

    # Print epoch statistics
    print('Epoch {}/{} - Train Loss: {:.4f} - Train Acc: {:.4f} - Val Loss: {:.4f} - Val Acc: {:.4f}'
          .format(epoch + 1, num_epochs, train_loss, train_acc, val_loss, val_acc))

    # Store epoch loss for tracking
    epoch_loss.append(train_loss)

    # Save the checkpoint
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': epoch_loss
    }, checkpoint_path)



100%|██████████| 163/163 [01:55<00:00,  1.41batch/s, loss=0.553]


Epoch 9/12 - Train Loss: 0.1113 - Train Acc: 0.9615 - Val Loss: 0.5295 - Val Acc: 0.8125


100%|██████████| 163/163 [01:56<00:00,  1.40batch/s, loss=0.233]


Epoch 10/12 - Train Loss: 0.0997 - Train Acc: 0.9630 - Val Loss: 0.4677 - Val Acc: 0.8125


100%|██████████| 163/163 [01:55<00:00,  1.42batch/s, loss=0.0338]


Epoch 11/12 - Train Loss: 0.0901 - Train Acc: 0.9651 - Val Loss: 0.3115 - Val Acc: 0.8750


100%|██████████| 163/163 [01:55<00:00,  1.41batch/s, loss=0.105]


Epoch 12/12 - Train Loss: 0.0936 - Train Acc: 0.9634 - Val Loss: 0.1756 - Val Acc: 0.9375
