In [None]:
import numpy as np
import matplotlib.pyplot as plt
import squidpy as sq
import scanpy as sc
import pandas as pd
import os
from PIL import Image
import imagecodecs
import tifffile as tff
from pyometiff import OMETIFFReader
import seaborn as sns
import geopandas as gpd
from shapely.geometry import Polygon

from sklearn.mixture import GaussianMixture
from scipy.stats import norm

import pickle
import json
import umap
import anndata as ad
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import logging
logging.basicConfig(format='%(asctime)s - %(message)s', 
                    datefmt='%d-%b-%y %H:%M:%S',
                    handlers=[logging.StreamHandler(),
                                logging.FileHandler("test.log", "a")])
logging.getLogger().setLevel(logging.INFO)

from itables import init_notebook_mode
init_notebook_mode(all_interactive=True)

# Defining functions

In [None]:
## Setup function for calculating elapsed time
def print_elapsed_time(start,stop):
    # Calculate the elapsed time in seconds
    elapsed_seconds=stop-start
    
    # Convert elapsed time to hours and minutes
    elapsed_minutes, elapsed_seconds = divmod(int(elapsed_seconds), 60)
    elapsed_hours, elapsed_minutes = divmod(elapsed_minutes, 60)
    
    # Print the result in the desired format
    return(f"{elapsed_hours} hours: {elapsed_minutes} minutes: {elapsed_seconds} seconds") 


### =================================================================
## Read nucleus/cell boundary .csv files + cell summary .csv
def read_boundary_files(data_dir,sample_dir):
    nucleus_bound=pd.read_csv(os.path.join(data_dir,sample_dir,'nucleus_boundaries.csv.gz'))
    cell_bound=pd.read_csv(os.path.join(data_dir,sample_dir,'cell_boundaries.csv.gz'))
    #transcripts=pd.read_csv(os.path.join(data_dir,sample_dir,'transcripts.csv.gz'))    
    #cells=pd.read_csv(os.path.join(data_dir,sample_dir,'cells.csv.gz'))

    return nucleus_bound,cell_bound


### =================================================================
## Read __morphology_focus.ome.tiff__ DAPI-fluorescence picture of slide
def read_ome_tiff(slide_type,data_dir,sample_dir):
    if slide_type=='focus':
        fn=os.path.join(data_dir,sample_dir,'morphology_focus.ome.tif')

    if slide_type=='mip':    
        fn=os.path.join(data_dir,sample_dir,'morphology_mip.ome.tif')

    if slide_type=='z_stack': 
        fn=os.path.join(data_dir,sample_dir,'morphology.ome.tif')
        
    reader = OMETIFFReader(fpath=fn)
    img_array,metadata,xml_metadata=reader.read()

    return img_array,metadata,xml_metadata



### =================================================================

## Create dataframe of mask polygon coordinates from json outputs, that Baysor created
def create_cell_bound_df(polygon_fn,seg):

    ## The cell names in segmentation.csv have a prefix in front of the cell name (i.e.'CRe50034e63-1')
    #  Extract that prefix and add it to the cell names coming from the json file to make them uniform
    prefix=seg.loc[~seg['cell'].isna(),'cell'].str.split('-',expand=True)[0].unique()

    if len(prefix)!=1:
        raise ValueError('Multiple prefixes in the segmentation.csv cell column => check the length of prefix variable in create_polygon_df() function!')

    prefix=prefix[0]
    
    import json
    # Opening JSON file
    f=open(polygon_fn)
    polygon=json.load(f)
    
    ## Init list to collect cell dfs with coordinates
    df_list=[]
    for n in range(len(polygon['geometries'])):
        ## Extract polygon coordinates
        coord=np.squeeze(np.array(polygon['geometries'][n]['coordinates']))
        
        ## Add only cells with non-empty masks
        if (coord.shape[0])>0:

            ## Extract cell name
            cell_name=polygon['geometries'][n]['cell']
            cell_name_arr=np.array([cell_name,]*int(coord.shape[0]))

            ## Stack coordinates and cell_name and create dataframe
            coord_with_name=np.hstack((coord,cell_name_arr.reshape(-1, 1)))
            colnames=['vertex_x_pixel','vertex_y_pixel','cell_id']
            df=pd.DataFrame(data=coord_with_name,columns=colnames)
        
            ## If cell name doesn't start with prefix (most of the time it doesn't, sometimes it does as it was saved with the prefix in the json file),
            #  add prefix

            #if not str(cell_name).startswith(prefix):
            df['cell_id']=prefix + '-' + df['cell_id'].astype(int).astype(str)
            df_list.append(df)

    ## Concatenate all cell dfs into one
    df=pd.concat(df_list)    

    return df    


