In [None]:
import pandas as pd
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import SimpleITK
import itertools
import sys
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import cm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from pathlib import Path

SOURCE_PATH = Path(os.getcwd()) / "src"

if SOURCE_PATH not in sys.path:
    sys.path.append(SOURCE_PATH)

from src.extraction import get_images_lists_from_path, get_images_lists_from_more_paths

from src.plots import plot_observation

from src.loading import load_images_from_paths

%load_ext autoreload
%autoreload 2

In [None]:
type_to_use = "t2"
seg_path = Path(os.getcwd()) / "data_extracted" / "seg"
input_path = Path(os.getcwd()) / "data_extracted" / type_to_use
images, segs = load_images_from_paths(input_path, seg_path)

modelname = "t2_20e_mobilenet_mse.pt"
model = torch.load(Path(os.getcwd()) / "models" / modelname)
model.eval()

output_save_path = Path(os.getcwd()) / "output"
if not os.path.isdir(output_save_path):
    os.mkdir(output_save_path)

In [None]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

indexes_predict = [32, 36, 48]

for i in indexes_predict:
    input_image = Image.fromarray(images[i])
    true_segments = Image.fromarray(segs[i])

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        output = model(input_batch)["out"][0]
    output_predictions = torch.amax(output, 0).numpy()

    # create a color pallette, selecting a color for each class
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    threshold_min = np.percentile(output_predictions, 90)
    threshold_mid = np.percentile(output_predictions, 95)
    threshold_max = np.percentile(output_predictions, 99)

    output_pred = output_predictions
    output_pred = np.where(output_predictions > threshold_min, threshold_min, 0)
    output_pred = np.where(
        output_predictions > threshold_mid, threshold_mid, output_pred
    )
    output_pred = np.where(
        output_predictions > threshold_max, threshold_max, output_pred
    )

    f, ax = plt.subplots(1, 3, figsize=(15, 4))
    ax[0].set_title("input image")
    ax[0].axis("off")
    ax[0].imshow(input_image)
    ax[1].set_title("segmented output")
    ax[1].axis("off")
    ax[1].imshow(output_pred)
    ax[2].set_title("ground truth")
    ax[2].axis("off")
    ax[2].imshow(true_segments)
    plt.show()

    np.save(str(output_save_path) + f"/segmented_{i}.npy", output_pred)