In this work I will try to classify X-ray images, taken from
the dataset of Chest X-ray images, to Pneumonia/Normal.
Dataset is shared on https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
Dataset is organized into 3 folders (train, test, val) and contains subfolders for each image category (Pneumonia/Normal).
There are 5,863 X-Ray images (JPEG) and 2 categories (Pneumonia/Normal).


These two major transfer learning scenarios look as follows:

-  **Finetuning the convnet**: Instead of random initializaion, we
   initialize the network with a pretrained network, like the one that is
   trained on imagenet 1000 dataset. Rest of the training looks as
   usual.
-  **ConvNet as fixed feature extractor**: Here, we will freeze the weights
   for all of the network except that of the final fully connected
   layer or layers. These last fully connected layers are replaced with a new layer or layers
   with random weights and only these layers are trained.


In [None]:
from  google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.listdir('/content/drive/MyDrive/datasets/chest_xray')

['train', 'val', 'test']

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score
from torch.autograd import Variable
import itertools
from mlxtend.plotting import plot_confusion_matrix

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:0


In [None]:
# --- General Hyperparameters ---
num_epochs = 10
max_epochs_stop = 10
batch_size = 4
num_workers = 4

# --- Fine-tuning Model Parameters ---
finetune_learning_rate = 1e-4
finetune_model_name = 'FineTuned_model.pth'
finetune_scheduler_step_size = 7
finetune_scheduler_gamma = 0.1

# --- Fixed Feature Extractor Model Parameters ---
fixed_extractor_learning_rate = 0.001
fixed_extractor_momentum = 0.9
fixed_extractor_model_name = 'Pretrained_model.pth'
fixed_extractor_scheduler_step_size = 7
fixed_extractor_scheduler_gamma = 0.1

In [None]:
# Data augmentation and normalization
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Define path to the data directory
root_dir ='/content/drive/MyDrive/models'
data_dir = os.path.join(os.path.join(root_dir, 'data'), 'chest_xray')
data_dir='/content/drive/MyDrive/datasets/chest_xray'
# Load datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=num_workers) for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes

print("Dataset sizes:", dataset_sizes)
print("Class names:", class_names)

Dataset sizes: {'train': 5216, 'val': 16, 'test': 624}
Class names: ['NORMAL', 'PNEUMONIA']




In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, max_epochs_stop, model_name):
    since = time.time()
    best_acc = 0.0
    epochs_no_improve = 0
    loss_train, acc_train = [], []
    loss_val, acc_val = [], []
    epochs_array = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')

        for phase in ['train', 'val']:
            if phase == 'train':
                epochs_array.append(epoch)
                scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'train':
                loss_train.append(epoch_loss)
                acc_train.append(epoch_acc.item())
            else:
                loss_val.append(epoch_loss)
                acc_val.append(epoch_acc.item())

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                torch.save(model.state_dict(), os.path.join(root_dir, model_name))
                epochs_no_improve = 0
                best_acc = epoch_acc
            elif phase == 'val' and epoch_acc <= best_acc:
                epochs_no_improve += 1
                if epochs_no_improve >= max_epochs_stop:
                    print("Early Stopping!")
                    model.load_state_dict(torch.load(os.path.join(root_dir, model_name)))
                    time_elapsed = time.time() - since
                    return best_acc, epochs_array, acc_val, acc_train, loss_val, loss_train, time_elapsed

    time_elapsed = time.time() - since
    return best_acc, epochs_array, acc_val, acc_train, loss_val, loss_train, time_elapsed

In [None]:
# Function to plot confusion matrix
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, save_name='confusion_matrix.png'):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(save_name)
    plt.clf()
    plt.close()

# Function to plot ROC curve
def roc_plotter(fpr, tpr, auc_score, savename):
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {auc_score:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic on Test set')
    plt.legend(loc="lower right")
    plt.savefig(savename)
    plt.clf()
    plt.close()

# Function to plot loss and accuracy over epochs
def loss_acc_over_epocs_plotter(epochs_array, m_val, m_train, metric, save_name):
    plt.figure()
    plt.plot(epochs_array, m_val, label='Validation ' + metric)
    plt.plot(epochs_array, m_train, label='Training ' + metric)
    plt.xlabel('Epochs')
    plt.ylabel(metric)
    plt.title(f'{metric} over Epochs')
    plt.legend()
    plt.grid()
    plt.savefig(save_name)
    plt.clf()
    plt.close()

def test_model(model):
    print("Testing...")
    model.eval()
    running_corrects = 0
    for inputs, labels in dataloaders['test']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)
    acc_test = running_corrects.double() / dataset_sizes['test']
    print(f'Test Acc: {acc_test:.4f}')
    return acc_test

