In [1]:
import monai
import numpy as np
import os
from tqdm.notebook import tqdm
import torch
from monai.networks.nets import SwinUNETR

from monai.transforms import (
    AsDiscrete,
    Compose,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    EnsureChannelFirstd,
    SpatialPadd,
)

 missing cuda symbols while dynamic loading
 cuFile initialization failed


In [None]:
num_samples_per_image = 2
patch_size = [160, 64, 128]

post_label = monai.transforms.AsDiscrete(to_onehot=13)
post_pred = monai.transforms.AsDiscrete(argmax=True, to_onehot=13)
post_pred_to_save = monai.transforms.AsDiscrete(argmax=True)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 1.5),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-982,
            a_max=1094,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        SpatialPadd(keys=["image", "label"], spatial_size=patch_size),
    ]
)

source_pred_folder = "/MSD_SPLEEN/test_set"

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
            img_size=patch_size,
            in_channels=1,
            out_channels=13,
            feature_size=48,
            use_checkpoint=True,
            ).to(device)

for pred_image in tqdm(sorted(os.listdir(source_pred_folder))):
    try:
        test_data = val_transforms({'image': str(os.path.join(source_pred_folder, pred_image)), 'path': str(os.path.join(source_pred_folder, pred_image)), 'label': str(os.path.join(source_pred_folder, pred_image))})
        for fold in tqdm(range(0, 5)):
        
            weight = torch.load(f"/experiments_Swin-UNETR/fold_{fold}/output/swin_unetr_fold_{fold}_best_metric_model.pth")
            model.load_state_dict(weight['state_dict'])
    
            model.eval()
            with torch.no_grad():
                test_data[f'fold_{fold}_pred'] = monai.inferers.sliding_window_inference(torch.unsqueeze(test_data['image'], axis=0).to(device), patch_size, num_samples_per_image, model, overlap=0.5).cpu()
        test_data['ensembled_pred'] = []
        test_data = monai.transforms.MeanEnsembled(keys=['fold_0_pred', 'fold_1_pred', 'fold_2_pred', 'fold_3_pred', 'fold_4_pred'], output_key='ensembled_pred')(test_data)
        prediction = post_pred_to_save(test_data['ensembled_pred'][0])

        test_volume_no_preprocess = monai.transforms.LoadImage()(test_data["path"])
        
        prediction = monai.transforms.ResampleToMatch()(img=prediction, img_dst=torch.unsqueeze(test_volume_no_preprocess, axis=0), mode='nearest')
        monai.transforms.SaveImage(output_dir=f"/swin_unetr/inference/msd_spleen_test_set/preds", output_postfix='', separate_folder=False)(prediction)
    except Exception as e:
        print(e)
        continue