In [12]:
import numpy as np
import pdb
import numpy as np
from skimage.measure import label
from utils import segmentation_metrics
import torch
import SimpleITK as sitk

def get_metrics(segmentation, mask, n_class=4):
    results = np.zeros((n_class, 5))
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    for i in range(n_class):
        cur_mask = (mask == i+1).astype(np.int16)
        cur_segmentation = (segmentation == i+1).astype(np.int16)
        segmentation_itk = sitk.GetImageFromArray(cur_segmentation)
        mask_itk = sitk.GetImageFromArray(cur_mask)
        overlap_measures_filter.Execute(segmentation_itk, mask_itk)
        hausdorff_distance_filter.Execute(segmentation_itk, mask_itk)
        results[i, 0] = overlap_measures_filter.GetJaccardCoefficient()
        results[i, 1] = overlap_measures_filter.GetDiceCoefficient()
        results[i, 2] = hausdorff_distance_filter.GetHausdorffDistance()
        results[i, 3] = overlap_measures_filter.GetFalseNegativeError()
        results[i, 4] = overlap_measures_filter.GetFalsePositiveError()
    return np.mean(results, 0)

def multi_scale(images, if_zoom=True, if_flip=True):
    total_images = []
    total_flip_flag = []
    total_zoom_flag = []
    zoom_scale = [512, 544, 576, 608, 640]  #, 672, 704, 736,768]
    if not if_zoom:
        zoom_scale = [512]
    flip_scale = [[1, -1], [1, 1]]
    if not if_flip:
        flip_scale = [[1, 1]]
    for cur_zoom in zoom_scale:
        for cur_flip in flip_scale:
            cur_images = images.transpose(0, 2, 3, 1)
            new_images = []
            for cur_idx in range(cur_images.shape[0]):
                new_image = cur_images[cur_idx]
                cur_zoom0 = cur_zoom / float(new_image.shape[0])
                cur_zoom1 = cur_zoom / float(new_image.shape[1])
                new_image = cv2.resize(
                    new_image,
                    None,
                    fx=cur_zoom1,
                    fy=cur_zoom0,
                    interpolation=cv2.INTER_LINEAR)

                new_image = np.ascontiguousarray(
                    new_image[::cur_flip[0], ::cur_flip[1]])
                new_images.append(new_image)
            total_zoom_flag.append((1 / cur_zoom0, 1 / cur_zoom1))
            new_images = np.stack(new_images, 0).transpose(0, 3, 1, 2)
            total_flip_flag.append(cur_flip)
            total_images.append(new_images)
    return total_images, total_flip_flag, total_zoom_flag


def recover(images, total_flip_flag, total_zoom_flag):
    total_labels = []
    for cur_images, flip_flag, zoom_flag in zip(images, total_flip_flag,
                                                total_zoom_flag):
        new_labels = []
        for idx in range(cur_images.shape[0]):
            cur_img = cur_images[idx].transpose(1, 2, 0)
            cur_zoom0 = zoom_flag[0]
            cur_zoom1 = zoom_flag[1]
            new_label = cv2.resize(
                cur_img.astype(np.float32),
                None,
                fx=cur_zoom1,
                fy=cur_zoom0,
                interpolation=cv2.INTER_LINEAR)
            new_label = np.ascontiguousarray(
                new_label[::flip_flag[0], ::flip_flag[1]])
            new_label = new_label.transpose(2, 0, 1)
            new_labels.append(new_label)
        new_labels = np.stack(new_labels, 0)
        total_labels.append(new_labels)
    total_labels = np.stack(total_labels, 0)
    return np.mean(total_labels, 0)


def getLargestCC(segmentation, connectivity):
    labels = label(segmentation, connectivity=connectivity)
    unique, counts = np.unique(labels, return_counts=True)
    list_seg = list(zip(unique, counts))[1:]
    max_unique = 0
    max_count = 0
    for _unique, _count in list_seg:
        if max_count < _count:
            max_count = _count
            max_unique = _unique
    labels[np.where(labels != max_unique)] = 0
    return labels / max_unique


import argparse
import os
import time
import numpy as np
import cv2
import nibabel as nib
import shutil
import ntpath
import sys
import pdb
import logging

