In [None]:
import os
import SimpleITK as sitk 
import pandas as pds 
import numpy as np 
from math import ceil 
import pickle
from tqdm import tqdm
from collections import Counter
from tools import resample_multiprocesses

import sklearn.metrics as skMetrics
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
csvFile = pds.read_csv(os.path.join("..", "data", "tapvc_info.csv"))
tapvcInfo = {}
patientNameList = []
for fname in os.listdir(os.path.join("..", "data", "TAPVC", "splits")):
    with open(os.path.join("..", "data", "TAPVC", "splits", fname), "rb") as f:
        patientNameList += pickle.load(f)
nameInCSV = list(csvFile["name"])

count = 0
for name in patientNameList:
    try:
        index = nameInCSV.index(name)
        rowItem = csvFile.iloc[index]
        storeItem = {
            "name": rowItem[-1],
            "label1": rowItem[-3],
            "label2": rowItem[-2],
            "feature": list(rowItem[1:-4])
        }
        tapvcInfo[name] = storeItem
        count = count + 1
    except:
        print("No ", name)

print(tapvcInfo)
print("There are {} items.".format(count))

with open(os.path.join("..", "data", "tapvc_info.pth"), "wb") as f:
    pickle.dump(tapvcInfo, f)

In [None]:
'''
Arrange Model
'''

fold_num = 5
dataset_name = 'TAPVC'
data_spacing = '1_1_1'
model_tag = '{}_coarse_{}'.format(dataset_name, "tversky_0.3_0.7")

model_path = os.path.join('..', 'record', '{}_coarse_tversky'.format(dataset_name))

for i in range(fold_num):
    curr_path = os.path.join('..', 'record', '{}_coarse_fold{}.pth'.format(dataset_name, i))
    if os.path.exists(curr_path) == False:
        os.mkdir(curr_path)
        os.mkdir(os.path.join(curr_path, 'data'))
        os.mkdir(os.path.join(curr_path, 'model'))

    os.link(os.path.join(model_path, '{}_fold{}.pth'.format(model_tag, i), 'model', 'best_model.pth'), os.path.join(curr_path, 'model', 'best_model.pth'))


In [None]:
'''
all data 
PA:1
'''

# data_path = os.path.join('..', 'data', 'Resampled')
# target_path = os.path.join('..', 'data', 'all_data')

# data_path = os.path.abspath(data_path)
# target_path = os.path.abspath(target_path)

# if os.path.exists(target_path) == False:
#     os.mkdir(target_path)

# for s in os.listdir(data_path):
#     if os.path.exists(os.path.join(target_path, s)) == False:
#         os.mkdir(os.path.join(target_path, s))

#     PV = sitk.ReadImage(os.path.join(data_path, s, 'PV.nii.gz'))
#     LA = sitk.ReadImage(os.path.join(data_path, s, 'LA.nii.gz'))

#     spacing = PV.GetSpacing()
#     PV = sitk.GetArrayFromImage(PV)
#     LA = sitk.GetArrayFromImage(LA)

#     label = np.zeros_like(PV)
#     label[PV.astype('bool')] = 1
#     label[LA.astype('bool')] = 2

#     label = label.astype('uint8')
#     label = sitk.GetImageFromArray(label)
#     label.SetSpacing(spacing)

#     sitk.WriteImage(label, os.path.join(target_path, s, 'mask.nii.gz'))
#     os.symlink(os.path.join(data_path, s, 'image.nii.gz'), os.path.join(target_path, s, 'im.nii.gz'))


    

In [None]:
# '''
# data information
# '''

data_path = os.path.join('..', 'data', 'TAPVC', 'all_data_0.35_0.35_0.625')

data = {
    'fname':[],
    'size_x':[],
    'size_y':[],
    'size_z':[],
    'spacing_x':[],
    'spacing_y':[],
    'spacing_z':[],
    'I_max':[],
    'I_min':[],
    'I_mid':[],
    'I_mean':[],
    'I_var':[]
}

