In [30]:
from segment2d import *
import numpy as np
import csv
from matplotlib import pyplot as plt
from ipywidgets import interact
# visualize the image and mask in z axis using interact, image and mask are in one slice
def plot_image_mask_z(image, mask, z, prediction=None):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image[..., z], cmap="gray")
    ax[0].set_title("Image")
    ax[0].imshow(mask[..., z], cmap="jet", alpha=0.3)
    if prediction is not None:
        ax[1].imshow(image[..., z], cmap="gray")
        ax[1].set_title("Image prediction")
        ax[1].imshow(prediction[..., z], cmap="jet", alpha=0.3)

    plt.show()

In [31]:
def preprocess_data(image_path):
    data = {}
    image = nib.load(image_path).get_fdata()

    image = min_max_normalize(image)

    padded_image, crop_index, padded_index = pad_background(image, dim2pad=cfg.DATA.DIM2PAD)
    # padded_mask = pad_background_with_index(mask, crop_index, padded_index, dim2pad=cfg.DATA.DIM2PAD)
    data["crop_index"] = crop_index
    data["padded_index"] = padded_index
    data["original_shape"] = image.shape
    batch_images = []
    for i in range(padded_image.shape[-1]):
        slice_inputs = padded_image[..., i : i + 1]  # shape (224, 224, 1)
        slices_image = torch.from_numpy(slice_inputs.transpose(-1, 0, 1))  # shape (1, 224, 224)
        batch_images.append(slices_image)

    batch_images = torch.stack(batch_images).float()  # shape (9,1, 224, 224)
    data["image"] = batch_images
    return data


def predict_data(data, segmenter, threshold=100, task="train_combine"):
    probability_output = segmenter.predict_patches(data["image"])  # shape (n, 5, 128, 128)
    seg = np.argmax(probability_output, axis=1).transpose(1, 2, 0)  # shape (128, 128, n)
    seg = remove_small_elements(seg, min_size_remove=1000)
    if np.sum(seg==3) + np.sum(seg==4) < threshold:
        # convert 3 to 2
        seg[seg == 3] = 2
        seg[seg == 4] = 2
    if task == "train_combine":
        seg[seg == 4] = 3
    invert_seg = invert_padding(data["original_shape"], seg, data["crop_index"], data["padded_index"])
    return invert_seg

In [32]:
task = "train_combine"
task = "train_full"
num_classes = 4 if task == "train_combine" else 5
with open("./test.csv", mode="r") as f:
    reader = csv.DictReader(f)
    list_test_subject = [row["path"] for row in reader]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FCDenseNet(in_channels=cfg.DATA.INDIM_MODEL, n_classes=num_classes)
list_mask_test_dataset = [x.replace("Images", "Contours") for x in list_test_subject]
test_dataset = EMIDEC_Test_Loader(list_test_subject)
segmenter = Segmenter(
    model,
    cfg.DATA.CLASS_WEIGHT,
    5,
    0.001,
    0.5,
    50,
)
segmenter.eval()
if task == "train_combine":
    checkpoint = "./weights_train_combine/dice_0.7721.ckpt"
else:
    checkpoint = "./weights_train_full/myo_0.9266.ckpt"
segmenter = Segmenter.load_from_checkpoint(
    checkpoint_path=checkpoint,
    model=model,
    class_weight=cfg.DATA.CLASS_WEIGHT,
    num_classes=num_classes,
    learning_rate=0.001,
    factor_lr=0.5,
    patience_lr=50,
)
segmenter = segmenter.to(device)

# combine MI + PMO

