In [155]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%reload_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

import sys
sys.path.append('../src')

from cellpose import plot
import cv2
from functools import reduce
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import pickle
import re
from skimage.measure import regionprops_table
from skimage.segmentation import mark_boundaries, expand_labels
import tifffile as tiff
from time import time
from tqdm import tqdm

from metadata import metadata
from utils import list_subdir_filter as lsd

md = metadata()

for f in md.folders.values():
    Path(f).mkdir(exist_ok=True, parents=True)


In [156]:
def overlap_segmentations(top_tile, bottom_tile, overlap=200):

    # create copies because... eh. 
    top_tile = np.copy(top_tile)
    bottom_tile = np.copy(bottom_tile)

    # take the IDs of all segmentation masks in the top tile which are touching or below 
    # half of the overlapping area. They will be deleted
    top_remove = set(top_tile[-overlap//2:].flatten())

    # Now take the IDS of all seg masks in the *bottom* tile, touching or below half of
    # the overlapping area. These are the only ones that will be kept instead!
    bottom_keep = set(bottom_tile[overlap//2:overlap].flatten())

    # remove (paint 0s all over them) all cells who are below the mid-overlap line in the
    # top tile, and above it in the bottom tile! makes sense, right?
    top_tile[np.isin(top_tile, list(top_remove))] = 0
    bottom_tile[:overlap][np.isin(bottom_tile[:overlap], list(bottom_keep), invert=True)] = 0

    # take the maximum tile ID on the top tile, so that the bottom tile will continue from max+1
    top_max = np.max(top_tile)
    
    try:
        # if the bottom tile has no foreground (=no areas with values >0) the following
        # line would return error, so we catch it with the 'except:' below
        bottom_min = np.min(bottom_tile[bottom_tile>0])
        bottom_tile[bottom_tile>0] += top_max - bottom_min + 1
    except:
        bottom_min = 0

    # cut the overlapping area from top and bottom, and make it into its own array 
    top_cropped = top_tile[:-overlap]
    intersection = top_tile[-overlap:] + bottom_tile[:overlap]
    bottom_cropped = bottom_tile[overlap:]
    
    return np.concatenate([top_cropped, intersection, bottom_cropped])


def return_whole_segmentation(tile_code, overlap=200):
    all_segmentation_files = lsd(seg_dir, True, f'{tile_code}_.*\.pkl')
    all_segmentation_masks = [pickle.load(open(file, 'rb'))['masks'].astype('uint32') for file in all_segmentation_files]

    whole_segmentation = all_segmentation_masks[0]

    for i in range(1, len(all_segmentation_masks)):
        whole_segmentation = overlap_segmentations(whole_segmentation, all_segmentation_masks[i])

    return whole_segmentation


def segment_and_regionprops(tile_code):
    t0 = time()
    
    img_file =  os.path.join(md.folders['bg_removed'], f'A40_{tile_code}_bg_removed_uint8.npy')
    pickle_file_path = os.path.join(md.folders['regionprops'], f'A40_{tile_code}_cell_measures.pickle')

    print('Segmentation: ', end='')
    segmentation = return_whole_segmentation(tile_code, overlap=200)


    if tile_code in ['2420', '2430']:
        mask = tiff.imread(f'../data/external/negmask_{tile_code}.tif')
        scale_mask = np.zeros_like(segmentation)
        true_px = np.argwhere(mask > 128)
        
        for pix in tqdm(true_px):
            ppx = np.array([pix*24, (pix + 1) * 24])
            scale_mask[ppx[0, 0]: ppx[1, 0], ppx[0, 1]: ppx[1, 1]] = 255
            
        segmentation[np.array(scale_mask) > 0] = 0
    
    np.save(os.path.join(seg_dir, f'whole_segmentation_{tile_code}.npy'), segmentation)
    whole_image = np.moveaxis(np.load(img_file), 0, -1)

    cell_segmentation = expand_labels(segmentation, 25)
    cells_only = cell_segmentation - segmentation
    print(f'{round(time() - t0)}s')
    t_ = time()


    print('Regionprops tables: (nuclei...) ', end='')
    tab = regionprops_table(
        label_image = segmentation,
        intensity_image = whole_image,
        properties = ['label', 'centroid', 'bbox', 'area', 'area_convex', 'axis_major_length', 
                      'axis_minor_length', 'intensity_min', 'intensity_mean', 'intensity_max']
    )
    print('(cells...) ', end='')
    cell_tab = regionprops_table(
        label_image = cells_only,
        intensity_image = whole_image,
        properties = ['label', 'centroid', 'intensity_min', 'intensity_mean', 'intensity_max']
    )   
    print(f'{round(time() - t_)}s')
    cell_tab = {k if k == 'label' else f'cell_{k}' : i for k,i in cell_tab.items()}

    t_ = time()
    print('Join and save: ', end='')
    df_final = reduce(
        lambda left, right: pd.merge(left, right, on='label'), 
        map(pd.DataFrame.from_dict, [tab, cell_tab])
    )

    with open(pickle_file_path, 'wb') as handle:
        pickle.dump(df_final, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'{round(time() - t_)}s')
    print(f'-------------------------------------\nTotal duration: {round(time() - t0)}')

    return df_final
    

In [157]:
seg_dir = os.path.join(md.folders['segmented'], 'cellpose')
img_files =  lsd(md.folders['bg_removed'], True, '\.npy')
all_segmented_files = lsd(seg_dir, False, 'TT_.*')
all_slide_ids = sorted(list(set([re.sub('.*A40_([0-9]+).*', '\\1', i) for i in all_segmented_files])))

In [None]:
for slide_id in ['2420', '2430']:#tqdm(all_slide_ids):
    if 1:#ot os.path.exists(os.path.join(md.folders['regionprops'], f'A40_{slide_id}_cell_measures.pickle')):
        _ = segment_and_regionprops(slide_id)
    else:
        print(f'skip {slide_id}')

Segmentation: 