In [1]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import data
from skimage import filters
from skimage import exposure
import cv2
from math import pi, ceil, floor, cos, sin
import os
import time
from ipynb.fs.full.unet_validation import *

plt.rcParams["figure.figsize"] = (10, 10)
# kill axis in rcParams
plt.rc('axes.spines',top=False,bottom=False,left=False,right=False);
plt.rc('axes',facecolor=(1,1,1,0),edgecolor=(1,1,1,0));
plt.rc(('xtick','ytick'),color=(1,1,1,0));

RESULT_PATH  = "../segmented/myresults/"
IMG_OUT_PATH = "../segmented/"

CELL_POINT_SIZE = 8
CELL_COLOR      = (255,60,0)
INTERNAL_COLOR  = (215,215,20)
EXTERNAL_COLOR  = (255, 255, 255)

UNET_VALIDATION = 'unet'
HSV_VALIDATION  = 'hsv'
NO_VALIDATION   = 'standalone'


# General functions

In [2]:
def save_results(image_name, method_name, cells,internalInfection,externalInfection, passTime, img, method_ext, path=""):
    
    global IMG_OUT_PATH, RESULT_PATH
    
    #corrige o path
    if (os.path.exists(IMG_OUT_PATH) == False):
        RESULT_PATH = RESULT_PATH.replace("../","./")
        IMG_OUT_PATH = IMG_OUT_PATH.replace("../","./")
    
    if path == None or path == False:
        path = "none"
    
    #save png
    out_path = IMG_OUT_PATH+path+"/"
    if os.path.exists(out_path) == False:
        os.mkdir(out_path)
    name = image_name.split("/")[-1]
    out = out_path + name.lower().replace(".jpg",method_ext) + ".png"
    plt.imsave(out,img,format='png')
    
    
    
    #save txt
    image_name_parts = image_name.split("/")
    image_name = image_name_parts[-1]
    image_name = image_name.lower().replace(".jpg","")
    f_name = RESULT_PATH + path + "/" + image_name + "_result.txt"
    
    if os.path.exists(RESULT_PATH+path) == False:
        os.mkdir(RESULT_PATH+path)
        
    if os.path.exists(f_name):
        f = open(f_name,"r+")
        
        lines = f.read().split("\n")
        for l in range(len(lines)):
            line = lines[l]
            if line.startswith(method_name):
                lines[l] = f"{method_name:20}\t{cells:7}\t{internalInfection:8}\t{externalInfection:8}\t{passTime:6}\t{out}"
                txt = "\n".join(lines)
                
                f.truncate(0)
                f.seek(0)
                f.write(txt)
                f.close()
                return
    else:
        f = open(f_name,"w")
        f.write(f"{'METHOD':20}\t{'  CELLS':7}\t{'INTERNAL':8}\t{'EXTERNAL':8}\t{'TIME':6}\tFILE_NAME\n")
    f.write(f"{method_name:20}\t{cells:7}\t{internalInfection:8}\t{externalInfection:8}\t{passTime:6}\t{out}\n")
    f.close()

In [3]:

def get_image(path, mode=0, size=(640,480), show=False):
    img = cv2.imread(path,mode)
    if mode == 1:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    if img is None:
        print("Não foi possível abrir a imagem:", path)
    res = cv2.resize(img,size)
    
    if show:
        plt.figure(figsize=(10,10))
        plt.imshow(res)
        plt.show()
    return res

    

def apply_blur(img, k=9):
    return cv2.GaussianBlur(img,(k,k),0)

def otsu(img,block=81,offset=0.3):
    return cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,block,offset)

def adaptative_thresh(img,block=121,offset=0):
    return cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,block,offset)

def fill_holes(img):
    contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    thresh_filled = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
    color = 255
    for i in range(len(contours)):
        cv2.drawContours(thresh_filled, contours, i, color, -1, cv2.LINE_8, hierarchy, 0)
    return thresh_filled