In [55]:
MI_test_pts = [
    "Case_P050",
    "Case_P087",
    "Case_P001",
    "Case_P010",
    "Case_P017",
    "Case_P029",
    "Case_P090",
    "Case_P038",
    "Case_N052",
    "Case_N016",
    "Case_P100",
    "Case_P043",
    "Case_P051",
    "Case_N030",
    "Case_P007",
    "Case_P088",
    "Case_N025",
    "Case_P076",
    "Case_N046",
    "Case_N054",
    "Case_N049",
    "Case_N041",
    "Case_N023",
    "Case_P026",
    "Case_P031",
    "Case_N024",
    "Case_P064",
    "Case_P021",
    "Case_P015",
    "Case_P094",
]
task = "train_combine"
num_classes = 4 if task == "train_combine" else 5
with open("./test.csv", mode="r") as f:
    reader = csv.DictReader(f)
    list_test_subject = [row["path"] for row in reader]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FCDenseNet(in_channels=cfg.DATA.INDIM_MODEL, n_classes=num_classes)
list_mask_test_dataset = [x.replace("Images", "Contours") for x in list_test_subject]
test_dataset = EMIDEC_Test_Loader(list_test_subject)
segmenter = Segmenter(
    model,
    cfg.DATA.CLASS_WEIGHT,
    5,
    0.001,
    0.5,
    50,
)
segmenter.eval()
if task == "train_combine":
    checkpoint = "./weights_train_combine/dice_0.7721.ckpt"
else:
    checkpoint = "./weights_train_full/myo_0.9266.ckpt"
segmenter = Segmenter.load_from_checkpoint(
    checkpoint_path=checkpoint,
    model=model,
    class_weight=cfg.DATA.CLASS_WEIGHT,
    num_classes=num_classes,
    learning_rate=0.001,
    factor_lr=0.5,
    patience_lr=50,
)
segmenter = segmenter.to(device)

In [57]:
dice_scores = {"dice_myocardium": [], "dice_lv": [], "dice_mi": []}
dice_scores_combined = {"dice_myocardium": [], "dice_lv": []}

for i in range(len(list_test_subject)):
    id_patient = list_test_subject[i].split("/")[-3]
    # if id_patient not in ["Case_N023"]:
    if id_patient not in MI_test_pts:
        continue
    test_image = nib.load(list_test_subject[i]).get_fdata()
    mask_image = nib.load(list_test_subject[i].replace("Images", "Contours")).get_fdata()
    mask_image[mask_image == 4] = 3
    data = preprocess_data(list_test_subject[i])
    seg = predict_data(data, segmenter, threshold=76, task=task).astype(np.uint8)
    dice_myo = dice_volume(mask_image, seg, class_index=2)
    dice_lv = dice_volume(mask_image, seg, class_index=1)
    dice_mi = dice_volume(mask_image, seg, class_index=3)
    # print("number of MI: ", np.sum(seg==3))
    dice_scores["dice_myocardium"].append(dice_myo)
    dice_scores["dice_lv"].append(dice_lv)
    dice_scores["dice_mi"].append(dice_mi)

    seg_combined = seg.copy()
    seg_combined[seg_combined == 3] = 2
    mask_combined = mask_image.copy()
    mask_combined[mask_combined == 3] = 2

    dice_myo_combined = dice_volume(mask_combined, seg_combined, class_index=2)
    dice_lv_combined = dice_volume(mask_combined, seg_combined, class_index=1)
    dice_scores_combined["dice_myocardium"].append(dice_myo_combined)
    dice_scores_combined["dice_lv"].append(dice_lv_combined)

    print(f"{id_patient} myo: {dice_myo:0.4f}, lv: {dice_lv:0.4f}, mi: {dice_mi:0.4f}")
    # print(f"patient {id_patient} dice myo: {dice_myo_combined:0.4f}, dice lv: {dice_lv_combined:0.4f}")
    # print(
    #     f"patient {id_patient} dice myo: {dice_myo_disease:0.4f}, dice lv: {dice_lv_disease:0.4f}, dice mi: {dice_mi_disease:0.4f}"
    # )

Case_N024 myo: 0.8879, lv: 0.9636, mi: 1.0000
Case_N054 myo: 0.8900, lv: 0.9506, mi: 1.0000
Case_N023 myo: 0.8569, lv: 0.9519, mi: 1.0000
Case_P100 myo: 0.8660, lv: 0.9364, mi: 0.4352
Case_P007 myo: 0.8541, lv: 0.9413, mi: 0.4738
Case_P038 myo: 0.8035, lv: 0.9442, mi: 0.7634
Case_P021 myo: 0.8877, lv: 0.9255, mi: 0.2780
Case_P087 myo: 0.8590, lv: 0.9326, mi: 0.6469