from importlib import import_module
import torch
from torch.backends import cudnn
from torch import optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms
from data_utils.torch_data import THOR_Data, get_cross_validation_paths, get_global_alpha
import data_utils.transforms as tr
from models.loss_funs import DependentLoss
from utils import setgpu, get_threshold, metric, segmentation_metrics

#################################################################
# ######################testing config###########################
gpus = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = gpus

DEVICE = torch.device("cuda" if True else "cpu")
data_path = "/root/workspace/Python3/SegTHOR/data/data_npy/"
test_flag = 0
train_files, test_files = get_cross_validation_paths(test_flag) 
model_flag = 'MTL_WMCE'
net_name = 'DenseUNet121'
loss_name = 'CombinedLoss'
if_fpf = True
saved_checkpoint = '45.ckpt'
dynamic_threshold = [0.2398, 0.2151, 0.1941, 0.1636]#None
precise_net_path = os.path.join("SavePath",model_flag,net_name,str(test_flag),saved_checkpoint)
if_dependent = 1
if if_dependent == 1:
    alpha = get_global_alpha(train_files, data_path)
    alpha = torch.from_numpy(alpha).float().to(DEVICE)
    alpha.requires_grad = False
else:
    alpha = None
        
################################################################

from torch.nn import DataParallel
model = import_module('models.model_loader')
precise_net, loss = model.get_full_model(
    net_name, loss_name, n_classes=5, alpha=alpha)
c_loss = DependentLoss(alpha)
checkpoint = torch.load(precise_net_path)
precise_net.load_state_dict(checkpoint['state_dict'])
precise_net = precise_net.to(DEVICE)
precise_net = DataParallel(precise_net)
precise_net.eval()


################# first get the threshold #####################
composed_transforms_tr = transforms.Compose([
        tr.Normalize(mean=(0.12, 0.12, 0.12), std=(0.018, 0.018, 0.018)),
        tr.ToTensor2(5)
    ])
eval_dataset = THOR_Data(
    transform=composed_transforms_tr,
    path=data_path,
    file_list=test_files)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4)

predict_c = []
label_c = []
predictions = []
targets = []

for i, sample in enumerate(eval_loader):
    if i % 50 ==0:
        print(i)
    data = sample['image']
    target_c = sample['label_c']
    target_s = sample['label_s']
    label_c.append(target_c.numpy())
    targets.append(torch.argmax(target_s,1).numpy())
    data = data.to(DEVICE)
    target_c = target_c.to(DEVICE)
    target_s = target_s.to(DEVICE)
    
    cur_img2_scales, total_flip_flag, total_zoom_flag = multi_scale(
        data.cpu().numpy(), if_zoom=True, if_flip=False)
    predict_slabel_scales = []
    predict_clabel_scales = []
    for cur_img2 in cur_img2_scales:
        cur_img2 = torch.from_numpy(cur_img2.astype('float32'))
        with torch.no_grad():
            data = cur_img2.to(DEVICE)
            output_s, output_c = precise_net(data)
            if model_flag != 'MTL_WMCE':
                c_p = output_c
            else:
                _, c_p = c_loss(output_c, target_c)
        predict_clabel_scales.append(c_p.cpu().numpy())
        predicted_label_s = torch.softmax(output_s, 1)
        predicted_label_s = predicted_label_s.cpu().numpy()
        predict_slabel_scales.append(predicted_label_s)
    predict_clabel = np.mean(np.array(predict_clabel_scales), 0)

    if dynamic_threshold == None:
        predict_clabel = (predict_clabel > 0.5).astype('uint8')
    else:
        for i in range(predict_clabel.shape[1]):
            predict_clabel[:,i] = (predict_clabel[:,i] > dynamic_threshold[i] ).astype('uint8')
    recover_label = recover(predict_slabel_scales, total_flip_flag,
                            total_zoom_flag)
    predict_slabel = np.argmax(recover_label, 1)
    if model_flag!='SM' and if_fpf:
        for i in range(predict_clabel.shape[1]):
            for j in range(predict_clabel.shape[0]):
                if predict_clabel[j, i] == 0:
                    predict_slabel[j][np.where(predict_slabel[j] == i + 1)] = 0
    predictions.append(predict_slabel.astype('uint8'))
