In [8]:
from prostate158.transforms import get_base_transforms
from prostate158.utils import load_config
from prostate158.model import get_model
from monai.transforms import Compose
from monai.inferers import sliding_window_inference
import torch
from ipywidgets import interact
from matplotlib import pyplot as plt

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
config = load_config('tests/config/config_all_sequences.yaml') # change to 'tumor.yaml' for tumor segmentation
transforms = Compose(get_base_transforms(config=config))

In [11]:
model_all_sequences = get_model(config)
model_all_sequences.load_state_dict(torch.load('models_saved/anatomy_all_sequences.pt'), strict=False)
model_all_sequences.to(device=device)
model_all_sequences.eval();

In [12]:
images_all_sequences = transforms({'t2': './tests/input/027/t2.nii.gz', 't2_anatomy_reader1': './tests/input/027/t2_anatomy_reader1.nii.gz', 'adc': './tests/input/027/adc.nii.gz', 'dwi': './tests/input/027/dwi.nii.gz'})
image_t2w_2 = images_all_sequences['t2'].to(device=device).unsqueeze(0)
image_adc = images_all_sequences['adc'].to(device=device).unsqueeze(0)
image_dwi = images_all_sequences['dwi'].to(device=device).unsqueeze(0)
images_all_sequences_input = torch.cat((image_t2w_2, image_adc, image_dwi), dim=1)  
input_all_sequences = torch.cat([images_all_sequences_input])

In [13]:
with torch.no_grad():
    roi_size = (160, 160, 160)
    sw_batch_size = 4
    val_outputs_1_sequence = sliding_window_inference(input_all_sequences, roi_size, sw_batch_size, model_all_sequences)

In [14]:
def plot(index):
    plt.figure("check", (18, 12))
    plt.subplot(3, 3, 1)
    plt.title(f"image t2w")
    plt.imshow(images_all_sequences_input[0, 0, :, :, index].cpu())
    plt.subplot(3, 3, 2)
    plt.title(f"image adc")
    plt.imshow(images_all_sequences_input[0, 1, :, :, index].cpu())
    plt.subplot(3, 3, 3)
    plt.title(f"image dwi")
    plt.imshow(images_all_sequences_input[0, 2, :, :, index].cpu())
    plt.subplot(3, 3, 4)
    plt.title(f"label")
    plt.imshow(images_all_sequences["t2_anatomy_reader1"][0, :, :, index])
    plt.subplot(3, 3, 5)
    plt.title(f"output")
    plt.imshow(torch.argmax(val_outputs_1_sequence, dim=1).detach().cpu()[0, :, :, index])
    plt.show()

shape = images_all_sequences["t2_anatomy_reader1"].shape
interact(plot, index=(0, shape[-1] - 1, 1))

interactive(children=(IntSlider(value=69, description='index', max=138), Output()), _dom_classes=('widget-inte…

<function __main__.plot(index)>