for fname in tqdm(os.listdir(data_path)):
    image = sitk.ReadImage(os.path.join(data_path, fname, 'im.nii.gz'))
    label = sitk.ReadImage(os.path.join(data_path, fname, 'mask.nii.gz'))

    if image.GetSize() != label.GetSize():
        print(fname)

    size_x, size_y, size_z = image.GetSize()
    spacing_x, spacing_y, spacing_z = image.GetSpacing()

    image = sitk.GetArrayFromImage(image)
    label = sitk.GetArrayFromImage(label)

    label = label.astype('bool')
    label_pixels = image[label]
    label_pixels = np.sort(label_pixels)
    l = int(len(label_pixels) * 0.05)
    label_pixels = label_pixels[l:(len(label_pixels) - l)]

    data['fname'].append(fname)
    data['size_x'].append(size_x)
    data['size_y'].append(size_y)
    data['size_z'].append(size_z)
    data['spacing_x'].append(spacing_x)
    data['spacing_y'].append(spacing_y)
    data['spacing_z'].append(spacing_z)
    data['I_max'].append(label_pixels.max())
    data['I_min'].append(label_pixels.min())
    data['I_mid'].append(np.median(label_pixels))
    data['I_mean'].append(label_pixels.mean())
    data['I_var'].append(label_pixels.var())

data = pds.DataFrame(data)
data.to_csv('data_inform.csv')


In [None]:
# test_fname = []

# for fname in os.listdir('../data/7.4newdata'):
#     test_fname.append(fname)

# with open(os.path.join('../data/splits', 'test.pth'), 'wb') as f:
#     pickle.dump(test_fname, f)


In [None]:
# '''
# Image Resample
# '''

# data_path = os.path.join('..', 'data', 'TAPVC', 'all_data')
# target_path = os.path.join('..', 'data', 'TAPVC', 'all_data_0.35_0.35_0.625')

# if os.path.exists(target_path) == False:
#     os.mkdir(target_path)

# path_list = []
# for fname in os.listdir(data_path):
#     path_list.append(os.path.join(data_path, fname))

# resample_multiprocesses.Resample(path_list, target_path, (0.35,0.35,0.625), 8)




In [None]:
# '''
# data split
# '''

# fold_num = 5

# tapvcInfo = {}
# with open(os.path.join("..", "data", "tapvc_info.pth"), "rb") as f:
#     tapvcInfo = pickle.load(f)

# fname_list = list(tapvcInfo.keys())
# label1List = []
# label0List = []

# for patientName in fname_list:
#     if 0 == tapvcInfo[patientName]["label2"]:
#         label0List.append(patientName)
#     else:
#         label1List.append(patientName)

# print(len(label0List), len(label1List))

# target_path = os.path.join('..', 'data', 'TAPVC', 'splits_cls')
# # fname_list = sorted(fname_list, key = lambda x: x)

# if os.path.exists(target_path) == False:
#     os.mkdir(target_path)

# fold_num0_list = []
# fold_num1_list = []
# for i in range(fold_num):
#     if i == fold_num - 1:
#         fold_num0_list.append(len(label0List) - sum(fold_num0_list))
#         fold_num1_list.append(len(label1List) - sum(fold_num1_list))
#     else:
#         fold_num0_list.append(ceil(len(label0List) / fold_num))
#         fold_num1_list.append(ceil(len(label1List) / fold_num))



# label0List = np.array(label0List)
# label1List = np.array(label1List)
# index0 = list(range(len(label0List)))
# index1 = list(range(len(label1List)))
# np.random.shuffle(index1)
# np.random.shuffle(index0)

# for i in range(len(fold_num0_list)):
#     num0 = fold_num0_list[i]
#     num1 = fold_num1_list[i]

#     start0 = sum(fold_num0_list[0:i])
#     end0 = start0 + num0

#     start1 = sum(fold_num1_list[0:i])
#     end1 = start1 + num1


#     f_list = list(label0List[start0:end0]) + list(label1List[start1:end1])

#     with open(os.path.join(target_path, 'fold{}.pth'.format(i)), 'wb') as f:
#         pickle.dump(f_list, f)

# for fname in os.listdir(os.path.join(target_path)):
#     with open(os.path.join(target_path, fname), 'rb') as f:
#         f_list = pickle.load(f)

#     print(fname, len(f_list))
#     print(f_list)

In [None]:
'''
coarse_label arrange
'''

coarse_model_tag = 'TAPVC_coarse_fold{}.pth'
num_model = 5
target_path = os.path.join('..', 'data', 'TAPVC_coarse')
if os.path.exists(target_path) == False:
    os.mkdir(target_path)

