In [41]:
import pandas as pd
import os
import torch
from torch.utils.data import DataLoader
import albumentations as A

from datasets.lungdatasets import SchenzenMontgomeryLungSegmentationDataset
from datasets.lungdatasets import CheXpertLungSegmentationDataset
from models.unet import ResNetUNet
from utils.utils import bce_dice_loss, dice_metric
import numpy as np
from tqdm import tqdm
import glob
import cv2

CHEXPERT_TRAIN = '../CheXpert-v1.0-small/train.csv'
BASE_MASKS = './intermediate/out_lung_mask/'
BASE_IMG = './data/chexpert-cardio-nofinding/'
BASE_EXTRA = 'CheXpert-v1.0-small/train/'

In [42]:
def get_transforms(size, test = True):
    #TODO: Do test-time augmentation?
    if test:
        return A.Compose([
        A.Resize(height=size, width=size, p=1.0)
        ])
    return A.Compose([
        A.Resize(height=size, width=size, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.3),
        A.Transpose(p=0.3),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit=45, p=0.3),
    ])

In [54]:
#Find Min skipping 0
def find_min(arr):
    min_val = 1000
    for idx, value in enumerate(arr):
        if value < min_val and value != 0:
            min_val = value
    return min_val

def find_chest_width_image(img,post_process=True):
    if post_process:
        img = post_process_image(img)
    start = np.argmax(img[:,:,1],axis=1)
    end = np.argmax(img[:,::-1,1],axis=1)
    h,w,c = img.shape
    return find_min(start), w - find_min(end), w

def find_chest_width(path,post_process=True):
    img = cv2.imread(path)
    if post_process:
        img = post_process_image(img)
    start = np.argmax(img[:,:,1],axis=1)
    end = np.argmax(img[:,::-1,1],axis=1)
    h,w,c = img.shape
    return find_min(start), w - find_min(end), w

def post_process_image(img,hull = True):
    
    dst = img[:,:,0]
    
    #kernel = np.ones((3, 3), np.uint8)
    #dst = cv2.erode(dst, kernel,iterations= 3) 

    contours, hierarchy = cv2.findContours(dst, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    #create an empty image for contours
    img_contours = np.zeros(img.shape)
    # draw the contours on the empty image
    cs = [(c,cv2.contourArea(c)) for c in contours]
    cs.sort(key=lambda x:x[1])
    if hull:
        hulls = [cv2.convexHull(p[0]) for p in cs[-2:]]
        cv2.drawContours(img_contours, hulls, -1, (0,255,0), -1)
    else:
        contours2 = [p[0] for p in cs[-2:]]
        cv2.drawContours(img_contours, contours2, -1, (0,255,0), -1)
    return img_contours

def find_img(path):
    img = cv2.imread(path)
    return img.shape

In [63]:
import numpy as np
import cv2
from matplotlib.pyplot import imshow

ds_train_no_finding = CheXpertLungSegmentationDataset("./data/hand-label/nofinding.json", '../CheXpert-v1.0-small/train/'
                                                      , aug_transform=get_transforms(320)
                                                     , test = True)
ds_train_cardiomegaly = CheXpertLungSegmentationDataset("./data/hand-label/cardiomegaly-certain.json", '../CheXpert-v1.0-small/train/'
                                                        , aug_transform=get_transforms(320)
                                                       , test = True)

full_ds_chexpert = torch.utils.data.ConcatDataset([ds_train_no_finding, ds_train_cardiomegaly])

print(len(ds_train_no_finding), len(ds_train_cardiomegaly), len(full_ds_chexpert))
train_ds_chexpert,val_ds_chexpert = torch.utils.data.random_split(full_ds_chexpert, [len(full_ds_chexpert) - 100, 100], generator=torch.Generator().manual_seed(42))
sample = np.uint8(full_ds_chexpert[0][1].cpu().numpy() * 255)

ground_truth = {}
for img,mask,path in val_ds_chexpert:
    print(path)
    mask = np.uint8(mask *255)
    ground_truth[path] = find_chest_width_image((np.stack([mask,mask,mask])*255).transpose(1,2,0))

./data/hand-label/nofinding.json
./data/hand-label/cardiomegaly-certain.json
200 200 400
patient33169_study1_view1_frontal.jpg
patient10398_study3_view1_frontal.jpg
patient14763_study1_view1_frontal.jpg
patient26444_study1_view1_frontal.jpg
patient16357_study1_view1_frontal.jpg
patient06845_study3_view1_frontal.jpg
patient21750_study1_view1_frontal.jpg
patient31070_study2_view1_frontal.jpg
patient02031_study1_view1_frontal.jpg
patient15296_study6_view1_frontal.jpg
patient30984_study2_view1_frontal.jpg
patient26677_study1_view1_frontal.jpg
patient34405_study1_view1_frontal.jpg
patient18734_study1_view1_frontal.jpg
patient13853_study5_view1_frontal.jpg
patient30750_study1_view1_frontal.jpg
patient04039_study1_view1_frontal.jpg
patient10713_study5_view1_frontal.jpg
patient30231_study4_view1_frontal.jpg
patient07399_study1_view1_frontal.jpg
patient14768_study1_view1_frontal.jpg
patient25308_study1_view1_frontal.jpg
patient31890_study12_view1_frontal.jpg
patient22482_study3_view1_frontal.jp

In [62]:
a, b = torch.utils.data.random_split(range(200), [100, 100], generator=torch.Generator().manual_seed(42))
for i in a:
    print(i)

142
56
198
77
50
55
4
69
174
47
88
114
58
171
155
67
195
162
111
87
7
89
121
76
115
163
61
187
199
18
152
160
123
122
53
159
127
84
8
173
66
176
14
106
71
193
95
16
133
113
170
112
12
149
41
79
183
141
136
68
86
124
70
36
78
196
73
90
1
2
117
140
80
19
144
109
99
13
128
83
134
110
37
185
60
130
23
165
42
35
157
182
49
44
28
145
92
120
48
184


In [61]:
for i in range(10):
    #print(full_ds_chexpert[i][2])
    print(val_ds_chexpert[i][2])

patient33169_study1_view1_frontal.jpg
patient10398_study3_view1_frontal.jpg
patient14763_study1_view1_frontal.jpg
patient26444_study1_view1_frontal.jpg
patient16357_study1_view1_frontal.jpg
patient06845_study3_view1_frontal.jpg
patient21750_study1_view1_frontal.jpg
patient31070_study2_view1_frontal.jpg
patient02031_study1_view1_frontal.jpg
patient15296_study6_view1_frontal.jpg


In [5]:
len(val_ds_chexpert)

100

In [35]:
import os
import glob
import torch
from tqdm import tqdm
import cv2
import albumentations as A
import numpy as np
from datasets.lungdatasets import MEAN,STD
from models.unet import ResNetUNet

#TODO: Make out dir configurable with argparse
IMAGE_SIZE = 512

LUNG_MODEL_WEIGHTS = './intermediate/lung_mask_weights'
PATH = "./intermediate/out_lung_mask3/"

#TODO: Remove last absolute path
base_path = 'C:/Users/ignacio/workspace/stanford/cs230/CheXpert-v1.0-small/train/'
CHEXPERT_VALIDATION_BASE = './data/chexpert-cardio-nofinding'

paths = os.listdir(CHEXPERT_VALIDATION_BASE)
inference_transforms = A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0)])

