In [23]:
from prostate158.transforms import get_base_transforms
from prostate158.utils import load_config
from prostate158.model import get_model
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd
from monai.inferers import sliding_window_inference
import torch
import SimpleITK as sitk

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [25]:
config = load_config('tests/config/config.yaml') # change to 'tumor.yaml' for tumor segmentation
transforms = Compose(get_base_transforms(config=config))

In [26]:
model = get_model(config)
model.load_state_dict(torch.load('anatomy.pt'), strict=False)
model.to(device=device)
model.eval();

In [27]:
images = transforms({'t2': './tests/input/picai/10000_1000000_t2w.mha', 't2_anatomy_reader1': './tests/input/picai/10000_1000000.nii.gz'})
image = images['t2'].to(device=device).unsqueeze(0)
input = torch.cat([image])

In [28]:
with torch.no_grad():
    roi_size = (128, 128, 128)
    sw_batch_size = 4
    val_outputs = sliding_window_inference(input, roi_size, sw_batch_size, model, overlap=0.5)

In [29]:
print(images['t2_anatomy_reader1'].shape)
print(images['t2'].shape)
print(images['t2_anatomy_reader1'][:,:,10].any())

torch.Size([1, 360, 360, 159])
torch.Size([1, 360, 360, 199])
metatensor(False)


In [33]:
from matplotlib import pyplot as plt
from ipywidgets import interact

# swap 1s and 2s in label image
temp_value = 999
images["t2_anatomy_reader1"][images["t2_anatomy_reader1"] == 1] = temp_value
images["t2_anatomy_reader1"][images["t2_anatomy_reader1"] == 2] = 1
images["t2_anatomy_reader1"][images["t2_anatomy_reader1"] == temp_value] = 2

def plot(index):
    plt.figure("check", figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.title(f"image")
    plt.imshow(image[0, 0, :, :, index].cpu())
    plt.subplot(1, 3, 2)
    plt.title(f"label")
    plt.imshow(images["t2_anatomy_reader1"][0, :, :, index])
    plt.subplot(1, 3, 3)
    plt.title(f"output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, index])
    plt.show()

shape = images["t2_anatomy_reader1"].shape
interact(plot, index=(0, shape[-1] - 1, 1))

interactive(children=(IntSlider(value=79, description='index', max=158), Output()), _dom_classes=('widget-inte…

<function __main__.plot(index)>