In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author      : Han Liu
# Date Created: 04/27/2021


import os
import os.path as osp
from glob import glob
from time import time
import numpy as np
import ants
import skimage.morphology
import matplotlib.pyplot as plt
from utility import get_cc3d


GT_DIR = '/data/SkullSyn/bone_data/label'
RESULT_DIR = '/data/SkullSyn/results/JMI/Pred'
SAVE_DIR = '/data/SkullSyn/results/JMI/Save'

In [5]:
def extract_skull_mask(img: np.array, thresh:int=400, radius=4) -> np.array:
    msk = img.copy()
    msk[msk <= thresh] = 0
    msk[msk != 0] = 1
    msk = get_cc3d(msk, top=1)
#     struct = skimage.morphology.ball(radius=radius)
#     msk = skimage.morphology.binary_dilation(msk, struct)
    return msk


def evaluate(g_img: np.array, p_img: np.array):
    msk = extract_skull_mask(g_img).astype('uint8')
    dif = np.abs(g_img - p_img) * msk  # compute MAE in the skull mask region
    mae = np.sum(dif) / np.sum(msk)
    return msk, dif, mae


def eval_experiment(experiment_name: str, show_hist=False, save_files=False):
    g_paths = sorted(glob(GT_DIR + '/*.*'))
    assert osp.exists(RESULT_DIR + f'/{experiment_name}'), f'{experiment_name} does not exist...'
    p_paths = sorted(glob(RESULT_DIR + f'/{experiment_name}' + '/*.*'))
    # create a folder to save results; otherwise overwrite
    if not osp.exists(osp.join(SAVE_DIR, experiment_name)):
        os.mkdir(osp.join(SAVE_DIR, experiment_name))
    
    f = open(osp.join(SAVE_DIR, experiment_name, 'log.text'), 'w')
    f.write(f'Experiment name: {experiment_name}\n')
    print('start making inference...')
    mae = []
    g_imgs, p_imgs = None, None
    start = time()
    for g_path, p_path in zip(g_paths, p_paths):
        assert osp.basename(g_path)[:4] == osp.basename(p_path)[:4]
        image_id = osp.basename(g_path)[:4]
        print(f'processing image ID {image_id}...')
        g = ants.image_read(g_path)
        p = ants.image_read(p_path)
        info = [g.origin, g.spacing, g.direction]
        
        # post-processing results: remove noise in background
        p_img = extract_skull_mask(p.numpy(), thresh=350) * p.numpy()
        msk, dif, err = evaluate(g.numpy(), p_img)
        mae.append(err)
        
        if show_hist:
            g_imgs += g.numpy().flatten().tolist()
            p_imgs += p_img.flatten().tolist()
        
        if save_files:
            ants_fake = ants.from_numpy(p_img, origin=info[0], spacing=info[1], direction=info[2])
            ants.image_write(ants_fake, osp.join(SAVE_DIR, experiment_name, f'{image_id}_fake_CT.nii.gz'))
            ants_msk = ants.from_numpy(msk, origin=info[0], spacing=info[1], direction=info[2])
            ants.image_write(ants_msk, osp.join(SAVE_DIR, experiment_name, f'{image_id}_skull_mask.nii.gz'))
            ants_dif = ants.from_numpy(dif, origin=info[0], spacing=info[1], direction=info[2])
            ants.image_write(ants_dif, osp.join(SAVE_DIR, experiment_name, f'{image_id}_dif.nii.gz'))
        f.write(f'{image_id} MAE (skull): {err:.2f}\n')
    
    elapsed = time() - start
    f.write('============\n')
    f.write(f'MAE of {len(g_paths)} testing images: {np.mean(mae):.2f}\n')
    f.write(f'min MAE of {len(g_paths)} testing images: {np.min(mae):.2f}\n')
    f.write(f'max MAE of {len(g_paths)} testing images: {np.max(mae):.2f}\n')
    f.write(f'STD of {len(g_paths)} testing images: {np.std(mae):.2f}\n')
    f.write(f'time elapsed: {int(elapsed)} seconds\n')
    
    # plot histogram of groundtruth and prediction
    if show_hist:
        fig, ax = plt.subplots(figsize=(12, 6))
        ax.set_title(f'Groundtruth CT vs Synthetic CT')
        ax.hist(x=[np.array(g_imgs), np.array(p_imgs)],
                bins=20, range=(100, 2000), alpha=0.9, stacked=True, label=['GT', 'SYN'])
        plt.legend(loc='upper right')
        plt.show()
        
    print(f'MAE of {len(g_paths)} testing images: {np.mean(mae):.2f}')
    print(f'min MAE of {len(g_paths)} testing images: {np.min(mae):.2f}')
    print(f'max MAE of {len(g_paths)} testing images: {np.max(mae):.2f}')
    print(f'STD of {len(g_paths)} testing images: {np.std(mae):.2f}')

### post-processing and evaluation

In [6]:
print(os.listdir(RESULT_DIR))
for exp_name in os.listdir(RESULT_DIR):
    print(f'current experiment name: {exp_name}')
    eval_experiment(exp_name, show_hist=False, save_files=True)


['cGAN_L1_1e2_Edge_MONAI', 'cGAN_L1_100', 'cGAN_L1_50_Edge', 'cGAN_L1_200', 'cGAN_L1_100_Edge', 'ResNet_Edge', 'cGAN_L1_50', 'cGAN_L1_200_Edge', 'cGAN_L1_500_Edge', 'ResNet']
current experiment name: cGAN_L1_1e2_Edge_MONAI
start making inference...
processing image ID 3565...
processing image ID 3566...
processing image ID 3569...
processing image ID 3572...
processing image ID 3574...
processing image ID 3578...
processing image ID 3579...
processing image ID 3580...
processing image ID 3581...
processing image ID 3583...
MAE of 10 testing images: 192.31
min MAE of 10 testing images: 159.80
max MAE of 10 testing images: 256.71
STD of 10 testing images: 28.21
current experiment name: cGAN_L1_100
start making inference...
processing image ID 3565...
processing image ID 3566...
processing image ID 3569...
processing image ID 3572...
processing image ID 3574...
processing image ID 3578...
processing image ID 3579...
processing image ID 3580...
processing image ID 3581...
processing image 

In [11]:
import SimpleITK as sitk
gt = sitk.ReadImage('/data/SkullSyn/bone_data/label/3565_real_CT.nii.gz')
data = sitk.GetArrayFromImage(gt)
msk = extract_skull_mask(data, 400, 1).astype('uint8')
sitk_msk = sitk.GetImageFromArray(msk)
sitk_msk.CopyInformation(gt)
sitk.WriteImage(sitk_msk, '/data/SkullSyn/bone_data/label/3565_mask.nii.gz')
print('done')


done


In [13]:
import SimpleITK as sitk
gt = sitk.ReadImage('/data/SkullSyn/results/JMI/Save/cGAN_L1_1e2_Edge_MONAI/3565_fake_CT.nii.gz')
data = sitk.GetArrayFromImage(gt)
msk = extract_skull_mask(data, 400, 1).astype('uint8')
sitk_msk = sitk.GetImageFromArray(msk)
sitk_msk.CopyInformation(gt)
sitk.WriteImage(sitk_msk, '/data/SkullSyn/results/JMI/Save/cGAN_L1_1e2_Edge_MONAI/3565_vis_mask.nii.gz')
print('done')


done