def connected_components(thresh):
    output = cv2.connectedComponentsWithStats(thresh, 4, cv2.CV_32S)
    (numLabels, labels, stats, centroids) = output
    
    if (np.max(labels) == 0):
        divider = 1
    else:
        divider = np.max(labels)
    
    # Map component labels to hue val, 0-179 is the hue range in OpenCV
    label_hue = np.uint8(179*labels/divider)
    blank_ch = 255*np.ones_like(label_hue)
    labeled_img = cv2.merge([label_hue, blank_ch, blank_ch])

    # Converting cvt to BGR
    labeled_img = cv2.cvtColor(labeled_img, cv2.COLOR_HSV2BGR)

    # set bg label to black
    labeled_img[label_hue==0] = 0

    labeled_img_rgb = cv2.cvtColor(labeled_img, cv2.COLOR_BGR2RGB)
    return labeled_img_rgb, numLabels, labels, stats, centroids


def erode(src,size):
    element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * size + 1, 2 * size + 1),
                                       (size, size))
    return cv2.erode(src, element)
    
def dilate(src,size):
    element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * size + 1, 2 * size + 1),
                                       (size, size))
    return cv2.dilate(src, element)

def get_principal_components(img):
    # Connected components with stats.
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=4)
    
    max_label, max_size = max([(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, nb_components)], key=lambda x: x[1])
    
    img2 = np.zeros(output.shape, dtype='uint8')
    img2[output == max_label] = 255
    return img2

def get_principal_component_RGB(img,mask):
    mask = get_principal_components(mask)
    new = np.zeros(img.shape, dtype='uint8')
    for x in range(img.shape[0]):
        for y in range(img.shape[1]):
            if mask[x,y] == 255:
                new[x,y] = img[x,y]
    return new

