In [1]:
import argparse
import sys
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
from tqdm import tqdm
from pathlib import Path

thispath = Path().resolve()
base_path = thispath.parent
sys.path.insert(0, str(base_path))
from utils import plots
from utils.metrics import dice_score, rel_abs_vol_dif, avd, haussdorf
from utils.plots import plot_dice, plot_hausorf, plot_ravd, plot_avd

In [None]:
data_path = base_path / 'data'
test_set_path = data_path / 'test_set'
train_set_path = data_path / 'train_set'
val_set_path = data_path / 'validation_set'
params_path =  base_path / 'elastix' / 'parameter_maps' / 'Par0010'
our_atlas_path = data_path / 'ibsr_atlas'
mni_atlas_path = data_path / 'mni_atlas'

In [None]:
dl_path = base_path / 'experiments/dl/segmentations'
em_path = base_path / 'experiments/val_results/EM--tissue_models_init--atlas_after_misa/segmentations'
ma_path = base_path / 'experiments/simple_segmenters/multi_atlas_mi/segmentations'
t1_path = base_path / 'data/validation_set'

cases = ['IBSR_11', 'IBSR_12', 'IBSR_13', 'IBSR_14', 'IBSR_17']
fg, ax = plt.subplots(5,5, figsize=(20, 20))
slice_n = 20

for col, case in enumerate(cases):
    
    t1 = sitk.GetArrayFromImage(sitk.ReadImage(t1_path / case / f'{case}.nii.gz'))
    gt = sitk.GetArrayFromImage(sitk.ReadImage(t1_path / case / f'{case}_seg.nii.gz'))
    em = sitk.GetArrayFromImage(sitk.ReadImage(em_path / f'{case}.nii.gz'))
    ma = sitk.GetArrayFromImage(sitk.ReadImage(ma_path / f'{case}.nii.gz'))
    dl = sitk.GetArrayFromImage(sitk.ReadImage(dl_path / f'{case}.nii.gz'))        

    ax[0, col].imshow(t1, cmap='gray')
    ax[0, col].axis('off')
    if col == 0:
        ax[0, col].set_x_label('T1')
    ax[0, col].set_title(case)
    
    ax[1, col].imshow(gt, cmap='viridis')
    ax[1, col].axis('off')
    if col == 0:
        ax[1, col].set_x_label('Ground Truth')
    
    ax[2, col].imshow(em, cmap='viridis')
    ax[2, col].axis('off')
    if col == 0:
        ax[2, col].set_x_label('EM-based')
        
    ax[3, col].imshow(ma, cmap='viridis')
    ax[3, col].axis('off')
    if col == 0:
        ax[3, col].set_x_label('Multi-Atlas')
        
    ax[4, col].imshow(ma, cmap='viridis')
    ax[4, col].axis('off')
    if col == 0:
        ax[4, col].set_x_label('Multi-Atlas')
        
