In [1]:
'packages'
import os, glob, cv2
import numpy as np
from skimage import exposure,io, measure
from skimage.measure import label, regionprops
from skimage.morphology import binary_erosion, binary_dilation, disk, remove_small_objects
from skimage.util import montage
import math
import shutil
import pandas as pd
import scipy
from scipy import ndimage
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image

In [2]:
import session_info

session_info.show()

In [9]:
#------------------------------- comet preprocessing functions --------------------------
def read_files(path,identifier, verbose=True):
    '''
    function to read tif images from path
    Input: image path, identifier = file extension
    Return image file list
    '''
    filelist = sorted(
        glob.glob(os.path.join(path, '*'+identifier+'*'), recursive=True))
    if verbose:
        print(filelist)
    return filelist


def image_preprocess(image, degree = 0, flip = False):
    '''
    degree = image rotation, flip = flip image vertically
    Return processed image
    '''
    if flip:
        image = cv2.flip(image, 0) 
    if degree != 0:
        image = ndimage.rotate(image, degree)
    return image


def image_preview(filelist_in, degree, flip, num_images=5):
    '''
    function to preview images if needed
    default to show the first 5 original and preprocessed image
    '''
    assert len(filelist_in) >= num_images, print("Do not have enough images to preview.")
    if num_images > 0:
        for idx in range(0,num_images):
            img_path = filelist_in[idx]
            img = cv2.imread(img_path,0)
            img_original = img.copy()

            img = image_preprocess(img, degree, flip)

            fig, ax = plt.subplots(1,2, figsize=(10, 10), sharex=True, sharey=True)
            ax[0].imshow(img_original, cmap=plt.cm.gray)
            ax[0].autoscale(False)
            ax[0].axis('off')
            ax[0].set_title('Original') 
            ax[1].imshow(img, cmap=plt.cm.gray)
            ax[1].autoscale(False)
            ax[1].axis('off')
            ax[1].set_title('Processed Images')
        
        
def object_segment(filelist_in, degree, flip, crop_dim, binary_thresh = 50, 
                   min_area = 30, max_area = 5000):
    '''
    read the image 
    perform necessary filping/rotation
    perform binary thresholding to segment image and keep objects that are within reasonable size

    Input: image file list, minimum/maximum pixel area of comet objects
    Return array containing segmented image, original image, image path
    '''
    
    img_segment_all = []
    
    for idx in range(0,len(filelist_in)):
        img_path = filelist_in[idx]
        print(img_path)

        img = cv2.imread(img_path,0)
        img_original = img.copy()
        
        #perform preprocessing
        img = image_preprocess(img, degree, flip)
        
        #perform segmentation
        img = exposure.rescale_intensity(img)
        gray = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        _,thresh = cv2.threshold(gray, binary_thresh, 255, cv2.THRESH_BINARY)

        label_image = measure.label(thresh, connectivity=2, background=0)
        regions = regionprops(label_image)
        totalarea = cv2.countNonZero(label_image)
        labels = label_image.copy()
        
        fig, ax = plt.subplots(1,4, figsize=(15, 15), sharex=True, sharey=True)
        ax[0].imshow(img_original, cmap=plt.cm.gray)
        ax[0].autoscale(False)
        ax[0].axis('off')
        ax[0].set_title('Original') 
        ax[1].imshow(img, cmap=plt.cm.gray)
        ax[1].autoscale(False)
        ax[1].axis('off')
        ax[1].set_title('Processed Image')
        
        #filter images that contain huge total area (consider as bad intensity image)
        BAD_TOTAL_INT = 200000
        if(totalarea < BAD_TOTAL_INT):
            for region in regions:
                region_area = float(region.area)
                minr, minc, maxr, maxc = region.bbox
                new_minr, new_minc, new_maxr, new_maxc = crop_rect(minr, minc, maxr, maxc, crop_dim)
                try:
                    if not (min_area < region_area < max_area):
                        labels[np.isin(labels, region.label)] = 0
                        rect = mpatches.Rectangle((new_minc, new_minr), new_maxc - new_minc, new_maxr - new_minr,
                                      fill=False, edgecolor='gray', linewidth=1)
                        ax[2].add_patch(rect)

                    else:
                        rect = mpatches.Rectangle((new_minc, new_minr), new_maxc - new_minc, new_maxr - new_minr,
                                      fill=False, edgecolor='red', linewidth=3)
                        ax[2].add_patch(rect)
                        rect = mpatches.Rectangle((new_minc, new_minr), new_maxc - new_minc, new_maxr - new_minr,
                                              fill=False, edgecolor='red', linewidth=3)
                        ax[3].add_patch(rect)
                except Exception:
                    #Out of bounding box
                    pass

            img_segment_all.append([labels,img,img_path])
        else:
            print("Filtering Bad intensity image: ", img_path, 'totalarea < BAD_TOTAL_INT:', totalarea)
            
        #crop and filter comets
        
        labels = labels.astype(np.uint8)
        ax[2].imshow(thresh, cmap=plt.cm.gray)
        ax[2].autoscale(False)
        ax[2].axis('off')
        ax[2].set_title('Segmented Mask')
        ax[3].imshow(labels, cmap=plt.cm.gray, vmin=0, vmax=1)
        ax[3].autoscale(False)
        ax[3].axis('off')
        ax[3].set_title('Filtered Labels')
        fig.tight_layout()
        plt.show()
        
    return img_segment_all

