In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import pandas as pd
import pydicom as dcm
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from torchvision.transforms import Compose, ToTensor, Normalize, RandomAffine, RandomResizedCrop
import torchmetrics
import importlib
import pnm.preprocess as preproc
importlib.reload(preproc)
import seaborn as sns

# Take a look at the x-ray images

In [None]:
raw_dir = 'data/stage_2_train_images'
label_path = 'data/stage_2_train_labels.csv'

num_images = 9
labels = pd.read_csv(label_path)
dicom_files = [file for file in os.listdir(raw_dir) if file.endswith('.dcm')]

# Plot the pixel array of the first 9 DICOM images
for i in range(num_images):
    file_path = os.path.join(raw_dir, dicom_files[i])
    patient_id = os.path.splitext(dicom_files[i])[0]
    label = labels[labels['patientId'] == patient_id]['Target'].iloc[0]
    ds = dcm.dcmread(file_path)
    pixel_array = ds.pixel_array
    plt.subplot(3, 3, i+1)
    plt.title(f'Label: {label}')
    plt.imshow(pixel_array, cmap='bone')
    plt.axis('off')

plt.show()


# Parameters

In [None]:
shape = (224, 224)
raw_dir = raw_dir
label_path = label_path
preproc_dir = 'preprocessed'
batch_size = 64
num_workers = 4

# Preprocessing

In [None]:
if not os.path.exists(preproc_dir):
    preproc.preprocess(raw_dir, label_path, preproc_dir, shape)
    
standard_params = preproc.compute_standard_params(preproc_dir, shape)


# Construct data loader

In [None]:
def load_img(file_name):
    # NOTE: The following code breaks at several points (RandomAffine, Trainer.fit()) for np.float16 :(((
    return np.load(file_name).astype(np.float32)

In [None]:
train_transforms = Compose([
    ToTensor(),
    Normalize(*standard_params),
    RandomAffine(degrees=5, translate=(0, 0.05), scale=(0.9, 1.1)),
    RandomResizedCrop(224, scale=(0.35, 1.))
])
val_transforms = Compose([ToTensor(), Normalize(*standard_params)])

In [None]:
train_data = DatasetFolder(os.path.join(preproc_dir, 'train'), 
loader=load_img, extensions='.npy', transform=train_transforms)
val_data = DatasetFolder(os.path.join(preproc_dir, 'val'), loader=load_img, extensions='.npy', transform=val_transforms)

In [None]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)


In [None]:
np.unique(train_data.targets, return_counts=True), np.unique(val_data.targets, return_counts=True)

In [None]:
for batch in train_loader:
    for i in range(3):
        rand_idx = np.random.randint(batch[0].shape[0])
        plt.subplot(1, 3, i+1)
        plt.imshow(batch[0][rand_idx, 0, :, :], cmap='bone')
        plt.title(f'label: {batch[1][rand_idx]}')
        plt.axis('off')
    break

# Create and train model

In [None]:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.models import resnet18
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from torch import tensor, sigmoid
from torch.nn import BCEWithLogitsLoss, Conv2d, Linear
from torch.optim import Adam

In [None]:
class PneumoniaClassifier(LightningModule):

    def __init__(self, weight=1, metrics=None):
        super().__init__()

        self.model = resnet18(pretrained=True)
        # Freeze the weights
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.conv1 = Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Q: Should bias be False??? in course code its True
        self.model.fc = Linear(512, 1, bias=False)

        self.metrics = {'acc': Accuracy('binary')} if metrics is None else metrics
        self.loss_fn = BCEWithLogitsLoss(pos_weight=tensor(weight))
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, label = batch
        logit = self(img)[:,0]
        loss = self.loss_fn(logit, label.float())

        self.log('train_loss', loss)
        for name, metric in self.metrics.items():
            prob = sigmoid(logit)
            self.log(f'train_{name}', metric(prob, label.int()), prog_bar=True)
        return loss

    def on_training_epoch_end(self):
        for name, metric in self.metrics.items():
            self.log('batch_train_{name}', metric.compute())
            metric.reset()

    def validation_step(self, batch, batch_idx):
        img, label = batch
        logit = self(img)[:,0]
        loss = self.loss_fn(logit, label.float())
        self.log('val_loss', loss)
        for name, metric in self.metrics.items():
            prob = sigmoid(logit)
            self.log(f'val_{name}', metric(prob, label.int()), prog_bar=True)
        return loss
    
    def on_validation_epoch_end(self):
        for name, metric in self.metrics.items():
            self.log(f'batch_val_{name}', metric.compute())
            metric.reset()

    def configure_optimizers(self):
        return Adam(self.model.parameters(), lr=1e-3)