metric5 = get_metrics(np.concatenate(predictions, 0), np.concatenate(targets, 0))


print('the JaccardCoefficient is --> %4f' % metric5[0])
print('the DiceCoefficient    is --> %4f' % metric5[1])
print('the HausdorffDistance  is --> %4f' % metric5[2])
print('the FalseNegativeError is --> %4f' % metric5[3])
print('the FalseNegativeError is --> %4f' % metric5[4])
print(metric5)


the data length is 1910
0
50
100
150
200
the JaccardCoefficient is --> 0.776123
the DiceCoefficient    is --> 0.870466
the HausdorffDistance  is --> 55.274412
the FalseNegativeError is --> 0.129951
the FalseNegativeError is --> 0.128882
[ 0.77612284  0.87046613 55.27441216  0.12995103  0.1288817 ]


### ResUNet101 results

In [4]:
##########################ResUNet101 #################################
######################### single model ##################################
total_metric = np.array([[7.31430518e-01, 8.40647343e-01, 1.14178972e+02, 1.00342060e-01, 2.08221512e-01],
                         [0.72034801,     0.83231599,     121.5669528,    0.12256956,     0.20370237], 
                         [6.92351886e-01, 8.10962348e-01, 1.15613286e+02, 9.93265212e-02, 2.55561990e-01],
                         [7.25348248e-01, 8.34327775e-01, 1.09636641e+02, 9.69096626e-02, 2.15380337e-01]
                        ])
#########################     MTL      ##################################
total_metric = np.array([[ 0.75861127,    0.85884688,     90.99232376,    0.12306459,     0.15707381],
                         [ 0.72888126,    0.84025922,     125.48822938,   0.14224875,     0.17362341],
                         [7.17157967e-01, 8.31320742e-01, 1.46567672e+02, 1.32204524e-01, 1.97503320e-01],
                         [7.24491683e-01, 8.33975656e-01, 1.41680772e+02, 1.28239266e-01, 1.91245398e-01]
                        ])

#########################    MTL+FPF   ##################################
total_metric = np.array([[ 0.76999855,    0.86617058,     52.59950327,    0.15219046,     0.11296789],
                         [7.43979731e-01, 8.50141732e-01, 1.00585833e+02, 1.96857087e-01, 9.68112147e-02],
                         [ 0.75068629,    0.85405626,     65.38742573,    0.17099281,     0.11697367],
                         [ 0.74775378,    0.85067017,     64.31540543,    0.17185095,     0.1170839 ]
                        ])

#########################  MTL+FPF+DT  ##################################
total_metric = np.array([[ 0.78566751,    0.87569909,     56.14309948,    0.13483721  ,   0.11245999],
                        [7.60042191e-01,  8.60182037e-01, 1.02360403e+02, 1.81614731e-01, 9.30837886e-02],
                        [ 0.76659042,     0.86434276,     68.56163885,    0.15658578,     0.11282196],
                        [ 0.76666007,     0.86231065,     65.45028801,    0.15302318,     0.11479879]
                        ])

#########################  MTL + WMCE  ##################################
total_metric = np.array([[ 0.75339612 ,   0.85604584,     80.51317592,    0.10797239,     0.17677737],
                        [  0.75441814 ,   0.85544218,     116.19220473,   0.13268507,     0.1557596 ],
                        [7.01862859e-01,  8.21504855e-01, 1.31137754e+02, 1.05141096e-01, 2.37912911e-01],
                        [7.01862859e-01,  8.21504855e-01, 1.31137754e+02, 1.05141096e-01, 2.37912911e-01]
                        ])

#########################  MTL+FPF+WMCE  ################################
total_metric = np.array([[ 0.77424051,    0.86908471,     44.80134109,    0.13363046,     0.12731052],
                         [ 0.76220773,    0.86133249,     62.75743568,    0.17285679 ,    0.10102821],
                         [ 0.74695307,    0.85108103,     40.92081843,    0.15293728,     0.14244711],
                         [ 0.75343537,    0.85507376,     42.47629328,    0.15754154,     0.12261252]
                        ])