def crop_rect(minr, minc, maxr, maxc, crop_dim):
    '''get rectangle bounding box for cropping'''
    h = crop_dim[0]
    w = crop_dim[1]
    new_maxc = None
    new_minc = None 
    new_maxr = None
    new_minr = None
    
    if((maxc - minc) < w):
        diff = int(w - (maxc - minc))
        new_maxc = int(maxc + (diff/2))
        new_minc = int(minc - (diff/2))
    if((maxr - minr < h)):
        diff = int(h - (maxr - minr))
        new_maxr = int(maxr + diff - 10)
        new_minr = int(minr - 10)
    return new_minr, new_minc, new_maxr, new_maxc
    
def object_info(filelist_in):
    '''
    Go through each segmented object and store object infos, 
    Return [file, img array ,segment mask array, label objectid, y coordinate,x coordinate, minr, minc, maxr, maxc]
    '''
    img_labels_all = []
    img_info_all = []
    for idx in range(0,len(filelist_in)):
        label_image = filelist_in[idx][0]
        regions = regionprops(label_image)
        
        labels = label_image.copy()

        for region in regions:
            y0, x0 = region.centroid
            minr, minc, maxr, maxc = region.bbox

            tmp_info=[filelist_in[idx][2], filelist_in[idx][1], filelist_in[idx][0], 
                      region.label, y0, x0, minr, minc, maxr, maxc] 
            img_info_all.append(tmp_info)

    return img_info_all

In [10]:
#------------------------------- comet segmentation and filtering functions --------------------------
def object_params(img_info_all, idx):
    '''
    To extract individual cell info
    '''
    
    file = img_info_all[idx][0]
    img = img_info_all[idx][1]
    labels = img_info_all[idx][2]
    objectid = img_info_all[idx][3]
    minr = img_info_all[idx][6]
    minc = img_info_all[idx][7]
    maxr = img_info_all[idx][8]
    maxc = img_info_all[idx][9]
    return file, img, labels, objectid, minr, minc, maxr, maxc

def object_cropping(image_in, minr, minc, maxr, maxc, crop_dim):
    '''
    To crop images to user crop dimension
    Return individual crop
    '''
    
    new_minr, new_minc, new_maxr, new_maxc = crop_rect(minr, minc, maxr, maxc, crop_dim)
    object_crop = image_in[new_minr:new_maxr, new_minc:new_maxc]

    return object_crop 

def get_crops(data, idx, crop_dim):
    '''
    get crop file name, label, and perform cropping
    '''
    
    file, img, labels, objectid, minr, minc, maxr, maxc = object_params(data, idx)

    cropped_label = object_cropping(labels, minr, minc, maxr, maxc, crop_dim)
    cropped_img = object_cropping(img, minr, minc, maxr, maxc, crop_dim)
    return file, labels, objectid, cropped_img, cropped_label


def filter_extra_labels_and_edge(data, crop_dim, verbose=True):
    '''
    First round of filtering
    Filter labels that appears in more than one crop and on edge
    Return list with filename, objectid, cropped image, and cropped segment image'''
    
    filter_labels_list=[]
    for idx in range(len(data)):

        file, labels, objectid, cropped_img, cropped_label = get_crops(data, idx, crop_dim)

        #filter images on edge
        if cropped_label.shape == crop_dim and cropped_img.shape == crop_dim:

            #find labels that appears in more than one crop
            filter_labels = np.unique(cropped_label)[2:]
            if (len(filter_labels) >= 1):
                for i in range(len(filter_labels)):
                    filters = [file,filter_labels[i]]
                    filter_labels_list.append(filters)
                    
    #remove filtered labels from list 
    print('Original number of objects:', len(data))
    for idx,info in enumerate(filter_labels_list):
        for ele in data[:]:
            if(ele[0] == info[0] and ele[3] == info[1]):
                data.remove(ele)
    
    #Filter edge images
    filelist_out = []
    for idx in range(len(data)):

        file, labels, objectid, cropped_img, cropped_label = get_crops(data, idx, crop_dim)

        if cropped_label.shape == crop_dim and cropped_img.shape == crop_dim:
            final=[file, objectid, cropped_img,cropped_label]
            filelist_out.append(final)

    print('Filtered Extra labels and edge images, new number of objects:', len(filelist_out))
    return filelist_out

