In [179]:
from prostate158.transforms import get_base_transforms
from prostate158.utils import load_config
from prostate158.model import get_model
from monai.transforms import Compose
from monai.inferers import sliding_window_inference
import torch
from ipywidgets import interact
from matplotlib import pyplot as plt

In [180]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = load_config('tests/config/config.yaml') # change to 'tumor.yaml' for tumor segmentation
transforms = Compose(get_base_transforms(config=config))

In [181]:
model = get_model(config)
model.load_state_dict(torch.load('models_copy/anatomy_3_1_sequence.pt'), strict=False)
model.to(device=device)
model.eval();

In [182]:
images = transforms({'t2': './tests/input/027/t2.nii.gz', 't2_anatomy_reader1': './tests/input/027/t2_anatomy_reader1.nii.gz'})
image = images['t2'].to(device=device).unsqueeze(0)
input = torch.cat([image])

In [183]:
with torch.no_grad():
    roi_size = (160, 160, 160)
    sw_batch_size = 4
    val_outputs = sliding_window_inference(input, roi_size, sw_batch_size, model)

In [185]:
def plot(index):
    plt.figure("check", (18, 6))
    plt.subplot(1, 3, 1)
    plt.title(f"image")
    plt.imshow(image[0, 0, :, :, index].cpu())
    plt.subplot(1, 3, 2)
    plt.title(f"label")
    plt.imshow(images["t2_anatomy_reader1"][0, :, :, index])
    plt.subplot(1, 3, 3)
    plt.title(f"output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, index])
    plt.show()

shape = images["t2_anatomy_reader1"].shape
interact(plot, index=(0, shape[-1] - 1, 1))

interactive(children=(IntSlider(value=69, description='index', max=138), Output()), _dom_classes=('widget-inte…