The aim of this notebook is to visualize model predictions and evaluate it
(sanity check).

# Imports & constants

In [1]:
import sys 
sys.path.append('..')

In [2]:
import torch
import matplotlib.pyplot as plt

import src.utils as utils

from functools import partial
from pathlib import Path

from monai.networks.nets import SwinUNETR
from monai.data import DataLoader, Dataset, decollate_batch
from monai.transforms import (
    LoadImaged,
    Compose,
    CropForegroundd,
    EnsureChannelFirstd,
    EnsureTyped,
    ScaleIntensityRanged,
    Orientationd,
    Spacingd,
    SpatialPadd
)
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete
from monai.utils.enums import MetricReduction

from src.loaders import get_finetune_data

In [3]:
CHKPT_PATH = '../chkpts/test.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare stuff

In [4]:
_, data = get_finetune_data('../data/finetune/')

transforms = Compose([
    LoadImaged(keys=['img', 'label']),
    EnsureChannelFirstd(keys=['img', 'label']),
    Orientationd(keys=['img', 'label'], axcodes='RAS'),
    Spacingd(keys=['img', 'label'], pixdim=(1.5, 1.5, 2), 
                mode=('bilinear', 'nearest')),
    ScaleIntensityRanged(keys=['img'], a_min=-175, a_max=250,
                            b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=['img', 'label'], source_key='img'),
    SpatialPadd(keys=['img', 'label'], spatial_size=(96, 96, 96)),
    EnsureTyped(keys=['img', 'label'], track_meta=False)
])

ds = Dataset(
    data=data, 
    transform=transforms
)
    
loader = DataLoader(
    ds, 
    num_workers=0, 
    batch_size=1, 
    shuffle=False
)

In [5]:
model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=14,
    feature_size=12,
    num_heads=(3, 3, 3, 3)
).to(device)

model.load_state_dict(torch.load(CHKPT_PATH, map_location=torch.device('cpu')))
model.eval()

model_infer = partial(
    sliding_window_inference,
    roi_size=[96, 96, 96],
    sw_batch_size=1,
    predictor=model,
    overlap=0.2
)

In [6]:
acc_fn = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True)
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)

# Visualize & evaluate

In [9]:
avg_agg = utils.AverageAggregator()

for data_dict in loader:
    img, label = data_dict['img'].to(device), data_dict['label'].to(device)

    with torch.no_grad():
        pred = model_infer(img)

    label_list = decollate_batch(label)
    label_list = [post_label(label_tensor) for label_tensor in label_list]
    pred_list = decollate_batch(pred)
    pred_list = [post_pred(pred_tensor) for pred_tensor in pred_list]

    # Store visualizations
    for i in range(img.shape[-1]):
        if i % 50 == 0:
            fig, axs = plt.subplots(1, 3, figsize=(10, 10))
            axs[0].imshow(img[0, 0, :, :, i], cmap='gray', vmin=0, vmax=1)
            axs[0].set_title('Original slice')
            axs[1].imshow(label[0, 0, :, :, i])
            axs[1].set_title('Label')
            axs[2].imshow(torch.argmax(pred[0, :, :, :, i], dim=0))
            axs[2].set_title('Prediction')

            file_id = Path(data_dict["img_meta_dict"]["filename_or_obj"][0]).stem
            plt.savefig(f'{file_id}_{i}.png')
            plt.close()

    acc_fn.reset()
    acc_fn(y_pred=pred_list, y=label_list)
    acc, not_nans = acc_fn.aggregate()
    assert not_nans == 1
    avg_agg.update(acc.item())

print(f'Mean validation dice score: {avg_agg.item():.4f}')

Mean validation dice score: 0.0906
