In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from templates import *
from templates_cls import *
import monai.transforms as T
from monai.utils.misc import first

import matplotlib.pyplot as plt

from monai.networks.nets import DenseNet121
from sklearn.metrics import mean_absolute_error, f1_score

In [11]:
batch_size = 32

train_spider_files = create_spider_files()
train_transforms = T.Compose([
        # each image has a single ivd label
        T.LoadImaged(keys=['image', 'mask'], ensure_channel_first=True),
        T.Orientationd(keys=['image', 'mask'], axcodes='RAS'),
        # median dataset spacing
        T.Spacingd(keys=['image', 'mask'],pixdim=(3.32, 0.625, 0.625), mode=("bilinear", "nearest")),
        T.ScaleIntensityRangePercentilesd(keys='image', lower=0, upper=99.5, b_min=0, b_max=1),
        # remove other labels from mask
        CropMaskByLabel(mask_key='mask', label_key='ivd_label', label_lambda_func=lambda x: x + 200),
        # some augmentations
        #T.RandGaussianNoised(keys=['image'], mean=0.0, std=0.015, prob=prob),
        #T.RandRotated(keys=['image', 'mask'], range_x=30 * (np.pi / 180), mode=["bilinear", "nearest"], prob=prob),
        # center and crop image around ivd
        T.CropForegroundd(keys=['image', 'mask'], source_key='mask', margin=(0, 80, 80), allow_smaller=False),
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(-1, 80, 80)),
        # get a single slice
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(1, -1, -1)),
        #T.RandSpatialCropd(keys=['image', 'mask'], roi_size=(1, -1, -1)),
        AssertEmptyImaged(),
        # resize
        T.Resized(keys=['image', 'mask'], spatial_size=(1, 64, 64), anti_aliasing=True),
        T.ToTensord(keys=['image', 'mask']),
        T.SqueezeDimd(keys=['image', 'mask'], dim=1), 
])

dataset_train = data.CacheDataset(train_spider_files, transform=train_transforms)
train_loader = DataLoader(dataset_train, batch_size, shuffle=False, num_workers=8, pin_memory=True)

Loading dataset: 100%|██████████| 1185/1185 [15:29<00:00,  1.28it/s]


In [4]:
val_spider_files = create_spider_files(split="validation")
batch_size = 32
val_transforms = T.Compose([
        # each image has a single ivd label
        T.LoadImaged(keys=['image', 'mask'], ensure_channel_first=True),
        T.Orientationd(keys=['image', 'mask'], axcodes='RAS'),
        # median dataset spacing
        T.Spacingd(keys=['image', 'mask'],pixdim=(3.32, 0.625, 0.625), mode=("bilinear", "nearest")),
        T.ScaleIntensityRangePercentilesd(keys='image', lower=0, upper=99.5, b_min=0, b_max=1),
        # remove other labels from mask
        CropMaskByLabel(mask_key='mask', label_key='ivd_label', label_lambda_func=lambda x: x + 200),
        # center and crop image around ivd
        T.CropForegroundd(keys=['image', 'mask'], source_key='mask', margin=(0, 80, 80), allow_smaller=False),
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(-1, 80, 80)),
        # get a single slice
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(1, -1, -1)),
        AssertEmptyImaged(),
        # resize
        T.Resized(keys=['image', 'mask'], spatial_size=(1, 64, 64), anti_aliasing=True),
        T.ToTensord(keys=['image', 'mask']),
        T.SqueezeDimd(keys=['image', 'mask'], dim=1), 
])

dataset_val = data.CacheDataset(val_spider_files, transform=val_transforms)

val_loader = DataLoader(dataset_val, batch_size, shuffle=False, num_workers=8, pin_memory=True)

Missing grade, skipping IVD


Loading dataset: 100%|██████████| 261/261 [02:00<00:00,  2.17it/s]


