In [10]:
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import wandb

from glob import glob
from random import seed, shuffle
from pytorch_warmup_scheduler import WarmupScheduler
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms,models
from tqdm import tqdm

In [11]:
# Device Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


Training Configs

In [12]:
train_path = '/datasets/chest_xray/train/'
val_path = '/datasets/chest_xray/val/'

WARMUP_EPOCHS = 3
NUM_EPOCHS = 12
LEARNING_RATE = 1e-5 # initial lr
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 4
IMAGE_SHAPE = (224, 224)
TRAIN_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SHAPE),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomHorizontalFlip()
])
VAL_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SHAPE),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
LABEL_SMOOTHING = 0.0

# Create logging tools
# wandb.init(
#     project = "pneumonia-detection-finetune",

#     config = {
#         'learning_rate': LEARNING_RATE,
#         'weight_decay': WEIGHT_DECAY,
#         'epochs': NUM_EPOCHS,
#         'warmup_epochs': WARMUP_EPOCHS,
#         'batch_size': BATCH_SIZE,
#         'image_size': IMAGE_SHAPE,
#         'transform': TRAIN_TRANSFORM,
#         'label_smoothing': LABEL_SMOOTHING,
#         'optimizer': 'Adam',
#         'architecture': "Resnet18",
#         'dataset': 'chest_xray'
#     },
    
#     settings = wandb.Settings(disable_job_creation=True)
# )

Custom dataset object to use with torch.utils.data.DataLoader

In [13]:
def binary_label_smoothing(val, alpha):
    return (1-alpha) * val + (alpha / 2)

class PneumoniaDataset(Dataset):
    '''Dataset object that parses chest x-ray dataset and assigns a binary label to images based on if they are in the NORMAL or PNEUMONIA folders. Paired with a DataLoader it serves the image with transform applied along with its label.
    '''
    def __init__(self, data_path, transform=None, label_smoothing=0.0):
        self.data = data_path
        self.transform = transform
    
        normal_val = binary_label_smoothing(0, label_smoothing)
        pneumonia_val = binary_label_smoothing(1, label_smoothing)

        normal = [(path, normal_val) for path in glob(data_path+'/NORMAL/*.jpeg')]
        pneumonia = [(path, pneumonia_val) for path in glob(data_path+'/PNEUMONIA/*.jpeg')]
        self.paths_with_labels = normal + pneumonia
        shuffle(self.paths_with_labels)

    def __len__(self):
        return len(self.paths_with_labels)

    def __getitem__(self, index):
        path, label = self.paths_with_labels[index]
        im = cv2.imread(path)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        if self.transform:
            im = self.transform(im)
        return im, label

Instantiate data and model

In [14]:
train_data = PneumoniaDataset(train_path, TRAIN_TRANSFORM, label_smoothing=LABEL_SMOOTHING)
val_data = PneumoniaDataset(val_path, VAL_TRANSFORM, label_smoothing=LABEL_SMOOTHING)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, drop_last=True)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, drop_last=True)

# Set weights to None for random weights and 'DEFAULT' for ImageNet pretrain weights.
model = models.resnet18(weights='DEFAULT')
model.fc = nn.Sequential(nn.Flatten(),
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid()
)
model = model.to(device)
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
warmup_scheduler = WarmupScheduler(optimizer, WARMUP_EPOCHS)

Define training loop with appropriate loss/accuracy logging (tensorboard)

In [15]:
def num_correct(preds, labels):
    correct = 0
    for i, val in enumerate(preds):
        if (val > 0.5 and labels[i]==binary_label_smoothing(1., LABEL_SMOOTHING)) or (val < 0.5 and labels[i]==binary_label_smoothing(0, LABEL_SMOOTHING)):
            correct += 1
    return correct


def train_epoch(model, optimizer, loss_fn, train_dataloader, epoch):
    model.train()

    total_loss = 0
    total_correct = 0
    pbar = tqdm(enumerate(train_dataloader))
    pbar.set_description(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device).float()

        # Forward pass
        preds = model(images).reshape(BATCH_SIZE)

        # Calculate loss and accuracy
        total_correct += num_correct(preds, labels)
        batch_loss = loss_fn(preds, labels)
        total_loss += batch_loss

        # Backprop
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        warmup_scheduler.step()

    avg_loss = total_loss.item() / (len(train_dataloader) * BATCH_SIZE)
    accuracy = total_correct / (len(train_dataloader) * BATCH_SIZE)

    return avg_loss, accuracy


def val_epoch(model, loss_fn, val_dataloader, epoch):
    model.eval()
    
    total_loss = 0
    total_correct = 0
    pbar = tqdm(enumerate(val_dataloader))
    pbar.set_description(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device).float()

        # Forward pass
        preds = model(images).reshape(BATCH_SIZE)

        # Calculate loss and accuracy
        total_correct += num_correct(preds, labels)
        batch_loss = loss_fn(preds, labels)
        total_loss += batch_loss

    avg_loss = total_loss.item() / (len(val_dataloader) * BATCH_SIZE)
    accuracy = total_correct / (len(val_dataloader) * BATCH_SIZE)

    return avg_loss, accuracy
        
def train(model, optimizer, loss_fn, num_epochs):
    for epoch in range(num_epochs):
        # Training
        train_loss, train_acc = train_epoch(model, optimizer, loss_fn, train_dataloader, epoch)
        val_loss, val_acc = val_epoch(model, loss_fn, val_dataloader, epoch)
        info = {'train_acc': train_acc,
                'train_loss': train_loss,
                'val_acc': val_acc,
                'val_loss': val_loss}
        wandb.log(info)
        print(info)
        print()
    wandb.finish()
    save_path = f'./checkpoints/final_finetune.pt'
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, save_path)



Run training

In [16]:
train(model, optimizer, loss_fn, NUM_EPOCHS)

Evaluate trained model on test dataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

test_path = '/datasets/chest_xray/test/'
test_data = PneumoniaDataset(test_path, VAL_TRANSFORM)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, drop_last=True)

model = models.resnet18(weights=None)
model.fc = nn.Sequential(nn.Flatten(),
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid()
)
checkpoint = torch.load('./checkpoints/final_finetune.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

total_loss = 0
total_correct = 0
pbar = tqdm(test_dataloader)
for (images, labels) in pbar:
    images = images
    labels = labels.float()

    # Forward pass
    preds = model(images).reshape(BATCH_SIZE)

    # Calculate accuracy
    total_correct += num_correct(preds, labels)

accuracy = total_correct / (len(test_dataloader) * BATCH_SIZE)
print(f"Accuracy on test set: {accuracy:.2%}")