In [2]:
import torch
from monai.networks.nets import SwinUNETR
import json
import os
import monai
import shutil

from monai.transforms import (
    Compose,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    EnsureChannelFirstd,
    SpatialPadd,
)

import numpy as np
from tqdm import tqdm

 missing cuda symbols while dynamic loading
 cuFile initialization failed


In [3]:
splits_final = json.load(open(os.path.join("/nnUNet/preprocessed_data/Dataset060_Merged_Def/", "splits_Dataset060_Merged_Def.json")))

In [5]:
patch_size = [160, 64, 128]

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 1.5),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-982,
            a_max=1094,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        SpatialPadd(keys=["image", "label"], spatial_size=patch_size),
    ]
)

## Calculate predictions of validation set per fold

In [None]:
num_samples_per_image = 2

post_label = monai.transforms.AsDiscrete(to_onehot=13)
post_pred = monai.transforms.AsDiscrete(argmax=True, to_onehot=13)
post_pred_to_save = monai.transforms.AsDiscrete(argmax=True)

data_dir = "/nnUNet/raw_data/Dataset060_Merged_Def/"

for fold in range(0, 5):

    results = {}
    results['metric_per_case'] = []

    # Initialize model with weights of each fold
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SwinUNETR(
    img_size=patch_size,
    in_channels=1,
    out_channels=13,
    feature_size=48,
    use_checkpoint=True,
    ).to(device)

    weight = torch.load(f"/experiments_Swin-UNETR/fold_{fold}/output/swin_unetr_fold_{fold}_best_metric_model.pth")
    model.load_state_dict(weight['state_dict'])


    # Load validation set

    val_images = []
    val_labels = []

    for image_name in splits_final[fold]["val"]:
        val_images.append(os.path.join(data_dir, "imagesTr", image_name + "_0000.nii.gz"))
        val_labels.append(os.path.join(data_dir, "labelsTr", image_name + ".nii.gz"))

    val_images = sorted(val_images)
    val_labels = sorted(val_labels)

    val_files = [
    {"image": image_name, "label": label_name, 'path': image_name}
    for image_name, label_name in zip(val_images, val_labels)
    ]

    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, num_workers=7)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1)

    # Calculate validation Dice score

    model.eval()
    with torch.no_grad():
        for i, test_data in enumerate(val_loader):
            results['metric_per_case'].append({})
            
            test_volume = test_data["image"]
            test_label = test_data["label"]
            test_volume_no_preprocess = monai.transforms.LoadImage()(test_data["path"])
            test_volume, test_label = (
                test_volume.to(device),
                test_label.to(device),
            )

            test_outputs = monai.inferers.sliding_window_inference(test_volume, patch_size, num_samples_per_image, model, overlap=0.5)

            # CODE TO SAVE IMAGES
            prediction = post_pred_to_save(test_outputs[0])

            prediction = monai.transforms.ResampleToMatch()(img=prediction, img_dst=torch.unsqueeze(test_volume_no_preprocess, axis=0), mode='nearest')

            monai.transforms.SaveImage(output_dir=f"/swin_unetr/inference/fold_{fold}/", output_postfix='', separate_folder=False)(prediction)

In [5]:
def nnunet_dice(mask_pred: np.ndarray, mask_ref: np.ndarray):
    dice_scores = []
    for organ_class in range(1, 13):
        gt = (mask_ref == organ_class)
        pred = (mask_pred == organ_class)
        use_mask = np.ones_like(gt, dtype=bool)
        
        tp = np.sum((gt & pred) & use_mask)
        fp = np.sum(((~gt) & pred) & use_mask)
        fn = np.sum((gt & (~pred)) & use_mask)
        
        if tp + fp + fn == 0:
            dice_scores.append(np.nan)
        else:
            dice_scores.append(2 * tp / (2 * tp + fp + fn))
    
    return dice_scores

In [20]:
# CODE TO CALCULATE VALIDATION DICE SCORES AND SAVE THEM IN A JSON FILE

import monai.transforms


load_img = monai.transforms.LoadImage(ensure_channel_first=True)

for fold in tqdm(range(4, 5)):
    results = {}
    results['metric_per_case'] = []

    preds_folder = f"/swin_unetr/inference/fold_{fold}/preds"
    gt_folder = f"/swin_unetr/inference/fold_{fold}/gt"
    for i, pred_name in enumerate(sorted(os.listdir(preds_folder))):
        results['metric_per_case'].append({})
        
        gt = load_img(os.path.join(gt_folder, pred_name))
        pred = load_img(os.path.join(preds_folder, pred_name))
        dice_score = nnunet_dice(pred, gt)

        results['metric_per_case'][i]['metrics'] = {}
        for j in range(1, 13):
            results['metric_per_case'][i]['metrics'][str(j)] = dice_score[j-1]
        results['metric_per_case'][i]['prediction_file'] = os.path.join(preds_folder, pred_name)
        results['metric_per_case'][i]['reference_file'] = os.path.join(gt_folder, pred_name)
        
    with open(f"/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_{fold}.json", "w") as outfile:
        json.dump(results, outfile)
        print('File saved successfully.')

100%|██████████| 1/1 [04:06<00:00, 246.41s/it]

File saved successfully.





In [4]:
# CODE TO READ JSON FILES AND APPEND DICE SCORES PER ORGAN - ITS GOAL IS TO CALCULATE THE MEAN DICE PER ORGAN FOR ALL FOLDS

res = []
means = []

for i in range(0, 5):
    results = {}
    results['metric_per_case'] = []
    
    summary = json.load(open(f"/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_{i}.json"))
    for case in summary['metric_per_case']:
        res = []
        for key in range(1, 13):
            res.append(case['metrics'][str(key)])
        means.append(res)

In [9]:
print(np.mean(means, axis=0).round(3))
print(np.mean(means).round(3))

[0.961 0.952 0.947 0.826 0.833 0.976 0.921 0.949 0.903 0.861 0.788 0.785]
0.892


## Merge fold summaries into a single one (cross val)

In [24]:
f0 = json.load(open("/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_0.json"))
f1 = json.load(open("/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_1.json"))
f2 = json.load(open("/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_2.json"))
f3 = json.load(open("/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_3.json"))
f4 = json.load(open("/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_fold_4.json"))

In [26]:
cross_val = {'metric_per_case': []}

for fold in [f0, f1, f2, f3, f4]:
    for case in fold['metric_per_case']:
        cross_val['metric_per_case'].append(case)

In [29]:
with open(f"/experiments_Swin-UNETR/validation_dice_calculation/validation_results_per_fold/summary_cross_val.json", "w") as outfile:
        json.dump(cross_val, outfile)