### =================================================================

## Calculate nucleus pixel metric: median/mean/max/mode of nucleus pixels
def return_polygon_pixel_metric(x,img_array,metric):
    from matplotlib import path
    import statistics as st
    import warnings
    warnings.filterwarnings("ignore")
       
    ## Extract the vertex points of nucleus polygon + boundaries of 
    polygon_vert=list(zip(x['vertex_x_pixel'],x['vertex_y_pixel']))

    ## Extract rectangle around nucleus -> get coordinates of nucleus polygon and create rectangle
    x_min=int(x['vertex_x_pixel'].min())
    x_max=int(x['vertex_x_pixel'].max())

    y_min=int(x['vertex_y_pixel'].min())
    y_max=int(x['vertex_y_pixel'].max())


    ## Create rectangle around nucleus
    nx,ny=x_max-x_min+1,y_max-y_min+1
    x_,y_=np.meshgrid(np.arange(nx), np.arange(ny))
    x__,y__=x_.flatten()+x_min, y_.flatten()+y_min

    points = np.vstack((x__,y__)).T

    ## Check if created rectangle's pixels are inside/outside of polygon -> 
    #  mask_: boolean mask; False: pixel outside of nucl. polygon; True: pixel inside of nucl. polygon)
    p=path.Path(polygon_vert)
    mask=p.contains_points(points)
    mask_=mask.reshape(ny,nx)



    ## Extract the pixel numbers of the pixels from inside the polygon 
    #  Subset original image to rectangle around nucleus (saves memory)

    ## Some polygons output by Baysor are outside of the image -> their coordinates are larger than the img_array's shape
    #  => In this case just only take part of the polygons that lies on the image = resize mask_ to fit on img_array
    y_max_coor=min(img_array.shape[0],y_min+ny)
    x_max_coor=min(img_array.shape[1],x_min+nx)
    sub_img_array=img_array[y_min:y_max_coor,x_min:x_max_coor]
    mask_=mask_[0:y_max_coor-y_min,0:x_max_coor-x_min]

    ## Calculate metric of nucleus pixels (rectangle masked with boolean mask)

    def return_indentity(x):
        return x
    
    metric_dict={'median':np.median,'mean':np.mean,'max':np.max,'mode':st.mode,'raw_vals':return_indentity}
    metric_func=metric_dict[metric]


    ## Check if there are at least 3 unique polygon vertices (for some points baysor outputs only 1 unique x-y pair)
    try:
        nucleus_polygon_metric=metric_func(sub_img_array[mask_].flatten())
    except ValueError:
         nucleus_polygon_metric=np.nan

    '''
    ## Plotting functions to showcase the pixels of the polygon for one nucleus
    #  For this, run the following code outside of ths function:

    #  cell_id='aaaejiml-1'
    #  nucleus_bound[nucleus_bound['cell_id']==cell_id].groupby('cell_id').apply(return_polygon_pixels,img_array=img_array)

    #print('mask_',mask_)
    
    ## Plot the nuclues boolean mask 
    fig,ax=plt.subplots(1,3,figsize=(8,5))
    ax[0].imshow(mask_,origin='upper')

    ## Plot the polygon with GeoPandas dataframe for given cell
    #nucleus_polygons.loc[nucleus_polygons.index.isin(['aaaejiml-1']),:].plot(ax=ax[1],column='real_cell',cmap=reversed_map,legend=True,alpha=0.3)
    #nucleus_polygons.loc[nucleus_polygons.index.isin(['aaakgmde-1']),:].plot(ax=ax[1],column='real_cell',cmap=reversed_map,legend=True,alpha=0.3)
    #ax[1].invert_yaxis()

    ## Show raw nucleus data taken from oiriginal image
    #  Subset original slide image to a rectangle around the nucleus 
    sub_img_array=img_array[y_min:y_min+ny,x_min:x_min+nx]
    img=ax[1].imshow(points,origin='upper')

    ## Show nucleus masked with polygon + DAPI intensities inside of polygon
    #sub_img_array[~mask_]=0
    img2=ax[2].imshow(sub_img_array,origin='upper')
    #plt.colorbar(img2, ax=ax[2])
    plt.show()
    
    '''
    
    return nucleus_polygon_metric

# Save baysor segmentation outputs in dicts RUNTIME ~ 4-5 hours
- Cell boundaries
- Segmentation statistics
- Calculate intensity metrics of cellular masks (mean,median,mode,max) for each sample

In [None]:
data_dir="/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data/"#processed_data/cell_segmentation"

## Drop scratch folders that start with "._"
panel_dir=[f for f in os.listdir(data_dir) if ('Panel' in f and '._' not in f)]
panel_dir.sort()

