In [1]:
import numpy as np
import torch
import cv2
from matplotlib import pyplot as plt

from datasets.segmentation_dataset import SegmentationDataset
from models.unet3d import Unet3D
from train_segentation import validation
from utils.preprocessing import clahe3d, minmax_normalize, rescale, pad16
from utils.visualization import visualize
from utils.data_loaders import load_dcm_as_rsa_voxel_image
from utils.constants import root_dir

In [12]:
image, spacings = load_dcm_as_rsa_voxel_image("1.2.826.0.1.3680043.21040")
image = minmax_normalize(image)
image = rescale(image, spacings, 1, "trilinear")
image = clahe3d(image, 32, 2, 512)

In [13]:
visualize(image)

In [14]:
model = Unet3D().eval()
state_dict = torch.load("weights/segmentation_best_weights.pth", map_location='cpu')
model.load_state_dict(state_dict)

def segment(model, image):
    w, h, d = image.shape
    image = torch.tensor(image).float()
    padded_image = pad16(image)
    padded_image = padded_image[None, None, ...]

    with torch.no_grad():
        mask = model(padded_image)

    mask = mask.argmax(dim=1)
    mask = mask[0, :w, :h, :d].numpy()

    return mask

mask = segment(model, image)

In [15]:
visualize(image, mask)