In [31]:
# calculate mean dice
for keys in dice_scores.keys():
    print(f"mean dice {keys}: {np.mean(dice_scores[keys]):0.4f}")

mean dice dice_myocardium: 0.8321
mean dice dice_lv: 0.9368
mean dice dice_mi: 0.7060


In [None]:
interact(lambda z: plot_image_mask_z(test_image, mask_image, z, seg), z=(0, test_image.shape[-1] - 1))

interactive(children=(IntSlider(value=3, description='z', max=7), Output()), _dom_classes=('widget-interact',)…

<function __main__.<lambda>(z)>

# Full class

In [62]:
MI_test_pts = [
    "Case_P050",
    "Case_P087",
    "Case_P001",
    "Case_P010",
    "Case_P017",
    "Case_P029",
    "Case_P090",
    "Case_P038",
    "Case_N052",
    "Case_N016",
    "Case_P100",
    "Case_P043",
    "Case_P051",
    "Case_N030",
    "Case_P007",
    "Case_P088",
    "Case_N025",
    "Case_P076",
    "Case_N046",
    "Case_N054",
    "Case_N049",
    "Case_N041",
    "Case_N023",
    "Case_P026",
    "Case_P031",
    "Case_N024",
    "Case_P064",
    "Case_P021",
    "Case_P015",
    "Case_P094",
]
# get path from list patient id
list_test_subject = []
for patiend_id in MI_test_pts:
    path = f"./emidec-dataset-1.0.1/{patiend_id}/Images/{patiend_id}.nii.gz"
    list_test_subject.append(path)

In [None]:

with open("./test.csv", mode="r") as f:
    reader = csv.DictReader(f)
    list_test_subject = [row["path"] for row in reader]



In [63]:
task = "train_full"
list_mask_test_dataset = [x.replace("Images", "Contours") for x in list_test_subject]
num_classes = 4 if task == "train_combine" else 5
test_dataset = EMIDEC_Test_Loader(list_test_subject)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FCDenseNet(in_channels=cfg.DATA.INDIM_MODEL, n_classes=num_classes)
segmenter = Segmenter(
    model,
    cfg.DATA.CLASS_WEIGHT,
    5,
    0.001,
    0.5,
    50,
)
segmenter.eval()
if task == "train_combine":
    checkpoint = "./weights_train_combine/dice_0.7721.ckpt"
else:
    checkpoint = "./weights_train_full/myo_0.9266.ckpt"
segmenter = Segmenter.load_from_checkpoint(
    checkpoint_path=checkpoint,
    model=model,
    class_weight=cfg.DATA.CLASS_WEIGHT,
    num_classes=num_classes,
    learning_rate=0.001,
    factor_lr=0.5,
    patience_lr=50,
)
segmenter = segmenter.to(device)

In [65]:
dice_scores = {"dice_myocardium": [], "dice_lv": [], "dice_mi": [], "dice_pmo": []}
dice_scores_combined = {"dice_myocardium": [], "dice_lv": []}
dice_scores_disease = {"dice_myocardium": [], "dice_lv": [], "dice_MI": []}
for i  in range(len(list_test_subject)):
    id_patient = list_test_subject[i].split("/")[-3]
    test_image = nib.load(list_test_subject[i]).get_fdata()
    mask_image = nib.load(list_mask_test_dataset[i]).get_fdata()
    # if id_patient not in MI_test_pts:
    #     continue
    data = preprocess_data(list_test_subject[i])
    seg = predict_data(data, segmenter, threshold=50, task=task).astype(np.uint8)
    # print("number of MI: ", np.sum(seg==3))
    if np.sum(mask_image == 4) > 0:
        print("number of PMO: ", np.sum(mask_image == 4))
    dice_myo = dice_volume(mask_image, seg, class_index=2)
    dice_lv = dice_volume(mask_image, seg, class_index=1)
    dice_mi = dice_volume(mask_image, seg, class_index=3)
    dice_pmo = dice_volume(mask_image, seg, class_index=4)
    dice_scores["dice_myocardium"].append(dice_myo)
    dice_scores["dice_lv"].append(dice_lv)
    dice_scores["dice_mi"].append(dice_mi)
    dice_scores["dice_pmo"].append(dice_pmo)
    seg_combined = seg.copy()
    seg_combined[seg_combined == 3] = 2
    seg_combined[seg_combined == 4] = 2
    mask_combined = mask_image.copy()
    mask_combined[mask_combined == 3] = 2
    mask_combined[mask_combined == 4] = 2
    dice_myo_combined = dice_volume(mask_combined, seg_combined, class_index=2)
    dice_lv_combined = dice_volume(mask_combined, seg_combined, class_index=1)
    dice_scores_combined["dice_myocardium"].append(dice_myo_combined)
    dice_scores_combined["dice_lv"].append(dice_lv_combined)
    seg_disease = seg.copy()

    seg_disease[seg_disease == 4] = 3
    mask_disease = mask_image.copy()
    mask_disease[mask_disease == 4] = 3
    dice_myo_disease = dice_volume(mask_disease, seg_disease, class_index=2)
    dice_lv_disease = dice_volume(mask_disease, seg_disease, class_index=1)
    dice_mi_disease = dice_volume(mask_disease, seg_disease, class_index=3)
    dice_scores_disease["dice_myocardium"].append(dice_myo_disease)
    dice_scores_disease["dice_lv"].append(dice_lv_disease)
    dice_scores_disease["dice_MI"].append(dice_mi_disease)
    print(f"{id_patient} myo: {dice_myo:0.4f}, lv: {dice_lv:0.4f}, mi: {dice_mi:0.4f}, pmo: {dice_pmo:0.4f}")
    # print(f"patient {id_patient} dice myo: {dice_myo_combined:0.4f}, dice lv: {dice_lv_combined:0.4f}")
    # print(f"patient {id_patient} dice myo: {dice_myo_disease:0.4f}, dice lv: {dice_lv_disease:0.4f}, dice mi: {dice_mi_disease:0.4f}")

number of PMO:  442
Case_P050 myo: 0.8979, lv: 0.9774, mi: 0.9055, pmo: 0.8743
Case_P087 myo: 0.8448, lv: 0.9111, mi: 0.5111, pmo: 1.0000
number of PMO:  1505
Case_P001 myo: 0.8817, lv: 0.9770, mi: 0.8998, pmo: 0.9300
number of PMO:  18
Case_P010 myo: 0.9382, lv: 0.9633, mi: 0.8323, pmo: 0.9714
Case_P017 myo: 0.9041, lv: 0.9695, mi: 0.8161, pmo: 1.0000
number of PMO:  189
Case_P029 myo: 0.8982, lv: 0.9766, mi: 0.8730, pmo: 0.8883
Case_P090 myo: 0.8584, lv: 0.9604, mi: 0.8261, pmo: 1.0000
Case_P038 myo: 0.7803, lv: 0.9415, mi: 0.7635, pmo: 0.0000
Case_N052 myo: 0.9184, lv: 0.9589, mi: 1.0000, pmo: 1.0000
Case_N016 myo: 0.9182, lv: 0.9500, mi: 1.0000, pmo: 1.0000
number of PMO:  29
Case_P100 myo: 0.8538, lv: 0.9321, mi: 0.5455, pmo: 0.0000
number of PMO:  252
Case_P043 myo: 0.8710, lv: 0.9742, mi: 0.8791, pmo: 0.8816
Case_P051 myo: 0.9123, lv: 0.9689, mi: 0.8291, pmo: 1.0000
Case_N030 myo: 0.9217, lv: 0.9500, mi: 1.0000, pmo: 1.0000
Case_P007 myo: 0.8790, lv: 0.9522, mi: 0.4664, pmo: 1.0

In [39]:
np.sum(seg==4)

0

In [10]:
dice_scores = {"dice_myocardium": [], "dice_lv": [], "dice_mi": [], "dice_pmo": []}
dice_scores_combined = {"dice_myocardium": [], "dice_lv": []}
dice_scores_disease = {"dice_myocardium": [], "dice_lv": [], "dice_MI": []}
i = -2
id_patient = list_test_subject[i].split("/")[-3]
test_image = nib.load(list_test_subject[i]).get_fdata()
mask_image = nib.load(list_mask_test_dataset[i]).get_fdata()
data = preprocess_data(list_test_subject[i])
seg = predict_data(data, segmenter).astype(np.uint8)
dice_myo = dice_volume(mask_image, seg, class_index=2)
dice_lv = dice_volume(mask_image, seg, class_index=1)
dice_mi = dice_volume(mask_image, seg, class_index=3)
dice_pmo = dice_volume(mask_image, seg, class_index=4)
dice_scores["dice_myocardium"].append(dice_myo)
dice_scores["dice_lv"].append(dice_lv)
dice_scores["dice_mi"].append(dice_mi)
dice_scores["dice_pmo"].append(dice_pmo)
seg_combined = seg.copy()
seg_combined[seg_combined == 3] = 2
seg_combined[seg_combined == 4] = 2
mask_combined = mask_image.copy()
mask_combined[mask_combined == 3] = 2
mask_combined[mask_combined == 4] = 2
dice_myo_combined = dice_volume(mask_combined, seg_combined, class_index=2)
dice_lv_combined = dice_volume(mask_combined, seg_combined, class_index=1)
dice_scores_combined["dice_myocardium"].append(dice_myo_combined)
dice_scores_combined["dice_lv"].append(dice_lv_combined)

seg_disease = seg.copy()
seg_disease[seg_disease == 4] = 3
mask_disease = mask_image.copy()
mask_disease[mask_disease == 4] = 3
dice_myo_disease = dice_volume(mask_disease, seg_disease, class_index=2)
dice_lv_disease = dice_volume(mask_disease, seg_disease, class_index=1)
dice_mi_disease = dice_volume(mask_disease, seg_disease, class_index=3)
dice_scores_disease["dice_myocardium"].append(dice_myo_disease)
dice_scores_disease["dice_lv"].append(dice_lv_disease)
dice_scores_disease["dice_MI"].append(dice_mi_disease)
# print(f"{id_patient} myo: {dice_myo:0.4f}, lv: {dice_lv:0.4f}, mi: {dice_mi:0.4f}, pmo: {dice_pmo:0.4f}")
print(f"patient {id_patient} dice myo: {dice_myo_disease:0.4f}, dice lv: {dice_lv_disease:0.4f}, dice mi: {dice_mi_disease:0.4f}")


patient Case_P072 dice myo: 0.7927, dice lv: 0.9452, dice mi: 0.7727


In [60]:
# calculate mean dice
for keys in dice_scores.keys():
    print(f"mean dice {keys}: {np.mean(dice_scores[keys]):0.4f}")



mean dice dice_myocardium: 0.8248
mean dice dice_lv: 0.9313
mean dice dice_mi: 0.4985
mean dice dice_pmo: 0.6211


In [61]:
for keys in dice_scores_combined.keys():
    print(f"mean dice {keys}: {np.mean(dice_scores_combined[keys]):0.4f}")

mean dice dice_myocardium: 0.8374
mean dice dice_lv: 0.9313


In [56]:
for keys in dice_scores_disease.keys():
    print(f"mean dice {keys}: {np.mean(dice_scores_disease[keys]):0.4f}")

mean dice dice_myocardium: 0.8248
mean dice dice_lv: 0.9314
mean dice dice_MI: 0.5071


In [None]:



interact(lambda z: plot_image_mask_z(test_image, mask_disease, z, seg_disease), z=(0, test_image.shape[-1] - 1))

interactive(children=(IntSlider(value=4, description='z', max=8), Output()), _dom_classes=('widget-interact',)…

<function __main__.<lambda>(z)>