In [1]:
import numpy as np
import cv2, os
import glob
from monai.data import DataLoader, Dataset, decollate_batch
import torch
import torchio as tio
from monai.transforms import (
    AsDiscreted,
    Compose,
    SaveImaged,
    Invertd,
)
from models.networks import P_RNet3D
from monai.inferers import sliding_window_inference
from geodis_toolkits  import get_geodismaps
from models.networks import P_RNet3D
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pnet_best_ckpt_dir = "./path/to/best_pnet_init_train"
pnet_best_ckpt_path = sorted(glob.glob(f"{pnet_best_ckpt_dir}/*.pt"))[-1]


In [None]:
def pnet_inference(
    image_path,
    save_path,
    pnet, 
    transform, 
    device
):
    """
    P-Net inference function
    
    Args:
        image_path:     file path of input image (ex. image_path.nii.gz)
        save_path:      file path to save result (ex. pnet_pred.nii.gz)
        pnet:           trained pnet model (torch.nn.Module)
        transform:      preprocessing transforms (torchio.Compose)
        norm_transform: preprocessing transforms (normalization)
        device:         torch device (torch.device)
    """
    test_images =sorted(glob.glob(os.path.join(image_path, "imagesTr", "*.nii.gz")))
    test_data = [{"image": image} for image in test_images]
    test_org_ds = Dataset(data=test_data, transform=transform)
    test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=1)
    post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=transform,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=save_path, squeeze_end_dims=True,output_postfix="pnet",resample=False,separate_folder=False)
    ])

    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs = test_data["image"].to(device)
            test_data["pred"] = pnet(test_inputs) 
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]

In [None]:
from data_loaders.transforms import get_transform
pnet = P_RNet3D(1,32,2,True).to(device)
pnet.load_state_dict(torch.load(pnet_best_ckpt_path))
pnet.eval()
test_images = "path/to/dataset/"
test_transform = get_transform("post")
save_path_pnet = "./save_path" 
pnet_inference(image_path=test_images,
            save_path=save_path_pnet,
            pnet=pnet,
            transform=test_transform,
            device=device)   

In [None]:
norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
rnet_best_ckpt_dir = "./path/to/best_rnet_init_train"
rnet_best_ckpt_path = sorted(glob.glob(f"{rnet_best_ckpt_dir}/*.pt"))[-1]

In [None]:
def rnet_inference(
    image_path,
    save_path,
    rnet, 
    transform, 
    norm_transform,
    device
):
    """
    P-Net inference function
    
    Args:
        image_path:     file path of input image (ex. image_path.nii.gz)
        save_path:      file path to save result (ex. pnet_pred.nii.gz)
        pnet:           trained pnet model (torch.nn.Module)
        transform:      preprocessing transforms (torchio.Compose)
        norm_transform: preprocessing transforms (normalization)
        device:         torch device (torch.device)
    """
    test_images =sorted(glob.glob(os.path.join(image_path, "imagesTr", "*.nii.gz")))
    test_labels = sorted(glob.glob(os.path.join(image_path, "labelsTr", "*.nii.gz")))
    pnet_masks = sorted(glob.glob(os.path.join(image_path, "pnet_masks", "*.nii.gz")))
    test_data = [{"image": image, "P_mask": mask_name,"label": label_name} for image, mask_name, label_name in zip(test_images,pnet_masks,test_labels)]
    test_org_ds = Dataset(data=test_data, transform=transform)
    test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=1)
    post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=transform,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=save_path, squeeze_end_dims=True,output_postfix="rnet",resample=False,separate_folder=False)
    ])

    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs = test_data["image"].to(device)
            true_labels = test_data["label"].to(device).type(torch.long)
            pnet_pred_labels = test_data["P_mask"].to(device).type(torch.long)
            fore_dist_map, back_dist_map = get_geodismaps(test_inputs.to("cpu").numpy(), 
                                                            true_labels.squeeze(dim=1).to("cpu").numpy(), 
                                                            pnet_pred_labels.squeeze(dim=1).to("cpu").numpy()) 
            fore_dist_map = norm_transform(torch.Tensor(fore_dist_map).squeeze(dim=1)) 
            back_dist_map = norm_transform(torch.Tensor(back_dist_map).squeeze(dim=1))
            fore_dist_map = fore_dist_map.unsqueeze(dim=1)
            back_dist_map = back_dist_map.unsqueeze(dim=1)
            rnet_inputs = torch.cat([
                test_inputs,
                pnet_pred_labels, 
                fore_dist_map.to(device), 
                back_dist_map.to(device)
            ], dim=1)            
            roi_size = (96, 96, 96)
            sw_batch_size = 2
            test_data["pred"] = sliding_window_inference(rnet_inputs, roi_size, sw_batch_size, rnet,overlap=0.6)
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]

In [None]:
from data_loaders.transforms_r import get_transform
rnet = P_RNet3D(4,32,2,True).to(device)
rnet.load_state_dict(torch.load(rnet_best_ckpt_path))
rnet.eval()
test_images = "path/to/dataset/"
test_transform = get_transform("post")
save_path_rnet = "path/to/save_path_rnet" 
rnet_inference(image_path=test_images,
            save_path=save_path_rnet,
            rnet=rnet,
            transform=test_transform,
            norm_transform= norm_transform,
            device=device)   