In [None]:
test_spider_files = create_spider_files(split="test")
batch_size = 32
val_transforms = T.Compose([
        # each image has a single ivd label
        T.LoadImaged(keys=['image', 'mask'], ensure_channel_first=True),
        T.Orientationd(keys=['image', 'mask'], axcodes='RAS'),
        # median dataset spacing
        T.Spacingd(keys=['image', 'mask'],pixdim=(3.32, 0.625, 0.625), mode=("bilinear", "nearest")),
        T.ScaleIntensityRangePercentilesd(keys='image', lower=0, upper=99.5, b_min=0, b_max=1),
        # remove other labels from mask
        CropMaskByLabel(mask_key='mask', label_key='ivd_label', label_lambda_func=lambda x: x + 200),
        # center and crop image around ivd
        T.CropForegroundd(keys=['image', 'mask'], source_key='mask', margin=(0, 80, 80), allow_smaller=False),
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(-1, 80, 80)),
        # get a single slice
        T.CenterSpatialCropd(keys=['image', 'mask'], roi_size=(1, -1, -1)),
        AssertEmptyImaged(),
        # resize
        T.Resized(keys=['image', 'mask'], spatial_size=(1, 64, 64), anti_aliasing=True),
        T.ToTensord(keys=['image', 'mask']),
        T.SqueezeDimd(keys=['image', 'mask'], dim=1), 
])

dataset_val = data.CacheDataset(val_spider_files, transform=val_transforms)

val_loader = DataLoader(dataset_val, batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Regression

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from sklearn.metrics import mean_absolute_error
from torch.utils.data import DataLoader

# Model setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)
loss_function = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

# Early Stopping Setup
early_stopping_patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch_data in train_loader:
        images, labels = batch_data['image'].to(device), batch_data["pfirrman_grade"].float().to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Learning rate scheduler step
    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    total_preds = []
    total_labels = []
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = batch_data['image'].to(device), batch_data['pfirrman_grade'].float().to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels.unsqueeze(1))
            val_loss += loss.item()

            preds = outputs.round()
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_mae = mean_absolute_error(total_labels, total_preds)

    print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stopping_patience:
            print('Early stopping!')
            break


Epoch 1/100, Train Loss: 1.7215, Val Loss: 1.9843, Val MAE: 2.0728
Epoch 2/100, Train Loss: 1.5297, Val Loss: 1.8285, Val MAE: 1.7625
Epoch 3/100, Train Loss: 1.3695, Val Loss: 1.7144, Val MAE: 1.5977
Epoch 4/100, Train Loss: 1.2271, Val Loss: 1.5794, Val MAE: 1.4866
Epoch 5/100, Train Loss: 1.0990, Val Loss: 1.4542, Val MAE: 1.4176
Epoch 6/100, Train Loss: 0.9809, Val Loss: 1.3546, Val MAE: 1.3678
Epoch 7/100, Train Loss: 0.8794, Val Loss: 1.2732, Val MAE: 1.2644
Epoch 8/100, Train Loss: 0.7857, Val Loss: 1.1813, Val MAE: 1.1648
Epoch 9/100, Train Loss: 0.7033, Val Loss: 1.1009, Val MAE: 1.0920
Epoch 10/100, Train Loss: 0.6310, Val Loss: 1.0574, Val MAE: 1.0575
Epoch 11/100, Train Loss: 0.5675, Val Loss: 1.0387, Val MAE: 1.0192
Epoch 12/100, Train Loss: 0.5126, Val Loss: 0.9766, Val MAE: 0.9579
Epoch 13/100, Train Loss: 0.4665, Val Loss: 0.9199, Val MAE: 0.8966
Epoch 14/100, Train Loss: 0.4256, Val Loss: 0.9255, Val MAE: 0.9042
Epoch 15/100, Train Loss: 0.3949, Val Loss: 0.9323, Val M

In [21]:
# Save the model state dictionary
torch.save(model.state_dict(), "spider_resnet_reg.pt")

