In [1]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from PIL import Image


from utils.Unet2 import Unet2
from utils.vgg16 import VGG16
from utils.visualizer_slider import visualizer_slider
from utils.segmentation_utils import *
from utils.utils import *


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading segmentation model (UNET)
model_seg = Unet2(input_nc = 3, output_nc= 1, ngf = 32, bilinear=True)
state_dict = torch.load('params/segmentation_net.pth', map_location=str(device)) #
model_seg.load_state_dict(state_dict)

# Loading classification model (VGG)
model_class = VGG16()
state_dict = torch.load('params/classifier_net_0.8854F.pth', map_location=str(device)) #
model_class.load_state_dict(state_dict)


<All keys matched successfully>

In [3]:
ds_path = 'ds/Evaluation'
subdir = [join(ds_path, f) for f in os.listdir(ds_path) if os.path.isdir(join(ds_path, f))]

file_list = []
for f in subdir:
    file_list += list_items(f, ".bmp")

In [76]:
total_tp, total_fp, total_fn = 0, 0, 0

for f_name in file_list:
    im_bmp = np.array(Image.open(f'{f_name[:-4]}_n.png'))
    im_mask = np.load(f_name[:-3]+'npy')

    im_bmp = im_bmp / 255.0
    slider = visualizer_slider(im_bmp, crop_size=512)
    slider.make_tiles()
    im_lst = torch.FloatTensor(np.array(slider.arr)).permute(0, 3, 1, 2)
        
    ################Segmentation model#####################
    for i in range(im_lst.shape[0]):
        data = im_lst[i:i + 1]
        generated = model_seg(data)
        generated = generated[:, 0:1, :, :]
        
        generated[generated >= .5] = 1
        generated[generated < .5] = 0
        
        if device == 'cuda':
            generated = generated.cpu()

        slider.arr[i] = generated[0].permute(1, 2, 0).data.numpy()

    slider.recover_mask()
    ################Segmentation model#####################
    
    
    #Extract segmentation patches for classification
    seg_imgs, seg_im_labeled, prop_seg, prop_gt = extract_segmentation(im_bmp, slider.recvored_mak[:,:,0], im_mask, crop_size=120, j=10, n =7)    
    
    ################Classification model#####################
    tp, fp, fn = 0, 0, 0
    for i, (index, outs) in enumerate(prop_seg.iterrows()):
        class_input = torch.FloatTensor(seg_imgs[i:i+1]).permute(0, 3, 1, 2)
        class_out = model_class(class_input).round().item()

        if class_out == 0:
            seg_im_labeled[seg_im_labeled==index+1] = 0
        elif class_out == 1:
            counter+=1
            flag = 1
            for index_gr, p_gt in prop_gt.iterrows():
                x1, y1, x2, y2, dist_x, dist_y, pred = validattion_func(prop_seg.iloc[i], p_gt)

                if pred == 1:
                    tp += 1
                    flag = 0
                    break
            if flag == 1:
                fp += 1

    fn = len(prop_gt)-tp
    print(f'image name: {f_name[-10:]},  TP: {tp}, FP: {fp}, FN: {fn}')
    ################Classification model#####################
    
    ################Save output images#####################
    seg_im_labeled[seg_im_labeled>0] = 1
    Image.fromarray((seg_im_labeled * 255.0).astype(np.uint8)).save(f'outs/{f_name[-10:-3]}png', format="PNG", quality=100, subsampling=0)
    ################Save output images#####################

    total_tp += tp
    total_fp += fp
    total_fn += fn

image name: A01_04.bmp,  TP: 11, FP: 0, FN: 4
image name: A01_06.bmp,  TP: 4, FP: 0, FN: 0
image name: A01_09.bmp,  TP: 8, FP: 1, FN: 4
image name: A03_01.bmp,  TP: 9, FP: 1, FN: 1
image name: A03_04.bmp,  TP: 11, FP: 0, FN: 3
image name: A03_00.bmp,  TP: 17, FP: 0, FN: 2
image name: A00_08.bmp,  TP: 2, FP: 1, FN: 1
image name: A00_00.bmp,  TP: 6, FP: 0, FN: 0
image name: A02_00.bmp,  TP: 4, FP: 2, FN: 0
image name: A02_01.bmp,  TP: 2, FP: 2, FN: 0
image name: A02_03.bmp,  TP: 1, FP: 2, FN: 0
image name: A02_07.bmp,  TP: 2, FP: 0, FN: 2
image name: A04_03.bmp,  TP: 7, FP: 2, FN: 7
image name: A04_09.bmp,  TP: 4, FP: 0, FN: 1
image name: A04_07.bmp,  TP: 5, FP: 0, FN: 0
image name: H03_00.bmp,  TP: 17, FP: 1, FN: 2
image name: H03_01.bmp,  TP: 9, FP: 1, FN: 1
image name: H03_04.bmp,  TP: 12, FP: 2, FN: 1
image name: H04_03.bmp,  TP: 9, FP: 2, FN: 0
image name: H04_07.bmp,  TP: 5, FP: 3, FN: 0
image name: H04_09.bmp,  TP: 4, FP: 0, FN: 0
image name: H02_01.bmp,  TP: 2, FP: 1, FN: 0
image

In [77]:
R = total_tp / (total_tp + total_fn)
P = total_tp / (total_tp + total_fp)
F = 2 * ((R * P) / (R + P))
print(f'TP: {total_tp}  |  FP: {total_fp}   |  FN: {total_fn}   |   F: {F:.4f}   |   R: {R}   |   P: {P}')

TP: 188  |  FP: 29   |  FN: 36   |   F: 0.8526   |   R: 0.8392857142857143   |   P: 0.8663594470046083