In [2]:
#------------------------------- comet measurement functions --------------------------
def head_body_segment(cropped_img, cropped_label, lowerthresh, upperthresh, head_min = 30, body_min = 120, dilation=9):
    '''
    segment head and body in crop
    Perform erosion and dilation to the head segment
    Remove tiny objects from the segmentation
    Return: Crop, body mask, head mask, head labels, body labels 
    '''
    
    img = exposure.rescale_intensity(cropped_img)
    img = exposure.equalize_adapthist(img,clip_limit=0.008)
    gray = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    blur = cv2.GaussianBlur(gray,(5,5),0)

    #detect head and body
    _,head_thresh = cv2.threshold(blur, upperthresh,255, cv2.THRESH_BINARY)
    _,body_thresh = cv2.threshold(blur, lowerthresh,255, cv2.THRESH_BINARY)

    #perfrom dilation on head
    head_thresh = binary_erosion(head_thresh)
    head_thresh = binary_dilation(head_thresh)
    head_thresh = scipy.ndimage.morphology.binary_dilation(head_thresh,iterations=dilation)

    #remove some tiny objects
    head_label, heads = scipy.ndimage.measurements.label(head_thresh)
    body_label, bodies = scipy.ndimage.measurements.label(body_thresh)
    body_label = remove_small_objects(body_label,body_min)
    head_label = remove_small_objects(head_label,head_min)

    head_regions = regionprops(head_label, intensity_image = gray)
    body_regions = regionprops(body_label, intensity_image = gray)
    
    return gray, body_label, head_label, head_regions, body_regions

def comet_filter_location(info, stats, labels, ob_label, filter_object, crop_dim, head_tail_dist, comet_size=2000):
    ''' 
    Second round of filter mainly based on comet objects location
    Filter comet heads that are further to the top of the crop;
    Filter comet body that has more than one segmented objects, 
        and are huge, and are very far away from each other 
    '''
    
    ob_filter = False
    
    if(filter_object == 'Head' and len(info)>0):
        #keep the higher head (compare the top of head segments location)
        if stats[4] > info[0][4]:
            labels[np.isin(labels, ob_label)] = 0
            ob_filter = True
    
    if(filter_object == 'Body' and len(info) >0):
        #if more than 1 body segmentation object is larger than 2000, assumming > 1 real comet
        if stats[0] > comet_size and info[0][0] > comet_size:
            if stats[5] > info[0][5]:
                labels[np.isin(labels, ob_label)] = 0
                ob_filter = True
            else:
                #keep the one appeared higher
                labels[np.isin(labels, ob_label) == False] = 0
                info=[]
                info.append(stats)
                
        #if 2 body segmented object are far away(more than half crop distance) from each other
        if abs(stats[4] - info[0][4]) > (crop_dim[0] * head_tail_dist): #compare minr difference with acceptable dist 
            if stats[5] > info[0][5]: #compare max row
                labels[np.isin(labels, ob_label)] = 0
                ob_filter = True
            else:
                #keep the one appeared higher
                labels[np.isin(labels, ob_label) == False] = 0
                info=[]
                info.append(stats)
        
    return ob_filter, labels, info

def comet_info(crop_dim, labels, regions, area_min, area_max, filter_object, head_tail_dist):
    '''
    find and update comet basic infos and replace object outside area range with black pixels for each comet
    '''
    
    info = []
    ob_filter = False
    
    for region in regions:
        ob_label = region.label
        
        if( area_min < region.area < area_max):
            y0, x0 = region.centroid
            minr, minc, maxr, maxc = region.bbox
            mid = maxc - minc
            cir = 4*math.pi*(region.area/(region.perimeter*region.perimeter))
            stats = [region.area, region.mean_intensity, region.major_axis_length,
                   x0, minr, maxr, mid, (region.area*region.mean_intensity), cir]         
            
            #perform second round filtering
            ob_filter, labels, info = comet_filter_location(info, stats, labels, ob_label, filter_object, crop_dim, head_tail_dist)
            
            if not ob_filter:
                info.append(stats)
        else:
            #black out object if not in right size
            labels[np.isin(labels, ob_label)] = 0

    return labels, info


