In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from models.unet import UNet

import numpy as np
from torchvision import transforms
import torch

from utils.dataset_full import LandmarksDataset, ToTensor

In [3]:
train_lungs = open("train_images_lungs.txt",'r').read().splitlines()
train_heart = open("train_images_heart.txt",'r').read().splitlines()

test_lungs = open("test_images_lungs.txt",'r').read().splitlines() #+ train_lungs
test_heart = open("test_images_heart.txt",'r').read().splitlines() #+ train_heart

test_mont = [image for image in test_lungs if "MCU" in image]
test_shen = [image for image in test_lungs if "CHN" in image]
test_jsrt = [image for image in test_heart if "JP" in image]
test_pad = [image for image in test_heart if not "JP" in image]

dataset_mont = LandmarksDataset(test_mont,
                           img_path = "../Chest-xray-landmark-dataset/Images",
                           label_path = "../Chest-xray-landmark-dataset/landmarks",
                           organ = 'L',
                           transform = transforms.Compose([ToTensor()])
                           )

dataset_shen = LandmarksDataset(test_shen,
                           img_path = "../Chest-xray-landmark-dataset/Images",
                           label_path = "../Chest-xray-landmark-dataset/landmarks",
                           organ = 'L',
                           transform = transforms.Compose([ToTensor()])
                           )


dataset_jsrt = LandmarksDataset(test_jsrt,
                           img_path = "../Chest-xray-landmark-dataset/Images",
                           label_path = "../Chest-xray-landmark-dataset/landmarks",
                           organ = 'LHC',
                           transform = transforms.Compose([ToTensor()])
                           )

dataset_pad = LandmarksDataset(test_pad,
                           img_path = "../Chest-xray-landmark-dataset/Images",
                           label_path = "../Chest-xray-landmark-dataset/landmarks",
                           organ = 'LH',
                           transform = transforms.Compose([ToTensor()])
                           )

device = 'cuda:0'

In [4]:
unet_full_L = UNet(n_classes=1).to(device)
unet_full_L.load_state_dict(torch.load("weights/UNet_HT/UNET_HT_L_BOTH/bestDice.pt", map_location=device))
unet_full_L.eval()

unet_full_LH = UNet(n_classes=2).to(device)
unet_full_LH.load_state_dict(torch.load("weights/UNet_HT/UNET_HT_LH_FULL/bestDice.pt", map_location=device))
unet_full_LH.eval()

unet_full_LHC = UNet(n_classes=3).to(device)
unet_full_LHC.load_state_dict(torch.load("weights/UNet_HT/UNET_HT_LHC_FULL/bestDice.pt", map_location=device))
unet_full_LHC.eval()

unet_LH = UNet(n_classes=2).to(device)
unet_LH.load_state_dict(torch.load("weights/UNet_HT/UNET_HT_LH_STRICT/bestDice.pt", map_location=device))
unet_LH.eval()

unet_LHC = UNet(n_classes=3).to(device)
unet_LHC.load_state_dict(torch.load("weights/UNet_HT/UNET_HT_LHC_STRICT/bestDice.pt", map_location=device))
unet_LHC.eval()

print('Model loaded')

Model loaded


In [5]:
model_list = [unet_full_L, unet_LH, unet_full_LH, unet_LHC, unet_full_LHC]
model_names = ['L (Full)', 'LH (Strict)', 'LH (Full)', 'LHC (Strict)', 'LHC (Full)']

datasets = [dataset_mont,dataset_shen, dataset_pad,dataset_jsrt]
dataset_names = ["Montgomery (L)", "Shenzhen (L)", "Padchest (LH)", "JSRT (LHC)"]

In [6]:
from medpy.metric import dc, hd
import cv2

def getDenseMask(RL, LL, H = None, CLA1 = None, CLA2 = None, imagesize = 1024):
    img = np.zeros([1024,1024])
    imgcla = np.zeros([1024,1024])
    
    RL = RL.reshape(-1, 1, 2).astype('int')
    LL = LL.reshape(-1, 1, 2).astype('int')

    img = cv2.drawContours(img, [RL], -1, 1, -1)
    img = cv2.drawContours(img, [LL], -1, 1, -1)
    
    if H is not None:
        H = H.reshape(-1, 1, 2).astype('int')
        img = cv2.drawContours(img, [H], -1, 2, -1)
        
    if CLA1 is not None:
        CLA1 = CLA1.reshape(-1, 1, 2).astype('int')
        img = cv2.drawContours(img, [CLA1], -1, 3, -1)
    
    if CLA2 is not None:
        CLA2 = CLA2.reshape(-1, 1, 2).astype('int')
        img = cv2.drawContours(img, [CLA2], -1, 3, -1)
    
    return img