def confusion_matrix_roc_curve_calc(model, cm_save_name='Confusion_Matrix.png', roc_save_name='ROC.png'):
    print("Calculating confusion matrix and ROC curve...")
    true_labels, test_predictions = [], []
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, test_pred = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            test_predictions.extend(test_pred.cpu().numpy())

    CM = confusion_matrix(true_labels, test_predictions)
    plot_confusion_matrix(cm=CM, classes=class_names, title='Confusion Matrix', save_name=cm_save_name)

    fpr, tpr, _ = roc_curve(true_labels, test_predictions)
    auc_score = roc_auc_score(true_labels, test_predictions)
    roc_plotter(fpr, tpr, auc_score, roc_save_name)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, max_epochs_stop, model_name):
    since = time.time()
    best_acc = 0.0
    epochs_no_improve = 0
    loss_train, acc_train = [], []
    loss_val, acc_val = [], []
    epochs_array = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')

        for phase in ['train', 'val']:
            if phase == 'train':
                epochs_array.append(epoch)
                scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'train':
                loss_train.append(epoch_loss)
                acc_train.append(epoch_acc.item())
            else:
                loss_val.append(epoch_loss)
                acc_val.append(epoch_acc.item())

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                torch.save(model.state_dict(), os.path.join(root_dir, model_name))
                epochs_no_improve = 0
                best_acc = epoch_acc
            elif phase == 'val' and epoch_acc <= best_acc:
                epochs_no_improve += 1
                if epochs_no_improve >= max_epochs_stop:
                    print("Early Stopping!")
                    model.load_state_dict(torch.load(os.path.join(root_dir, model_name)))
                    time_elapsed = time.time() - since
                    return best_acc, epochs_array, acc_val, acc_train, loss_val, loss_train, time_elapsed

    time_elapsed = time.time() - since
    return best_acc, epochs_array, acc_val, acc_train, loss_val, loss_train, time_elapsed

def run_experiments():
    # Fine-tuning the convnet
    print("Finetuning the convnet...")
    model_ft = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, len(class_names))
    model_ft = model_ft.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.Adam(model_ft.parameters(), lr=finetune_learning_rate)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=finetune_scheduler_step_size, gamma=finetune_scheduler_gamma)

    best_acc_ft, epochs_ft, acc_val_ft, acc_train_ft, loss_val_ft, loss_train_ft, time_elapsed_ft = \
        train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=num_epochs, max_epochs_stop=max_epochs_stop, model_name=finetune_model_name)

    print(f'Training complete in {time_elapsed_ft // 60:.0f}m {time_elapsed_ft % 60:.0f}s')
    print(f'Best validation accuracy: {best_acc_ft:.4f}')

    loss_acc_over_epocs_plotter(epochs_ft, acc_val_ft, acc_train_ft, 'Accuracy', 'FineTuned_NN_Accuracy_over_epochs.png')
    loss_acc_over_epocs_plotter(epochs_ft, loss_val_ft, loss_train_ft, 'Loss', 'FineTuned_NN_Loss_over_epochs.png')

    model_ft.load_state_dict(torch.load(os.path.join(root_dir, finetune_model_name)))
    test_acc_ft = test_model(model_ft)
    print(f'Test accuracy: {test_acc_ft:.4f}')
    confusion_matrix_roc_curve_calc(model_ft, 'FineTune_ConfusionMatrix.png', 'FineTuned_ROC.png')

    # ConvNet as fixed feature extractor
    print('\nConvNet as fixed feature extractor...')
    model_conv = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    for param in model_conv.parameters():
        param.requires_grad = False

    num_ftrs = model_conv.fc.in_features
    model_conv.fc = nn.Linear(num_ftrs, len(class_names))
    model_conv = model_conv.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=fixed_extractor_learning_rate, momentum=fixed_extractor_momentum)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=fixed_extractor_scheduler_step_size, gamma=fixed_extractor_scheduler_gamma)

    best_acc_conv, epochs_conv, acc_val_conv, acc_train_conv, loss_val_conv, loss_train_conv, time_elapsed_conv = \
        train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=num_epochs, max_epochs_stop=max_epochs_stop, model_name=fixed_extractor_model_name)

    print(f'Training complete in {time_elapsed_conv // 60:.0f}m {time_elapsed_conv % 60:.0f}s')
    print(f'Best validation accuracy: {best_acc_conv:.4f}')

    loss_acc_over_epocs_plotter(epochs_conv, acc_val_conv, acc_train_conv, 'Accuracy', 'Pretrained_NN_Accuracy_over_epochs.png')
    loss_acc_over_epocs_plotter(epochs_conv, loss_val_conv, loss_train_conv, 'Loss', 'Pretrained_NN_Loss_over_epochs.png')

    model_conv.load_state_dict(torch.load(os.path.join(root_dir, fixed_extractor_model_name)))
    test_acc_conv = test_model(model_conv)
    print(f'Test accuracy: {test_acc_conv:.4f}')
    confusion_matrix_roc_curve_calc(model_conv, 'Pretrained_ConfusionMatrix.png', 'Pretrained_ROC.png')