def comet_calculations(file, label, img, head_label, body_label, comet_head_info, comet_body_info, verbose = True, plot_graph=True):
    '''calculate comets statistics, includes:
    comet head length, comet tail length, comet body length,
    comet area, comet dna content, comet average intensity,
    head area, head dna content, head average intensity, head dna percentage,
    tail_area, tail dna content, tail average intensity, tail dna percentage
    Return list of statistics for each comet crop
    '''
    
    stats = []
    
    comet_area = sum(map(lambda x: x[0], comet_body_info))
    head_area = sum(map(lambda x: x[0], comet_head_info))
    body_cir_count = len(list(filter(lambda x: x[8] > 0.9, comet_body_info)))
    
    if len(comet_head_info) > 0 and len(comet_body_info)>0 :
        if not len(comet_body_info) > 4 or body_cir_count > 1:

            #bottom(maxr) of the higher head location, end of comet head/start of comet tail
            head_bottom= min(map(lambda x: x[5], comet_head_info))
            #top(minr) of the higher head location, start of comet head
            head_top = min(map(lambda x: x[4], comet_head_info))

            #bottom(maxr)of the body location,end of comet tail
            body_bottom= max(map(lambda x: x[5], comet_body_info))
            #bottom of the higher body(usually the head) location
            high_body_bottom = min(map(lambda x: x[5], comet_body_info))
            #top of the body location
            body_top = min(map(lambda x: x[4], comet_body_info))

            #in case of more than 1 object
            #x coordinate of comet head
            head_x = np.mean(comet_head_info, axis=0)[3]
            #avergae x cooridinate of comet bodies
            body_x = np.mean(comet_body_info, axis=0)[3]

            #filter if comet body appear higher than comet head
            dif = high_body_bottom - head_top 
            top_dif = body_top - head_top
            if(dif < 0 or top_dif <= -15):
                print("filter(comet body appear higher than comet head)")
                return

            #---Calculations
            bodylength = body_bottom - head_top 

            if(body_bottom >= head_bottom):
                taillength = body_bottom - head_bottom
                # find tail pixel intensities and filter in cases when body and head are far away in odd locations
                # and potentially has low intensity
                head_x_int, head_y_int = np.linspace(head_x, body_x, (bodylength)), np.linspace(head_top, body_bottom, (bodylength))
                top_zi = scipy.ndimage.map_coordinates(np.transpose(img), np.vstack((head_x_int,head_y_int)))
                if(np.mean(top_zi) < 40):
                    print("filter(odd location)")
                    return
            else:
                taillength=0

            head_length = bodylength - taillength
            comet_dna_content = sum(map(lambda x: x[7], comet_body_info))
            comet_avg_int = comet_dna_content/comet_area
            head_dna_content = sum(map(lambda x: x[7],comet_head_info))
            head_avg_int = head_dna_content / head_area
            head_dna_percentage = (head_dna_content / comet_dna_content) * 100
            tail_area = comet_area - head_area
            tail_dna_content = comet_dna_content - head_dna_content
            tail_average = tail_dna_content / tail_area
            tail_dna_percentage = (tail_dna_content/comet_dna_content) * 100
            
            if(tail_dna_percentage < 0 and comet_head_info[0][8] > 0.9):
                tail_dna_percentage = 0
                
            if tail_dna_percentage >= 0:
                fig, ax =  plt.subplots(1,3, figsize=(4, 4), sharex=True, sharey=True)
                ax[0].imshow(img, cmap='gray')
                ax[0].plot([head_x, body_x], [head_bottom, body_bottom], 'ro-')
                ax[0].plot([head_x+5, body_x+5], [head_top, body_bottom], 'bo-')
                ax[0].axis('off')
                ax[0].axis('image')
                ax[1].imshow(head_label, cmap='gray')
                ax[1].axis('off')
                ax[1].set_title('Head')
                ax[2].imshow(body_label, cmap='gray')
                ax[2].axis('off')
                ax[2].set_title('Body')

                fig.tight_layout()
                plt.show()

                if verbose:
                    print('Head length: ', head_length)
                    print('Tail length: ',  taillength)
                    print('Body length: ',  bodylength)
                    print('Comet area: ', comet_area)
                    print('Comet DNA content: ', comet_dna_content)
                    print('Comet average intensity: ', comet_avg_int)
                    print('Head area: ', head_area)
                    print('Head DNA content: ', head_dna_content)
                    print('Head average intensity: ', head_avg_int)
                    print('Head DNA %: ',  head_dna_percentage)
                    print('Tail area: ', tail_area)
                    print('Tail DNA content: ', tail_dna_content)
                    print('Tail average intensity: ', tail_average)
                    print('Tail DNA %: ', tail_dna_percentage)

                stats = [file, label, head_length, taillength, bodylength,
                         comet_area, comet_dna_content, comet_avg_int,
                         head_area, head_dna_content,head_avg_int,head_dna_percentage,
                         tail_area, tail_dna_content,tail_average,tail_dna_percentage]
                #set negatives to 0
                stats[2:] = [0 if i < 0 else i for i in stats[2:]]
            else:
                print("filter(negative tail_dna_percentage, protentially tiny bad debris)")
        #else:
            #print("filter(comet_calculations: too many objects found)", len(comet_body_info))
   # else:
        #print("filter(comet_calculations: too many objects found)",len(comet_head_info),' heads and bodies',len(comet_body_info))
         
    return stats