## Create list of samples to loop over (as script sometimes exits due to kernel death, this cell may need to be run split between the samples)
#  => this list controls which samples are being processed
panels=['Panel1','Panel2']
samples=['P1_H','P2_H','P3_H','P4_H','P1_D','P2_D','P3_D','P4_D']

#panels=['Panel1']
#samples=['P3_D']
samples_to_loop_over=['_'.join([panel,sample]) for panel in panels for sample in samples]

segmend_dict={'no_segmentation':{},
             '10x':{'expansion_sizes':[0]},
             'cellpose':{'expansion_sizes':[4,6,10],
                         'modes':['cyto','nucleus']}
             }


for panel in panel_dir[:]:
    panel_dir=os.path.join(data_dir,panel)
 
    ## Loop over all samples in a batch
    for sample_name in os.listdir(panel_dir)[0:]:
        sample_dir=os.path.join(data_dir,panel,sample_name)

        if os.path.isdir(sample_dir): #and 'P1_D' in sample_dir:
            sample_dict={}
            
            ## Extract Panel_Sample_name as string
            panel_sample_name='_'.join([panel.split('_')[-1],sample_name.split('__')[2]])

            ## Check if the sample name is in the samples that should be processed => if yes process them
            if panel_sample_name in samples_to_loop_over:
                logging.info(panel_sample_name)
                
    
                ## Load DAPI fluorescent-stained slide image
                # Select one slide_type ('mip'/'focus'/z_stack') file to load and 
                slide_type='mip'
                img_array,metadata,xml_metadata=read_ome_tiff(slide_type,data_dir,sample_dir)
                logging.info('Slide loaded')
    
                ## Create baysor output folder path
                baysor_out_fold=os.path.join(data_dir,'processed_data/baysor_output',panel_sample_name)
    
                ## Get only name of folders to loop over in baysor_output folder (filter for segmentation methods as well)
                bays_model_dirlist=[filename for filename in os.listdir(baysor_out_fold) if os.path.isdir(os.path.join(baysor_out_fold,filename))]
                #bays_model_dirlist=[filename for filename in bays_model_dirlist if any(x in filename for x in \
                #                                                                       ['cellpose_baysor-CPn_0','cellpose_baysor-CPc_0'])] #'no_segmentation','10x',
                
                
                ### Loop over baysor segmentation models and extract cell mask polygon metrics
                for bays_model_name in bays_model_dirlist[:]:
                    logging.info(bays_model_name)
    
                    ## Read in segmentation results + cell statistics output by Baysor (cells stats, area of cells)
                    seg_fn=os.path.join(baysor_out_fold,bays_model_name,'segmentation.csv')
                    seg=pd.read_csv(seg_fn)
                 
                    seg_stats_fn=os.path.join(baysor_out_fold,bays_model_name,'segmentation_cell_stats.csv')
                    seg_stats=pd.read_csv(seg_stats_fn)
    
                    polygon_fn=os.path.join(baysor_out_fold,bays_model_name,'segmentation_polygons.json')
                    cell_bound=create_cell_bound_df(polygon_fn,seg)
    
    
                    ## Extract pixels of cell mask polygons (identified by Baysor model) and return a dataframe with some metrics of these pixel 
                    #  intensities
                    metric_df_list=[]
                    metric_list=['median','mean','max','mode','raw_vals']
                    for metric in metric_list:
                        temp_df=cell_bound.groupby('cell_id').apply(return_polygon_pixel_metric,img_array=img_array,metric=metric)
                        metric_df_list.append(temp_df)                
                    #del img_array               
                    nucleus_polygon_pixel_metrics=pd.concat(metric_df_list,axis=1)
                    nucleus_polygon_pixel_metrics.columns=metric_list
                    logging.info('Cellular polygon metrics calculated')
    
                    
                    ## Save the processed data as a dictionary
                    #sample_dict['seg_stats']=seg_stats
                    sample_dict['cell_bound']=cell_bound
                    sample_dict['cell_polygon_pixel_metrics']=nucleus_polygon_pixel_metrics
                    sample_dict['slide_metadata']=metadata
                    
                    proc_dir=os.path.join(data_dir,'processed_data/true_cell_filtering/baysor',panel_sample_name,bays_model_name)
                    
                    if not os.path.isdir(proc_dir):
                        os.makedirs(proc_dir)
                        logging.info(f'Created directory: {proc_dir}')
    
                    fpath=os.path.join(proc_dir,panel_sample_name+'.pickle')
                    pickle.dump(sample_dict, open(fpath, "wb"))
                    logging.info(f'{panel_sample_name}-{bays_model_name} saved as pickle\n')
                logging.info('=====================')
 
logging.info('Done!')  