def color_average(img):
    lista = img.reshape(img.shape[0]*img.shape[1],3)
    lista = lista[lista != [0,0,0]]
    lista = lista.reshape(len(lista)//3,3)
    avg = lista.mean(axis=0)
    avg = np.uint8(avg)
    median = np.median(lista, axis=0)
    return avg, median

def hsv_average(img):
    hues = []
    lista = img.reshape(img.shape[0]*img.shape[1],3)
    for i in range(0,len(lista),3):
        if lista[i][0] != 0:
            hues.append(lista[i][0])
    avg = int(np.mean(hues))
    median = int(np.median(hues))
    return avg, median

def color_distance(rgb1, rgb2):
        rm = 0.5 * (rgb1[0] + rgb2[0])
        rd = ((2 + rm) * (rgb1[0] - rgb2[0])) ** 2
        gd = (4 * (rgb1[1] - rgb2[1])) ** 2
        bd = ((3 - rm) * (rgb1[2] - rgb2[2])) ** 2
        return (rd + gd + bd) ** 0.5

In [4]:
MIN_CELL_AREA = 500
MAX_CELL_AREA = 2700#2300
MIN_VALUE_TO_VALIDATE_CELL = 50

def thresh_is_valid(thresh_validation,x,y,w,h):
    
    
    cut = thresh_validation[y:y+h,x:x+w]
    if np.sum(cut) == 0:
        return False
    #pega apenas o maior componente do corte
    #cut = get_principal_components(cut)
    
#     if np.sum(cut)/(cut.shape[0]*cut.shape[1]) <= MIN_VALUE_TO_VALIDATE_CELL:
#         print(np.sum(cut)/(cut.shape[0]*cut.shape[1]))
#         plt.imshow(cut)
#         plt.show()
    value = np.sum(cut)/(cut.shape[0]*cut.shape[1])
    #print (value)
    return value > MIN_VALUE_TO_VALIDATE_CELL


# def equalize_images(img,sTo=80, vTo=150):
#     img_hsv = cv2.cvtColor(img,cv2.COLOR_RGB2HSV)

#     H, S, V = cv2.split(img_hsv)

#     dec = (sTo - np.median(S)).astype('int16')
#     S = S + dec
#     S = S.astype('uint8')

#     dec = (vTo - np.median(V)).astype('int16')
#     V = V + dec
#     V = V.astype('uint8')

#     hsv = cv2.merge([H, S, V])
#     rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

#     return rgb

unet = UNetPredict()
unetNewdataset = UNetPredict(newDataset=True)

def unet_validation_mask(img):
    mask = unet.predict(img)
    return mask



def equalize_images(img,sTo=80, vTo=150):
    img_hsv = cv2.cvtColor(img,cv2.COLOR_RGB2HSV)

    H, S, V = cv2.split(img_hsv)

    dec = (sTo - np.median(S)).astype('int16')
    S = S + dec
    S = S.astype('uint8')

    dec = (vTo - np.median(V)).astype('int16')
    V = V + dec
    V = V.astype('uint8')

    hsv = cv2.merge([H, S, V])
    rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

    return rgb

def hsv_validation_mask(img):

    img = equalize_images(img, sTo=80, vTo=150)

    hsv = cv2.cvtColor(img,cv2.COLOR_RGB2HSV)
    mask = cv2.inRange(hsv, (10,70,0), (20,100,150))

    mask = cv2.medianBlur(mask,5)
    mask = dilate(mask,1)
    mask = fill_holes(mask)
    return mask


def cut_circle(img,pos,radius):
    mask = np.zeros((img.shape[0],img.shape[1]))
    cv2.circle(mask, pos, radius, 1, -1)
    new = np.zeros(img.shape, dtype='uint8')
    for x in range(img.shape[0]):
        for y in range(img.shape[1]):
            if mask[x,y] > 0:
                new[x,y] = img[x,y]
    return new



def mark_cells(img, mask, components = None, estimate = True, validation_type='unet', max_cell_area=MAX_CELL_AREA):
    
    validation_step = False
    if validation_type != None and validation_type != False:
        validation_step = True
        if validation_type == UNET_VALIDATION:
            thresh_validation = unet_validation_mask(img)
        elif validation_type == HSV_VALIDATION:
            thresh_validation = hsv_validation_mask(img)
        elif validation_type == NO_VALIDATION:
            validation_step = False
        else:
            print("Invalid validation_type")
            validation_step = False
    
    #plt.imshow(thresh_validation)
    #plt.show()
    
    groupColor = 45
    
    #font = img.copy()
    if components == None:
        labeled_img_rgb, numLabels, labels, stats, centroids = connected_components(mask)
        #o primeiro componente é o fundo, deve ser ignorado sempre
        start = 1
    else:
        labeled_img_rgb, numLabels, labels, stats, centroids = components
        
        start = 0
    
    cellsCount = 0

    for i in range(start, numLabels):
        # extract the connected component statistics and centroid for
        # the current label
        x = stats[i, cv2.CC_STAT_LEFT]
        y = stats[i, cv2.CC_STAT_TOP]
        w = stats[i, cv2.CC_STAT_WIDTH]
        h = stats[i, cv2.CC_STAT_HEIGHT]
        #get area
        area = stats[i, cv2.CC_STAT_AREA]
        (cX, cY) = centroids[i]
        
        #print(minArea, area, maxArea) minArea=MIN_CELL_AREA, maxArea=MAX_CELL_AREA
        if MIN_CELL_AREA < area < max_cell_area or estimate == False:
            cellsCount += 1
            
            #print(cellsCount,end=" ")
            
            if validation_step and thresh_is_valid(thresh_validation,x,y,w,h) == False:
                continue
            
            pos = (int(cX), int(cY))
            
            cv2.circle(img, pos, CELL_POINT_SIZE, (0,0,0), -1)
            cv2.circle(img, pos, CELL_POINT_SIZE-1, CELL_COLOR, -1)            
            cv2.putText(img, str(cellsCount), (int(cX), int(cY)+CELL_POINT_SIZE+10), cv2.FONT_HERSHEY_SIMPLEX,
                        .4, (0,0,0),       1, cv2.LINE_AA)
            
            
            
        
        
        if max_cell_area < area and estimate == True:
            
            if validation_step and thresh_is_valid(thresh_validation,x,y,w,h) == False:
                continue
            
            groupColor += 5
            
            #estima quantas celulas no grupo
            est = ceil(area/max_cell_area)
            
            #imprime os marcadores de celulas dentro dos grupos
            #em colunas
            if w > 220:
                cols = 8
                margin = 40
            elif w > 120:
                cols = 6
                margin = 20
            else:
                cols = 3
                margin = 10
            xCell = x + margin
            yCell = y + 40
            for i in range(1,est+1):
                cv2.circle(img, (xCell, yCell), CELL_POINT_SIZE-1, (groupColor,0,groupColor), -1)
                xCell += CELL_POINT_SIZE + 10
                
                if i % cols == 0:
                    xCell = x + margin
                    yCell += CELL_POINT_SIZE + 10
            
            #imprime a id das celulas do grupo
            txt = "("+(",".join([str(q) for q in range(cellsCount+1,cellsCount+est+1,1)]))+")"
            
            cv2.putText(img, txt, (x+margin-10, y+20), cv2.FONT_HERSHEY_SIMPLEX,
                        .4, (0,0,0),       1, cv2.LINE_AA)
            cellsCount += est
            
            cv2.rectangle(img, (x,y), (x+w,y+h), (groupColor+2,0,groupColor+2), 2)
            
 
    return img, cellsCount

def break_big_groups(mask,_erode=15,_dilate=11):
    
    labeled_img_rgb, numLabels, labels, stats, centroids = connected_components(mask)


    for i in range(1, numLabels):
        x = stats[i, cv2.CC_STAT_LEFT]
        y = stats[i, cv2.CC_STAT_TOP]
        w = stats[i, cv2.CC_STAT_WIDTH]
        h = stats[i, cv2.CC_STAT_HEIGHT]
        area = stats[i, cv2.CC_STAT_AREA]
        (cX, cY) = centroids[i]

        cutY = y - 1 if y > 0 else y
        cutX = x - 1 if x > 0 else x
        cutXW = x+w+1
        cutYH = y+h+1            
            
        cut = mask[cutY:cutYH, cutX:cutXW]
        cut = cut.copy()

        if MAX_CELL_AREA < area:
            cut_bk = cut.copy()
            cut = get_principal_components(cut)

            #remove the group from original image
            cell_on_original_size = np.zeros( mask.shape ,dtype='uint8')
            cell_on_original_size[cutY:cutYH, cutX:cutXW] = cut
            mask[ cell_on_original_size == 255 ] = 0

            #work on group
            #cut = cv2.erode(cut, kernel, iterations=1)
            cut = erode(cut,_erode)
            cut = dilate(cut,_dilate)

            #return the eroded and marked group to original image
            cell_on_original_size = np.zeros( mask.shape ,dtype='uint8')
            cell_on_original_size[cutY:cutYH, cutX:cutXW] = cut
            mask += cell_on_original_size
    
    return mask


# Colors functions

In [5]:
def random_color():
    color = tuple(np.random.randint(1,256, size=3))
    if (color[0] < 100 and color[1] < 100 and color[2] < 100):
        return random_color()
    elif (color[0] > 230 and color[1] > 230 and color[2] > 230):
        return random_color()
    else:
        return (int(color[0]),int(color[1]),int(color[2]))
    

def list_colors(img):
    shape = img.shape
    if (len(shape) == 2):
        img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
    
    reshapedImg = img.reshape(1,shape[0]*shape[1],3)[0]
    colors = np.unique(reshapedImg, axis=0)
    colors = np.delete(colors,0,axis=0)
    return colors


def shuffle_colors(img):
    
    if len(img.shape) == 2:
        img = cv2.merge([img,img,img])

    colors = list_colors(img)
    labeled_img_rgb3 = img.copy()

    for color in colors:
        newColor = random_color()

        red, green, blue = labeled_img_rgb3.T 
        selected_color = (red == color[0]) & (green == color[1]) & (blue == color[2])
        labeled_img_rgb3[selected_color.T] = newColor
    return labeled_img_rgb3


def colorize(img):
    labeled_rgb, numLabels, labels, stats, centroids = connected_components(img)
    return shuffle_colors(labeled_rgb)

def equalize_color_image(img):
    hist,bins = np.histogram(img.flatten(),256,[0,256])
    cdf = hist.cumsum()
    cdf_normalized = cdf * float(hist.max()) / cdf.max()
    cdf_m = np.ma.masked_equal(cdf,0)
    cdf_m = (cdf_m - cdf_m.min())*255/(cdf_m.max()-cdf_m.min())
    cdf = np.ma.filled(cdf_m,0).astype('uint8')
    return cdf[img]

# Parasites functions

In [6]:

def hsv_par_mask(img):
    
    #range 1
    from_color=(95,40,0)
    to_color=(180,255,255)
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    mask = cv2.inRange(hsv, from_color, to_color)
    
    #range 2
#     from_color=(0,70,85)
#     to_color=(15,75,105)
    
#     mask2 = cv2.inRange(hsv, from_color, to_color)
    
    
#     mask = np.zeros(img.shape)
#     mask = cv2.bitwise_or(mask1,mask2)
    
    #mask = mask_conditional_blur(mask,3,3)
    #mask = mask_conditional_blur(mask,30,40)
    #mask = cv2.inRange(hsv, from_color, to_color)
    mask_dilated = mask.copy()
    #erode
    kernel = np.ones((3, 3), np.uint8)
    mask_dilated = cv2.erode(mask_dilated, kernel)
    #dilate
    kernel = np.ones((15, 15), np.uint8)
    mask_dilated = cv2.dilate(mask_dilated, kernel)

    return mask, mask_dilated


def gs_parasite_mask(gsImg):
    if (len(gsImg.shape) == 3):
        gsImg = cv2.cvtColor(gsImg,cv2.COLOR_RGB2GRAY)
    ret, thresh = cv2.threshold(gsImg, 75, 1, cv2.THRESH_BINARY_INV)
    return thresh

def remove_parasites(img,mask):
    maximum = np.max(mask)
    if maximum == 0:
        return img
    img[mask == maximum] = 0
    img = erode(img,2)
    img = dilate(img,2)
    return img

MIN_PARASITE_AREA = 20


def parasites_mark(img, parasite_mask, cell_mask, cell_components = None, contact_percent=0.15):
    
    
    labeled_rgb, numLabels, labels, stats, centroids = connected_components(parasite_mask)
    
    if cell_components is None:
        _, _, _, cell_stats, _ = connected_components(cell_mask)
    else:
        cell_mask = cv2.cvtColor(cell_components[0],cv2.COLOR_RGB2GRAY)
        cell_stats = cell_components[3]
    
    internal_count = 0
    external_count = 0
    
    
    for i in range(1, numLabels):
        area = stats[i, cv2.CC_STAT_AREA]
        #se for apenas pequenos pontos nao conta
        if MIN_PARASITE_AREA > area:
            continue
        
        x = stats[i, cv2.CC_STAT_LEFT]
        y = stats[i, cv2.CC_STAT_TOP]
        w = stats[i, cv2.CC_STAT_WIDTH]
        h = stats[i, cv2.CC_STAT_HEIGHT]
        (cX, cY) = centroids[i]
        
        color = EXTERNAL_COLOR
        c = 0
        for cell_stat in cell_stats:
            c += 1
            cellArea  = cell_stat[cv2.CC_STAT_AREA]
            cellX = cell_stat[cv2.CC_STAT_LEFT]
            cellY = cell_stat[cv2.CC_STAT_TOP]
            cellW = cell_stat[cv2.CC_STAT_WIDTH]
            cellH = cell_stat[cv2.CC_STAT_HEIGHT]
            
            if MIN_CELL_AREA < cellArea < MAX_CELL_AREA*10:
                #faz um recorte da celula e do parasita
                cell_cut                                       = np.zeros(parasite_mask.shape)
                cell_cut[cellY:cellY+cellW, cellX:cellX+cellH] = cell_mask[cellY:cellY+cellW, cellX:cellX+cellH]
                par_cut                = np.zeros(parasite_mask.shape)
                par_cut[y:y+h, x:x+w]  = parasite_mask[y:y+h, x:x+w]
                
                
                #verifica a superficie em contato
                contact = np.zeros(cell_cut.shape)
                contact[cv2.bitwise_and(cell_cut,par_cut) > 0] = 1
                contact_percent_calc = np.sum(contact) / area

                #se atingir o minimo esperado, classifica interno
                if contact_percent_calc >= contact_percent:
                    color = INTERNAL_COLOR
                    cv2.rectangle(img, (cellX-1,cellY-1), (cellX+cellW+2,cellY+cellH+2), (0,255,0), 1)
                    cv2.circle(img, (int(cX),int(cY)), 5, (240,240,0),-1)
                    break
        
        if color == INTERNAL_COLOR:
            internal_count += 1
        else:
            external_count += 1
        #caso nao esteja em contato com nenhuma celula, é externo
        cv2.rectangle(img, (x-1,y-1), (x+w+2,y+h+2), color, 1)
        
    return img, internal_count, external_count


# img = get_image(mode=1)
# par_mask = gs_parasite_mask(imgGS)
# img, internal, external = mark_parasites(img, par_mask, cell_mask)
# plt.imshow(img)
# plt.show()