In [None]:
pnm_model = PneumoniaClassifier()

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=10,
    mode='min'
)

In [None]:
trainer = Trainer(logger=TensorBoardLogger(save_dir="./logs"), log_every_n_steps=1,
                     callbacks=checkpoint_callback,
                     max_epochs=5)
trainer.fit(pnm_model, train_loader, val_loader)

# Save model

In [None]:
model_dir = Path('model')
current_date = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f'resnet_{current_date}.pth'

In [None]:
os.makedirs(model_dir, exist_ok=True)
save(pnm_model.state_dict(), model_dir / file_name)


# Load model

In [None]:
pnm_model = PneumoniaClassifier()
pnm_model.load_state_dict(torch.load(model_dir / file_name))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pnm_model.to(device)

# Evaluation

In [None]:
def preds_n_labels(model, loader):
    preds = []
    labels = []
    with torch.no_grad():
        for batch in tqdm(loader):
            img, label = batch
            logit = model(img.to(device))[:,0]
            prob = sigmoid(logit)
            preds.extend(prob.cpu().numpy())
            labels.extend(label.numpy())
    
    return torch.tensor(preds), torch.tensor(labels)

In [None]:
train_preds, train_labels = preds_n_labels(pnm_model, train_loader)
val_preds, val_labels = preds_n_labels(pnm_model, val_loader)

In [None]:
def evaluate(preds, labels):
    acc = torchmetrics.Accuracy('binary')(preds, labels)
    precision = torchmetrics.Precision('binary')(preds, labels)
    recall = torchmetrics.Recall('binary')(preds, labels)
    auc = torchmetrics.AUROC('binary')(preds, labels)

    acc_thresh = torchmetrics.Accuracy('binary', threshold=0.25)(preds, labels)
    precision_thresh = torchmetrics.Precision('binary', threshold=0.25)(preds, labels)
    recall_thresh = torchmetrics.Recall('binary', threshold=0.25)(preds, labels)

    cm = torchmetrics.ConfusionMatrix('binary')(preds, labels)
    cm_thresh = torchmetrics.ConfusionMatrix('binary', threshold=0.25)(preds, labels)

    print(f"Accuracy: {acc:.4f} vs. {acc_thresh:.4f} (threshold=0.25)")
    print(f"Precision: {precision:.4f} vs. {precision_thresh:.4f} (threshold=0.25)")
    print(f"Recall: {recall:.4f} vs. {recall_thresh:.4f} (threshold=0.25)")
    print(f"AUC: {auc:.4f}")

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0])
    axes[0].set_title('Confusion Matrix (threshold=0.5)')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('Actual')

    sns.heatmap(cm_thresh, annot=True, fmt='d', cmap='Blues', ax=axes[1])
    axes[1].set_title('Confusion Matrix (threshold=0.25)')
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('Actual')

    plt.tight_layout()
    plt.show()


In [None]:
evaluate(train_preds, train_labels)
evaluate(val_preds, val_labels)

In [None]:
def plot_roc(train_preds, val_preds, train_labels, val_labels):
    train_fpr, train_tpr, _ = torchmetrics.ROC('binary')(train_preds, train_labels)
    val_fpr, val_tpr, _ = torchmetrics.ROC('binary')(val_preds, val_labels)
    auc = torchmetrics.AUROC('binary')(preds, labels)

    plt.plot(train_fpr, train_tpr, label='Training AUC')
    plt.plot(val_fpr, val_tpr, label='Validation AUC')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic (AUC: {auc:.2f})')
    plt.legend()
    plt.show()


In [None]:
plot_roc(train_preds, val_preds, train_labels, val_labels)

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(9, 9))
random_indices = np.random.choice(len(val_data), size=9, replace=False)
for ax, i in zip(axes.flatten(), random_indices):
    ax.imshow(val_data[i][0][0], cmap='bone')
    ax.set_title(f"Prediction: {int(val_preds[i] > 0.25)}, True Label: {val_labels[i]}")
    ax.axis('off')

plt.tight_layout()
plt.show()