if __name__ == "__main__":
    run_experiments()

Finetuning the convnet...
Epoch 0/9
train Loss: 0.1652 Acc: 0.9383
val Loss: 0.7518 Acc: 0.7500
Epoch 1/9
train Loss: 0.0921 Acc: 0.9693
val Loss: 0.4302 Acc: 0.8750
Epoch 2/9
train Loss: 0.0711 Acc: 0.9801
val Loss: 0.4851 Acc: 0.8125
Epoch 3/9
train Loss: 0.0522 Acc: 0.9833
val Loss: 0.3520 Acc: 0.8750
Epoch 4/9
train Loss: 0.0444 Acc: 0.9845
val Loss: 0.2346 Acc: 0.9375
Epoch 5/9
train Loss: 0.0446 Acc: 0.9856
val Loss: 0.0692 Acc: 0.9375
Epoch 6/9
train Loss: 0.0122 Acc: 0.9964
val Loss: 0.0243 Acc: 1.0000
Epoch 7/9
train Loss: 0.0077 Acc: 0.9975
val Loss: 0.0507 Acc: 0.9375
Epoch 8/9
train Loss: 0.0088 Acc: 0.9979
val Loss: 0.0313 Acc: 1.0000
Epoch 9/9
train Loss: 0.0058 Acc: 0.9990
val Loss: 0.0987 Acc: 0.9375
Training complete in 18m 11s
Best validation accuracy: 1.0000
Testing...
Test Acc: 0.8205
Test accuracy: 0.8205
Calculating confusion matrix and ROC curve...

ConvNet as fixed feature extractor...
Epoch 0/9




train Loss: 0.3503 Acc: 0.8606
val Loss: 0.3984 Acc: 0.8750
Epoch 1/9
train Loss: 0.3002 Acc: 0.8840
val Loss: 0.3280 Acc: 0.9375
Epoch 2/9
train Loss: 0.2956 Acc: 0.8880
val Loss: 0.2509 Acc: 0.8750
Epoch 3/9
train Loss: 0.3232 Acc: 0.8806
val Loss: 0.5215 Acc: 0.8125
Epoch 4/9
train Loss: 0.3447 Acc: 0.8692
val Loss: 0.4470 Acc: 0.8750
Epoch 5/9
train Loss: 0.2877 Acc: 0.8967
val Loss: 0.2164 Acc: 0.8750
Epoch 6/9
train Loss: 0.2338 Acc: 0.9135
val Loss: 0.2559 Acc: 0.8750
Epoch 7/9
train Loss: 0.2351 Acc: 0.9091
val Loss: 0.4723 Acc: 0.8125
Epoch 8/9
train Loss: 0.2159 Acc: 0.9160
val Loss: 0.3105 Acc: 0.7500
Epoch 9/9
train Loss: 0.2179 Acc: 0.9166
val Loss: 0.5350 Acc: 0.7500
Training complete in 16m 56s
Best validation accuracy: 0.9375
Testing...
Test Acc: 0.8702
Test accuracy: 0.8702
Calculating confusion matrix and ROC curve...