data_list = []
for i in range(num_model):
    source_path = os.path.join('..', 'record', coarse_model_tag.format(i), 'data')
    source_path = os.path.abspath(source_path)

    for fname in os.listdir(source_path):
        if fname.endswith('csv'):
            data_list.append(pds.read_csv(os.path.join(source_path, fname)))
            os.link(os.path.join(source_path, fname), os.path.join(target_path, "summary_{}.csv".format(i)))
            continue
        
        if os.path.exists(os.path.join(target_path, fname)) == False:
            os.mkdir(os.path.join(target_path, fname))
        os.link(os.path.join(source_path, fname, 'predict.nii.gz'), os.path.join(target_path, fname, 'predict.nii.gz'))
        os.link(os.path.join(source_path, fname, 'im.nii.gz'), os.path.join(target_path, fname, 'im.nii.gz'))
        try:
            os.link(os.path.join(source_path, fname, 'mask.nii.gz'), os.path.join(target_path, fname, 'mask.nii.gz'))
            os.link(os.path.join(source_path, fname, 'probmap.nii.gz'), os.path.join(target_path, fname, 'probmap.nii.gz'))
        except:
            pass

data_list = pds.concat(data_list)
data_list.to_csv(os.path.join(target_path, 'summary.csv'))


In [None]:
'''
ROI statics
'''

predict_path = os.path.join('..', 'data', 'TAPVC_fine', "TAPVC_fine_duc_ds_gatt_2_4_best")
data_path = os.path.join('..', 'data', 'TAPVC', 'all_data_0.35_0.35_0.625')
expand_num = (0,0,0)

def getROI(label):
    def findMargin(sum_list):
        for i, v in enumerate(sum_list):
            lower = i
            if v != 0:
                break

        sum_list.reverse()
        for i, v in enumerate(sum_list):
            upper = len(sum_list) - i
            if v != 0:
                break
                
        if upper < lower:
            return upper, lower
        else:
            return lower, upper

    margin_list = []
    for i in range(label.ndim):
        edge_view = np.swapaxes(label, 0, i)
        l = edge_view.shape[0]
        edge_view = edge_view.reshape((l, -1)).sum(axis=1)
        lower, upper = findMargin(list(edge_view))

        margin_list.append((lower, upper))

    return margin_list

data = {
    'id':[],
    'z':[],
    'y':[],
    'x':[],
    'rate':[],
}

for fid in tqdm(os.listdir(predict_path)):
    if fid.endswith('csv') == True:
        continue

    prediction = sitk.ReadImage(os.path.join(predict_path, fid, 'predict.nii.gz'))
    label = sitk.ReadImage(os.path.join(data_path, fid, 'mask.nii.gz'))

    prediction = sitk.GetArrayFromImage(prediction)
    label = sitk.GetArrayFromImage(label)

    prediction = prediction.astype('bool').astype('int')
    label = label.astype('bool').astype('int')

    if prediction.shape != label.shape:
        print(fid)

    margin_list = getROI(prediction)
   
    new_margin_list = []
    for i, margin in enumerate(margin_list):
        lower, upper = margin

        lower = max(0, lower - expand_num[i])
        upper = min(prediction.shape[i], upper + expand_num[i])

        new_margin_list.append((lower, upper))

    cropped_prediction = label[
        new_margin_list[0][0]: new_margin_list[0][1],
        new_margin_list[1][0]: new_margin_list[1][1],
        new_margin_list[2][0]: new_margin_list[2][1],
    ]

    rate = cropped_prediction.sum() / label.sum()

    data['id'].append(fid)
    data['z'].append(cropped_prediction.shape[0])
    data['y'].append(cropped_prediction.shape[1])
    data['x'].append(cropped_prediction.shape[2])
    data['rate'].append(rate)

data = pds.DataFrame(data)
data.to_csv('ROI_inform.csv')

In [None]:
'''
Fine label arrange
'''
coarse_model_tag = 'TAPVC_classify_softmax_resin_cam0.1_v3_fold{}.pth'
num_model = 5
target_path = os.path.join('..', 'data', 'TAPVC_classify_softmax_resin_cam0.1_v3')
if os.path.exists(target_path) == False:
    os.mkdir(target_path)

data_list = []
for i in range(num_model):
    source_path = os.path.join('..', 'record', coarse_model_tag.format(i), 'data')
    source_path = os.path.abspath(source_path)

    for j, fname in enumerate(os.listdir(source_path)):
        if fname.endswith('csv'):
            data_list.append(pds.read_csv(os.path.join(source_path, fname)))
            os.link(os.path.join(source_path, fname), os.path.join(target_path, "summary_{}.csv".format(i)))
            continue
        
        if os.path.exists(os.path.join(target_path, fname)) == False:
            os.mkdir(os.path.join(target_path, fname))
        
        for subfname in os.listdir(os.path.join(source_path, fname)):
            os.link(os.path.join(source_path, fname, subfname),
            os.path.join(target_path, fname, subfname))