def load_image(base_path, path):
    path = path.replace('_','/',2)
    img_path = base_path + path
    image = cv2.imread(img_path,0)
    image = cv2.merge([image,image,image])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    augmented = inference_transforms(image=image)
    image = augmented['image']
    image = A.Normalize(mean=MEAN, std=STD)(image=image)["image"]
    return torch.FloatTensor(image).unsqueeze(0)

model = ResNetUNet().cuda()

#best_weights = sorted(glob.glob(LUNG_MODEL_WEIGHTS + "/*"), key=lambda x: x[8:-5])[-1]

checkpoint = torch.load('./intermediate/lung_mask_weights/pretraining0.903464_.pth')
#checkpoint = torch.load('./intermediate/lung_mask_weights/afterpretraining0.941366_.pth')
#checkpoint = torch.load('./intermediate/lung_mask_weights/nopretraining0.914747_.pth')
model.load_state_dict(checkpoint['state_dict'])

model.eval()
predictions = {}

for p in tqdm(paths): 
    if p in ground_truth:
        img = load_image(base_path, p)
        data_batch = img.permute(0, 3, 1, 2).cuda()
        outputs = model(data_batch)

        out_cut = np.copy(outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0

        mask = ((out_cut[0].transpose(1, 2, 0) * 255).astype(np.uint8))[:,:,0]
        prediction = find_chest_width_image((np.stack([mask,mask,mask]).transpose(1,2,0)))
        predictions[p] = prediction
        #cv2.imwrite(PATH + p, (out_cut[0].transpose(1, 2, 0) * 255).astype(np.uint8))
    
    #print(np.stack([mask,mask,mask]).shape)


100%|████████████████████████████████████████████████████████████████████████████| 9547/9547 [00:06<00:00, 1487.52it/s]


# Find errors

In [36]:
print(len(predictions))
count = 0
for p in predictions:
    x0,x1,w = predictions[p]
    
    if (x1-x0)/w < 0.6:
        count += 1
        print(count, p)

100


In [37]:
#print(ground_truth['patient00235_study1_view1_frontal.jpg'])
#print(predictions['patient00235_study1_view1_frontal.jpg'])
def compare(a,b):
    x0,y0,w0 = a
    x1,y1,w1 = b
    
    #print('Sarasa', x0/w0,y0/w0)
    #print('Sarasa', x1/w1,y1/w1)
    #print('dG', ((y0-x0)/w0))
    #print('dP', ((y1-x1)/w1))
    #print('Error',abs(((y1-x1)/w1) - ((y0-x0)/w0)))
    return abs(((y1-x1)/w1) - ((y0-x0)/w0)), ((y0-x0)/w0), ((y1-x1)/w1)
#compare(ground_truth['patient00235_study1_view1_frontal.jpg'],predictions['patient00235_study1_view1_frontal.jpg'])

In [38]:
errors = []
for k in ground_truth:
    err,a,b = compare(ground_truth[k],predictions[k])
    if err > 0.3:
        print(k,err,a,b)
    else:
        errors.append(err)
errors = np.array(errors)

patient29379_study1_view1_frontal.jpg 0.44921875 0.28125 0.73046875


In [40]:
len(errors)

99

In [39]:
print('RMSE:', np.sqrt((errors ** 2).sum()/len(errors)))
print('Min:', errors.min())
print('Max:', errors.max())
print('Mean:', errors.mean())
print('Median:', np.median(errors))
print('STD:', errors.std())

RMSE: 0.031283594434731866
Min: 0.0003906249999999778
Max: 0.18476562500000004
Mean: 0.019511521464646464
Median: 0.013671875
STD: 0.024453298568729475