Evaluation

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load("resent/spider_resnet_reg.pt"))
model.eval()

total_preds = []
total_labels = []
with torch.no_grad():
    for batch_data in val_loader:
        images, labels = batch_data['image'].to(device), batch_data['pfirrman_grade'].float().to(device)
        preds = model(images)
        total_preds.extend(preds.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())

val_mae = mean_absolute_error(total_labels, total_preds)
val_f1 = f1_score(total_labels, np.round(total_preds), average="macro")

print("MAE: ", val_mae)
print("F1: ", val_f1)

MAE:  0.8309463
F1:  0.3000367846158562


# Classification

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader

# Model setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

# Early Stopping Setup
early_stopping_patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch_data in train_loader:
        images, labels = batch_data['image'].to(device), batch_data["pfirrman_grade"].to(device)
        labels = (labels > 0)
        labels = labels.float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Learning rate scheduler step
    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    total_preds = []
    total_labels = []
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = batch_data['image'].to(device), batch_data['pfirrman_grade'].to(device)
            labels = (labels > 0)
            labels = labels.float()
            outputs = model(images)
            loss = loss_function(outputs, labels.unsqueeze(1))
            val_loss += loss.item()

            preds = torch.sigmoid(outputs)
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_acc = accuracy_score(total_labels, np.round(total_preds))

    print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val ACC: {val_acc:.4f}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stopping_patience:
            print('Early stopping!')
            break


Epoch 1/100, Train Loss: 0.5686, Val Loss: 0.5228, Val ACC: 0.8659
Epoch 2/100, Train Loss: 0.4911, Val Loss: 0.4779, Val ACC: 0.8659
Epoch 3/100, Train Loss: 0.4415, Val Loss: 0.4424, Val ACC: 0.8659
Epoch 4/100, Train Loss: 0.4036, Val Loss: 0.4169, Val ACC: 0.8659
Epoch 5/100, Train Loss: 0.3697, Val Loss: 0.3993, Val ACC: 0.8659
Epoch 6/100, Train Loss: 0.3400, Val Loss: 0.3826, Val ACC: 0.8659
Epoch 7/100, Train Loss: 0.3139, Val Loss: 0.3688, Val ACC: 0.8659
Epoch 8/100, Train Loss: 0.2918, Val Loss: 0.3580, Val ACC: 0.8659
Epoch 9/100, Train Loss: 0.2718, Val Loss: 0.3492, Val ACC: 0.8659
Epoch 10/100, Train Loss: 0.2501, Val Loss: 0.3395, Val ACC: 0.8659
Epoch 11/100, Train Loss: 0.2307, Val Loss: 0.3338, Val ACC: 0.8621
Epoch 12/100, Train Loss: 0.2129, Val Loss: 0.3285, Val ACC: 0.8621
Epoch 13/100, Train Loss: 0.1972, Val Loss: 0.3230, Val ACC: 0.8659
Epoch 14/100, Train Loss: 0.1817, Val Loss: 0.3165, Val ACC: 0.8621
Epoch 15/100, Train Loss: 0.1670, Val Loss: 0.3127, Val A

In [23]:
# Save the model state dictionary
torch.save(model.state_dict(), "resent/spider_resnet_cls.pt")

In [24]:
from sklearn.metrics import roc_auc_score, f1_score

model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load("resent/spider_resnet_cls.pt"))
model.eval()

total_preds = []
total_labels = []
with torch.no_grad():
    for batch_data in val_loader:
        images, labels = batch_data['image'].to(device), batch_data['pfirrman_grade'].to(device)
        labels = (labels > 0)
        labels = labels.float()
        outputs = model(images)
        preds = torch.sigmoid(outputs).round()
        total_preds.extend(preds.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())

auc = roc_auc_score(total_labels, total_preds)
f1 = f1_score(total_labels, total_preds)

print("AUC: ", auc)
print("F1: ", f1)

AUC:  0.6108723135271807
F1:  0.9276595744680851