data_list = pds.concat(data_list)
data_list.to_csv(os.path.join(target_path, 'summary.csv'))


In [None]:
# Generate Result Report

resultTag = "TAPVC_classify_softmax_resin_cam0.1_v3"
threshold = 0.5

referenceResult = {}
logAns = []
with open(os.path.join("..", "data", "tapvc_info.pth"), "rb") as f:
    referenceResult = pickle.load(f)

resultPath = os.path.join("..", "data", resultTag)
referenceList = []
resultList = []
for patientName in os.listdir(resultPath):
    if patientName.endswith(".csv"):
        continue
    referenceList.append(referenceResult[patientName]["label2"])
    with open(os.path.join(resultPath, patientName, "prob.pth"), "rb") as f:
        prob = pickle.load(f)
        print("{}_{}_{}_{}".format(patientName, referenceResult[patientName]["label2"], prob[0], prob[1]))
        prob = prob[1].item()
        resultList.append(prob)
    logAns.append("{} {} {}".format(patientName, referenceList[-1], resultList[-1]))
fpr, tpr, _ = skMetrics.roc_curve(referenceList, resultList, drop_intermediate=False)
auc = skMetrics.auc(fpr, tpr)

precision, recall, thresholds = skMetrics.precision_recall_curve(referenceList, resultList)

plt.figure()
plt.subplot(1,2,1)
plt.plot(fpr, tpr)
plt.plot([0,1], [0,1], linestyle='--')
plt.xlim([0.0,1.0])
plt.ylim([0.0,1.05])

plt.subplot(1,2,2)
plt.plot(recall, precision)
plt.plot([0,1], [0,1], linestyle='--')
plt.xlim([0.0,1.0])
plt.ylim([0.0,1.05])

plt.show()

binaryResultList = [i > threshold for i in resultList]

accuracy = skMetrics.accuracy_score(referenceList, binaryResultList)
f1 = skMetrics.f1_score(referenceList, binaryResultList)
precision = skMetrics.precision_score(referenceList, binaryResultList)
recall = skMetrics.recall_score(referenceList, binaryResultList)
print("Acc:{}\nAUC:{}\nF1:{}\nPrecision:{}\nRecall:{}\n".format(
    accuracy, auc, f1, precision, recall
))

In [None]:
# Generate all Result Report

def kROC(reference, prob, k = 0.01):
    reference = np.array(reference)
    prob = np.array(prob)

    fprList = []
    tprList = []
    for threshold in np.arange(0, 1.00001, k):
        fakePredict = prob > threshold
        tprList.append((fakePredict * reference).sum() / reference.sum())
        fprList.append(((1 - reference) * fakePredict).sum() / (1 - reference).sum())

    return fprList, tprList       

resultTagList = [
    "TAPVC_classify_softmax_resin_cam0.1_v3_None",
    # "TAPVC_classify_softmax_resin_cam0.1_v3_loss1",
    # "TAPVC_classify_softmax_resin_cam0.1_v3_loss2",
    "TAPVC_classify_softmax_resin_cam0.1_v3"]
resultLabelList = [
    "Baseline",
    # "Loss1",
    # "Loss2",
    "Loss1+Loss2",
]

plt.figure(figsize=(6.4,6.4))
for resultTag, resultLabel in zip(resultTagList, resultLabelList):
    referenceResult = {}
    logAns = []
    with open(os.path.join("..", "data", "tapvc_info.pth"), "rb") as f:
        referenceResult = pickle.load(f)

    resultPath = os.path.join("..", "data", resultTag)
    referenceList = []
    resultList = []
    for patientName in os.listdir(resultPath):
        if patientName.endswith(".csv"):
            continue
        referenceList.append(referenceResult[patientName]["label2"])
        with open(os.path.join(resultPath, patientName, "prob.pth"), "rb") as f:
            prob = pickle.load(f)
            prob = prob[1].item()
            resultList.append(prob)
        logAns.append("{} {} {}".format(patientName, referenceList[-1], resultList[-1]))
    fpr, tpr, _ = skMetrics.roc_curve(referenceList, resultList, drop_intermediate=False)
    # fpr, tpr = kROC(referenceList, resultList, k = 0.00001)
    plt.plot(fpr, tpr, label=resultLabel)

plt.legend()
plt.plot([0,1], [0,1], linestyle='--')
plt.xlim([0.0,1.0])
plt.xlabel("TPR")
plt.ylabel("FPR")
plt.ylim([0.0,1.05])
plt.title("ROC Curves of Baseline Method and Proposed Method")

# plt.show()