In [None]:
## Imports
import json
from pathlib import Path
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from skimage.draw import polygon_perimeter
from skimage import draw

In [None]:
## Parameters

path_base = Path(r"PASTE-FULL-PATH-TO-ANALYSIS-FOLDER")
image_size = (1024,1024)
dist_max = 90
bin_step = 10 


In [None]:
## Batch process folder

# Different folders
path_json = path_base / 'axes_enrichment' / 'annotations'
path_spots = path_base / 'spot_detection'

path_save = path_base / 'axes_enrichment' / 'results'
if not path_save.is_dir():
    path_save.mkdir(parents=True)

# Loop over all spot detection results
for f_spots in path_spots.glob('*__spots.csv'):

    print(f'>>> Processing spot detection file {f_spots}')
    name_spots = f_spots.name

    # Annnotation file
    name_json = name_spots.replace('__spots.csv','.json')
    name_json_full = (path_json / name_json).resolve()

    if not name_json_full.is_file():
        print(f'Annotation does not exist: {name_json_full}')
        continue

    with open(name_json_full, encoding='utf-8-sig') as fh:
        data_json = json.load(fh)

        # Overwrite default file size if bounding box is present
        if 'bbox' in data_json:
            image_size = (int(data_json['bbox'][2]-data_json['bbox'][0]+1),
                                int(data_json['bbox'][3]-data_json['bbox'][1]+1))
        else:
            print('Image size not provided in geojson file.')

        # Loop over list and create simple dictionary & get size of annotations
        n_features = len(data_json['features'])
        if n_features != 1:
            print(f'Annotation file CAN ONLY contain 1 annotation, not {n_features}.')
            continue

        annot_type = data_json['features'][0]['geometry']['type']
        if annot_type not in ['LineString']:
            print(f'Annotation type {annot_type} not supported .')
            continue            

        line = np.squeeze(np.asarray(data_json['features'][0]['geometry']['coordinates'])).astype('int')

        # >>>>> Process annotation

        # >> Loop over polygon coordinates and create closed line

        rr_all =  np.empty([0], dtype=int)
        cc_all = np.empty([0], dtype=int)
        img_mask = np.zeros(image_size)

        for i in range(line.shape[0]-1):
            r0 = line[i][1]
            c0 = line[i][0]
            r1 = line[i+1][1]
            c1 = line[i+1][0]

            rr, cc = draw.line(r0, c0, r1, c1)
            rr_all = np.append(rr_all,rr)
            cc_all = np.append(cc_all,cc)

        # >> Remove duplicate entries   
        line_sampled = np.column_stack((cc_all,rr_all)) 
        _,idx = np.unique(line_sampled, axis=0, return_index=True)
        line_sampled = line_sampled[np.sort(idx)]

        img_mask[rr_all, cc_all] = 1
        edt, inds = ndimage.distance_transform_edt(np.logical_not(img_mask), return_indices=True)

        # >> Distance along the line
        d = np.diff(line_sampled, axis=0)
        segdists = np.sqrt((d ** 2).sum(axis=1))

        dist_orig = np.cumsum(segdists)
        dist_orig = np.append(0,dist_orig)

        # >> Set 0 to turning point
        ind0 = np.argmin(line_sampled[:,0])
        dist_orig = dist_orig - dist_orig[ind0]

        # >> Set min to upper right corner
        y_start = line_sampled[0,1]
        y_end = line_sampled[-1,1]
        if y_end < y_start:
            dist_orig = -1*dist_orig


        # >>>>  Read spot detection file
        spots = pd.read_csv(f_spots, sep=',').to_numpy()

        #  >> Get index of closest pixel on the line 
        edt_spots = edt[spots[:,1], spots[:,2]]
        ind_spots_keep = edt_spots <= dist_max

        inds_spots_ax0 = inds[0,spots[:,1], spots[:,2]]
        inds_spots_ax1 = inds[1,spots[:,1], spots[:,2]]
        inds_spots = np.column_stack((inds_spots_ax1,inds_spots_ax0))

        inds_spots = inds_spots[ind_spots_keep,:]

        unique_rows, counters = np.unique(inds_spots, axis=0, return_counts=True)

        # >>>>> Combine results in data-frame
        df_line = pd.DataFrame(data=line_sampled,columns=["ax1", "ax2"])
        df_line['dist_orig'] = dist_orig
        df_spots = pd.DataFrame(data=unique_rows,columns=["ax1", "ax2"])
        df_spots['n_rna'] = counters

        # >> Merge the data frames & save
        df_results = pd.merge(df_line, df_spots,  how='left', left_on=['ax1','ax2'], right_on = ['ax1','ax2'])
        df_results["n_rna"] = df_results["n_rna"].fillna(0)
        #df_results['n_rna_movavg'] = df_results["n_rna"].rolling(window=n_avg).mean()

        name_save = path_save / name_spots.replace('__spots.csv','__axes_enrich.csv')
        df_results.to_csv(name_save, index=False)

        # >>>> Binning of data
        bins_neg = np.sort(-np.arange(bin_step, -dist_orig.min() , bin_step, dtype='int16'))
        bins_pos = np.arange(0, dist_orig.max(), bin_step, dtype='int16')
        dist_bin = np.concatenate((bins_neg, bins_pos), axis=0)

        # Determine to which bin data-points belong
        digitized = np.digitize(df_results['dist_orig'], dist_bin)

        # Sum RNA counts for each bin
        n_rna_bin = [df_results['n_rna'][digitized == i].sum() for i in range(1, len(dist_bin))]

        df_bin = pd.DataFrame({'dist_bin': dist_bin[0:-1],
                            'n_rna_bin': n_rna_bin})

        name_save = path_save / name_spots.replace('__spots.csv','__axes_enrich__binned.csv')
        df_bin.to_csv(name_save, index=False)

        # >>>  Plot results

        fig, ax = plt.subplots(2, 2)
        fig.set_size_inches((10, 10))

        ax[0][0].imshow(edt,cmap="hot")
        ax[0][0].get_xaxis().set_visible(False)
        ax[0][0].get_yaxis().set_visible(False)
        ax[0][0].set_title('Axes and distance from axes')
        ax[0][0].plot(line[:,0], line[:,1], color='b')

        ax[0][1].plot(line[:,0], line[:,1], color='b')
        ax[0][1].set_title('Spots (green-kept, red-removed)')
        ax[0][1].scatter(spots[ind_spots_keep,2], spots[ind_spots_keep,1], color='g', s=1)
        ax[0][1].scatter(spots[np.logical_not(ind_spots_keep),2], spots[np.logical_not(ind_spots_keep),1], color='r', s=1)
        ax[0][1].invert_yaxis()
        ax[0][1].set_aspect('equal', 'box')

        ax[1][0].hist(edt_spots, 50, density=True, facecolor='g', alpha=0.75)
        ax[1][0].set_title('Hist of distance from axis')
        ax[1][0].set_ylabel('Frequency')
        ax[1][0].set_xlabel('Distance [pix]')

        sns.lineplot(x="dist_orig", y="n_rna", data=df_results,ax=ax[1][1])
        sns.lineplot(x="dist_bin", y="n_rna_bin", data=df_bin)

        plt.tight_layout()

        name_save = path_save / name_spots.replace('__spots.csv','__axes_enrich.png')
        plt.savefig(name_save,dpi=300)
        plt.close()