def evalImageMetrics(output, target):
    dcp = dc(output == 1, target == 1)

    try:
        dcc = dc(output == 2, target == 2)
    except:
        dcc = -1
    
    try:
        dccla = dc(output == 3, target == 3)
    except:
        dccla = -1

    hdp = hd(output == 1, target == 1)    
    
    try:
        hdc = hd(output == 2, target == 2)
    except:
        hdc = -1

    try:
        hdcla = hd(output == 3, target == 3)
    except:
        hdcla = -1
            
    return [dcp, dcc, dccla, hdp, hdc, hdcla]

In [7]:
import pandas as pd

results = pd.DataFrame()

for k in range(0,4):
    dataset = datasets[k]

    for i in range(0, len(dataset.images)):
        print('\r',dataset_names[k], i+1,'of', len(dataset.images),end='')
        with torch.no_grad():
            sample = dataset[i]

            data, target = sample['image'], sample['landmarks']
            data = torch.unsqueeze(data, 0).to(device)                
            target =  target.reshape(-1,2).numpy()
            
            RL = target[:44] * 1024
            LL = target[44:94] * 1024

            if k > 1:
                H = target[94:120] * 1024
            else:
                H = None
            
            if k == 3:
                CLA1 = target[120:143] * 1024
                CLA2 = target[143:] * 1024
            else:
                CLA1 = None
                CLA2 = None

            targetseg = getDenseMask(RL, LL, H, CLA1, CLA2)

            for j in range(0, len(model_list)):
                output = model_list[j](data)

                sigmoid = torch.sigmoid(output) > 0.5

                seg = torch.zeros([1024,1024])

                seg[sigmoid[0,0,:,:] > 0.5] = 1
                if j > 0:
                    seg[sigmoid[0,1,:,:] > 0.5] = 2
                if j > 2:
                    seg[sigmoid[0,2,:,:] > 0.5] = 3

                metrics = evalImageMetrics(seg.cpu().numpy(), targetseg)

                aux = pd.DataFrame([[i, dataset_names[k], model_names[j]] + metrics], 
                                     columns=['i', 'Dataset', 'Model', 
                                     'Dice Lungs','Dice Heart','Dice Cla','HD Lungs','HD Heart','HD Cla'])
                results = results.append(aux, ignore_index = True)
    print('')

 Montgomery (L) 27 of 27
 Shenzhen (L) 78 of 78
 Padchest (LH) 27 of 27
 JSRT (LHC) 49 of 49


In [8]:
for d in dataset_names:
    print(d)
    sub = results[results['Dataset'] == d]
    del sub['i']
    if '(L)' in d:
        sub = sub[['Dice Lungs', 'HD Lungs', 'Model']]
    elif '(LH)' in d:
        sub = sub[['Dice Lungs', 'HD Lungs', 'Dice Heart', 'HD Heart', 'Model']]
    elif '(LHC)' in d:
        sub = sub[['Dice Lungs', 'HD Lungs', 'Dice Heart', 'HD Heart', 'Dice Cla', 'HD Cla', 'Model']]

    group = sub.groupby('Model').mean()
    display(group.iloc[[0,2,1,4,3],:])


Montgomery (L)


Unnamed: 0_level_0,Dice Lungs,HD Lungs
Model,Unnamed: 1_level_1,Unnamed: 2_level_1
L (Full),0.974303,46.856713
LH (Strict),0.95798,74.792915
LH (Full),0.976638,60.520413
LHC (Strict),0.912964,168.505556
LHC (Full),0.971801,72.612539


Shenzhen (L)


Unnamed: 0_level_0,Dice Lungs,HD Lungs
Model,Unnamed: 1_level_1,Unnamed: 2_level_1
L (Full),0.966638,78.657325
LH (Strict),0.958026,131.60959
LH (Full),0.965635,57.77645
LHC (Strict),0.912958,204.838943
LHC (Full),0.965567,77.100273


Padchest (LH)


Unnamed: 0_level_0,Dice Lungs,HD Lungs,Dice Heart,HD Heart
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
L (Full),0.95911,69.960019,0.0,-1.0
LH (Strict),0.961249,74.901471,0.935801,80.17867
LH (Full),0.960251,87.933438,0.929435,125.569459
LHC (Strict),0.895203,223.599099,0.874325,198.955925
LHC (Full),0.962616,66.060273,0.931028,87.78891


JSRT (LHC)


Unnamed: 0_level_0,Dice Lungs,HD Lungs,Dice Heart,HD Heart,Dice Cla,HD Cla
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
L (Full),0.949588,106.084423,0.0,-1.0,0.0,-1.0
LH (Strict),0.949487,105.598617,0.944026,58.419011,0.0,-1.0
LH (Full),0.950242,100.166029,0.938825,59.437755,0.0,-1.0
LHC (Strict),0.975013,78.050082,0.941694,82.230556,0.939006,28.327466
LHC (Full),0.976388,55.153278,0.94236,47.63171,0.939856,43.642498
