In [2]:
import os
from segment_anything.build_sam import sam_model_registry
from scripts.experiments.mask_aug.inference import class_inference, load_model
from scripts.sam_train import SamTrain
from segment_anything.modeling.sam import Sam
VAL_ROOT = "../dataset/FLARE22-version1/ReleaseValGT-20cases"
VOLUME_CACHE = os.path.join(VAL_ROOT, "images/FLARETs_0002_0000.cache.pt")
IMAGE_PATH = os.path.join(VAL_ROOT, "images/FLARETs_0002_0000.nii.gz")
MASK_PATH = os.path.join(VAL_ROOT, "labels/FLARETs_0002.nii.gz")
# MODEL_PATH = "../runs/transfer/imp-230603-150046/model-20.pt"
MODEL_PATH = "../runs/exps-230701-165310/model-20.pt"

model: Sam = sam_model_registry["vit_b"](
        checkpoint="../sam_vit_b_01ec64.pth", custom=MODEL_PATH
    )
sam_train = SamTrain(sam_model=model)


In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scripts.datasets.constant import IMAGE_TYPE
from scripts.datasets.preprocess_raw import FLARE22_Preprocess
from scripts.utils import torch_try_load

# %matplotlib inline


preprocessor = FLARE22_Preprocess()
volumes, masks = preprocessor.run_with_config(
            image_file=IMAGE_PATH,
            gt_file=MASK_PATH,
            config_name=IMAGE_TYPE.ABDOMEN_SOFT_TISSUES_ABDOMEN_LIVER,
        )
cache_volume = torch_try_load(VOLUME_CACHE, 'cpu')


In [4]:
from scripts.experiments.mask_aug.inference import get_all_organ_range


starts, ends = get_all_organ_range(masks)

In [None]:
from scripts.datasets.constant import FLARE22_LABEL_ENUM
%matplotlib widget
from ipywidgets import widgets
import matplotlib.pyplot as plt

organ_idx = FLARE22_LABEL_ENUM.LIVER.value
percent = 0.
fig, axes = plt.subplots(1, 2)

def f(percent):
    idx = int((ends[1] - starts[1]) * percent) + starts[1]
    axes[0].imshow(masks[idx])
    axes[1].imshow(masks[idx] == 1.0)
    axes[1].plot(274, 344, marker="o", markersize=5)
    axes[1].set_title(f'Frame: {idx}')
    fig.canvas.draw()
    pass

axes[0].imshow(masks[0])
axes[1].imshow(masks[0] == 1.0)

widgets.interact(f, percent=widgets.FloatSlider(min=0.0, max=1.0, step=0.02, value=0.0))

In [5]:
# As proven from another EDA, we will start to flip the mask and image for more consistent data

volumes, masks = preprocessor.run_with_config(
    image_file=IMAGE_PATH,
    gt_file=MASK_PATH,
    config_name=IMAGE_TYPE.ABDOMEN_SOFT_TISSUES_ABDOMEN_LIVER,
)

# Reversed
volumes = volumes[::-1]
masks = masks[::-1]

def centroid_of_volume(mask_volume: np.ndarray, class_number: int):
    coors = np.argwhere(mask_volume == class_number)
    centroid = np.mean(coors, axis=0)
    centroid_percent = centroid / mask_volume.shape
    centroid = np.ceil(centroid)
    return centroid, centroid_percent

# centroid_of_volume(masks, 1)

def centroid_by_path(image_file, gt_file, is_reversed=True):
    volumes, masks = preprocessor.run_with_config(
        image_file=image_file,
        gt_file=gt_file,
        config_name=IMAGE_TYPE.ABDOMEN_SOFT_TISSUES_ABDOMEN_LIVER,
    )
    if is_reversed:
        volumes = volumes[::-1]
        masks = masks[::-1]

    result = [[[0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0]]]
    for target_organ in range (1, 14):
        cent, cent_per = centroid_of_volume(masks, target_organ)
        result[0].append(cent)
        result[1].append(cent_per)

    return result

# centroid_by_path(IMAGE_PATH, MASK_PATH)


In [6]:
import glob

from tqdm import tqdm


TRAIN_ROOT = "../dataset/FLARE22-version1/FLARE22_LabeledCase50/"
image_list = sorted(glob.glob(f"{TRAIN_ROOT}/images/*.nii.gz"))
label_list = sorted(glob.glob(f"{TRAIN_ROOT}/labels/*.nii.gz"))

result = []
for image_file, gt_file in tqdm(zip(image_list, label_list)):
    result.append(centroid_by_path(image_file, gt_file))
    pass
result = np.array(result)


50it [01:17,  1.55s/it]


In [10]:
import pandas


organ_idx = 1
pandas.DataFrame(result[:, 1, organ_idx, :]).describe()

Unnamed: 0,0,1,2
count,50.0,50.0,50.0
mean,0.352355,0.544902,0.672188
std,0.058914,0.040365,0.029776
min,0.21249,0.447049,0.59208
25%,0.321103,0.521037,0.655094
50%,0.352454,0.549945,0.676175
75%,0.38275,0.56672,0.69251
max,0.49027,0.654832,0.724651


In [None]:
# sample, [coors or coors_percent], organ_idx, [x, y, z] = result.shape
organ_idx = FLARE22_LABEL_ENUM.LIVER.value
zs = result[:, 1, organ_idx, :][:, 0]
ys = result[:, 1, organ_idx, :][:, 1]
xs = result[:, 1, organ_idx, :][:, 2]

mean_coordinate = np.mean(result[:, 1, organ_idx, :], axis=0)
x_mean, y_mean, z_mean = mean_coordinate[2], mean_coordinate[1], mean_coordinate[0]

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.scatter(xs, ys, zs, marker='o', alpha=0.5)
ax.scatter(x_mean, y_mean, z_mean, marker='^', s=20)
ax.set_xlim(0.0, 1.0)
ax.set_ylim(0.0, 1.0)
ax.set_zlim(0.0, 1.0)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2)
axes[0].imshow(masks[141])
axes[1].imshow(masks[141] == 1.0)
axes[1].plot(344, 274, marker="o", markersize=5)
plt.show()

In [None]:
def confidence_score(logits, threshold):
    # Idea: confidence is high when the prob of 
    # foreground high and prob of background is low
    foreground_score = np.mean(logits[logits >= threshold])
    background_score = 1.0 - np.mean(logits[logits < threshold])
    return np.mean([foreground_score, background_score])