def comet_measure(data, crop_dim, head_min=200, body_min=300, head_max=3500, body_max=20000, body_thresh = 30, head_thresh= 185, head_tail_dist=0.3, dilate_iter = 9, plot_graph=True):
    '''
    Detect comet body and head
    Filter comets with certain criterias such as area, circularity,number of objects
    Measure and return comet statistics
    '''
    
    all_img = []
    all_stats = []
    for idx in range(0, len(data)):
        final_stats=[]

        file = data[idx][0]
        objectid = data[idx][1]
        cropped_img = data[idx][2]
        cropped_label = data[idx][3]
        
        #perform segmentation
        img, body_label, head_label, head_regions, body_regions = head_body_segment(cropped_img, cropped_label, lowerthresh=body_thresh, upperthresh=head_thresh,head_min=head_min,body_min=body_min,dilation=dilate_iter)
        plt.imshow(img, cmap='gray')
        plt.show()
        
        #perform necessary comet head and body filters 
        if(len(np.unique(body_label)) <= 5 and  1 < len(np.unique(head_label)) < 4):
                
            head_label, head_info = comet_info(crop_dim, head_label, head_regions, 
                                               area_min = head_min, area_max = head_max, filter_object= 'Head', head_tail_dist=head_tail_dist)
            body_label, body_info = comet_info(crop_dim, body_label, body_regions, 
                                               area_min = body_min, area_max = body_max, filter_object= 'Body', head_tail_dist=head_tail_dist)
                
            #perform final filtering and calculation
            stats = comet_calculations(file, objectid, img, head_label, body_label, head_info, body_info, plot_graph)
        
            if stats:
                all_stats.append(stats)
                all_img.append(np.asarray(cropped_img))
        else:
            print("filter(head_body_segment: too many objects found)")
                            
    return all_stats, all_img

In [12]:
#------------------------------- comet output functions --------------------------
def class_montage(imgs,w,h,rescale=True):
    '''
    Put comet crops together as montage
    '''
    plt.figure(figsize = (w,h))
    plt.imshow(montage(np.array(imgs),rescale_intensity=rescale),cmap=plt.cm.gray)

def generate_output(all_img, all_stats, output_path,image_extension):
    '''
    Generate outputs:
    1) Montage of all segmented comets (that are being kept from AutoComet)
    2) Each individual comet crop (that are being kept from AutoComet)
    2) CSV file with all comet statistics (that are being kept from AutoComet)
    '''
    #show comet montage image
    class_montage(np.array(all_img),40,40,rescale=True)
    
    #Make output folder
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    #output montage
    im = Image.fromarray(montage(np.array(all_img),rescale_intensity=True))
    im.save(os.path.join(output_path,'montage_comets.png'))

    #output individual crops 
    for (img,file) in zip(all_img, all_stats):
        name = file[0].split('/')[-1]
        filename = name.split('.'+image_extension)[0]
        cell = file[1]

        cv2.imwrite(os.path.join(output_path,filename+'_'+str(cell)+'.png'), img)

    #Ouput csv
    df = pd.DataFrame(all_stats, columns=['filename', 'cell', 'head length', 'tail length', 'body length','comet area', 
                                          'comet dna content', 'comet avg intensity','head area', 'head dna content', 'head avg intensity',
                                          'head dna percentage', 'tail area', 'tail dna content','tail avg intensity','tail dna percentage'])

    df.to_csv(os.path.join(output_path,'comet_measurements.csv'),index=False)