#####################   MTL+FPF+DT+WMCE  ################################
total_metric = np.array([[ 0.78638208,    0.87631012,     41.06253553,    0.12164051,     0.12558231],
                         [ 0.77705479,    0.87079372,     63.75076624,    0.15144445,     0.10556498],
                         [ 0.75522246,    0.85740417,     51.92375764 ,   0.10810269,     0.17186874],
                         [ 0.77438209,    0.86837714,     41.29515775,    0.12651826,     0.12975795]
                        ])


### DenseUNet121 results

In [13]:
########################## DenseUNet121 #################################
######################### single model ##################################
total_metric = np.array([[7.17490673e-01,   8.31883873e-01,  1.45306604e+02,  1.09846864e-01,  2.15679765e-01],
                         [7.26173721e-01,   8.35858499e-01,  1.18550526e+02,  1.17134829e-01,  2.03368553e-01],
                         [6.86296240e-01,   8.07879000e-01,  1.49835477e+02,  1.03727603e-01,  2.59314505e-01],
                         [7.18709849e-01,   8.29536663e-01,  1.10491102e+02,  1.04489176e-01,  2.18552513e-01]
                        ])
#########################     MTL      ##################################
total_metric = np.array([[  0.7505091,      0.85404752,      120.035943 ,     0.1302208,       0.15879294],
                        [ 0.73628023,       0.84336362,      96.31935153 ,    0.16284808 ,     0.14834658],
                        [  0.71036 ,        0.82676592,      108.05665879 ,   0.15303793  ,    0.1885888 ],
                        [7.23193417e-01,    8.33911501e-01,  1.37578859e+02,  1.31574970e-01,  1.87929912e-01]
                        ])

# #########################    MTL+FPF   ##################################
total_metric = np.array([[ 0.77825119,      0.87129557,      50.11546152,     0.1429695,       0.11314711],
                         [ 0.76269908,      0.86080058,      61.41174945,     0.18152781 ,     0.09067137],
                         [ 0.75032118 ,     0.85367757,      55.83202764 ,    0.16861209 ,     0.12164168],
                         [ 0.76623875 ,     0.86209232,      64.76080052 ,    0.14656549 ,     0.12326966]
                        ])

# #########################  MTL+FPF+DT  ##################################
total_metric = np.array([[ 0.77535875,      0.86908253,      46.68483234,     0.15364329 ,     0.10565098],
                         [ 0.75099711,      0.85308906,      46.74360535,     0.1984873 ,      0.08595137],
                         [ 0.74380985,      0.84894431,      53.78052821,     0.17959358,      0.11881974],
                         [ 0.76361954 ,     0.86000116,      46.6714574 ,     0.15710904 ,     0.11524336]
                        ])

# #########################  MTL + WMCE  ##################################
total_metric = np.array([[ 0.76938816,      0.86639259,      89.82851383,     0.12409367,      0.14258617],
                         [  0.75651879 ,    0.85731363,      124.18500598  ,  0.12582819 ,     0.158396  ],
                         [  0.71415536,     0.82967031,      112.136927  ,    0.14248751 ,     0.19470395],
                         [  0.7598401 ,     0.85950636,      110.23313698,    0.11648709,      0.15833222]
                        ])

# #########################  MTL+FPF+WMCE  ################################
total_metric = np.array([[ 0.77435325,      0.86975299,      52.4529059,      0.14886215 ,     0.11015945],
                         [ 0.76291228,      0.86182452,      63.20779552,     0.15999374,      0.11493437],
                         [ 0.73187873,      0.84027925 ,     66.34389246,     0.18452412  ,    0.13258422],
                         [ 0.76172713,      0.8610716 ,      52.72142478 ,    0.15530784 ,     0.11656709]
                        ])

# #####################   MTL+FPF+DT+WMCE  ################################
total_metric = np.array([[ 0.77612284,      0.87046613 ,     55.27441216,     0.12995103 ,     0.1288817 ],
                         [ 0.7686837,       0.86561119,      65.9472254,      0.15466388,      0.11299229],
                         [ 0.75067235,      0.85315483,      67.54198204,     0.15120168,      0.14183097],
                         [ 0.77159702,      0.86709052 ,     59.03450308,     0.12110735,      0.13945892]
                        ])

In [14]:
print(np.mean(total_metric, 0))

[ 0.76676898  0.86408067 61.94953067  0.13923099  0.13079097]
