In [1]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm
import dask.array as da
import pandas as pd
import numpy as np
import pickle as pkl
from tifffile import imread
from skimage.color import gray2rgb

from pathlib import Path
from fastai.learner import load_learner

from utils import extract_regionprops, mask_from_df, suppress_by_iou_mitos, compute_iou_array, add_soma_data, label_func

In [None]:
im_path = '/mnt/d/data_analysis/2025_Birder_mito/88EM87C 25x25_ashlar.ome.tif'
output_dir = '/mnt/d/data_analysis/2025_Birder_mito/C_00_analysis'
prefix_save = '88EM87C_'
mitos_sub_dir = 'mitos_sam'

classifier_path = Path('/mnt/d/data_analysis/2025_Birder_mito/250416_classifier.pkl')

axon_res = 3
mitos_res = 0
row_offset = 0 # used if df was created as a test on a smaller image
col_offset = 0 # used if df was created as a test on a smaller image

# details of selecting mito masks
res_adjust = axon_res - mitos_res

properties = ['area', 'area_convex', 'area_filled', 'euler_number', 'image','bbox',
        'eccentricity', 'solidity', 'centroid', 'major_axis_length', 'minor_axis_length']

min_iou = 0.85
iou_threshold = 0.2 # used for suppressing overlapping mito masks
min_area = 200
max_area = 2000
max_eccentricity = 0.9

# parameters for the classifier
pad = 50

## Get an image

In [3]:
store = imread(im_path, aszarr=True)
im = da.from_zarr(store, mitos_res)
print('Image shape:', im.shape)

Image shape: (22084, 29266)


## Get main df

In [4]:
df_name = f'{prefix_save}axons.pkl'
df_path = os.path.join(output_dir,df_name)
df_cells = pd.read_pickle(df_path)
print('Number of axons:',len(df_cells))
df_cells.head()

Number of axons: 548


Unnamed: 0,index,area,bbox,predicted_iou,point_coords,stability_score,crop_box,tile_row_start,tile_row_end,tile_col_start,...,inside_centroid-1,inside_area,inside_eccentricity,inside_major_axis_length,inside_minor_axis_length,inside_bbox-0,inside_bbox-1,inside_bbox-2,inside_bbox-3,inside_label
0,1,8317,"[740.0, 671.0, 157.0, 118.0]",0.972656,"[[771.34375, 699.15625]]",0.961874,"[639.0, 639.0, 385.0, 385.0]",0,1024,0,...,812.0,5262.0,0.780537,106.290709,66.443329,694.0,766.0,767.0,873.0,1
1,14,722,"[921.0, 953.0, 29.0, 45.0]",0.914062,"[[939.78125, 963.84375]]",0.943472,"[639.0, 639.0, 385.0, 385.0]",0,1024,0,...,934.0,282.0,0.929873,32.778532,12.058577,961.0,928.0,992.0,943.0,2
2,54,8935,"[744.0, 306.0, 136.0, 134.0]",0.964844,"[[867.59375, 346.03125]]",0.964152,"[639.0, 213.0, 385.0, 387.0]",0,1024,0,...,814.0,4639.0,0.504062,82.751126,71.469438,330.0,775.0,409.0,854.0,3
3,80,7428,"[139.0, 69.0, 134.0, 137.0]",0.976562,"[[157.21875, 84.65625]]",0.971724,"[0.0, 0.0, 387.0, 387.0]",0,1024,0,...,205.0,6586.0,0.501551,99.604531,86.170657,89.0,159.0,187.0,250.0,4
4,85,10627,"[68.0, 197.0, 126.0, 145.0]",0.953125,"[[84.65625, 253.96875]]",0.972049,"[0.0, 0.0, 387.0, 387.0]",0,1024,0,...,139.0,2564.0,0.644323,65.395412,50.011351,244.0,112.0,302.0,169.0,5


### Get SAM masks

In [5]:
masks_dir = os.path.join(output_dir,mitos_sub_dir)
file_list = [x for x in os.listdir(masks_dir) if x.endswith('.pkl')]   
print(f'Found {len(file_list)} mito masks in {masks_dir}')

Found 546 mito masks in /mnt/d/data_analysis/2025_Birder_mito/C_00_analysis/mitos_sam


In [6]:
df_list = []

for file_name in tqdm(file_list):

    print(f'Processing {file_name}')
    
    # load suggested mito masks
    file_path = os.path.join(masks_dir, file_name)
    masks_org = pkl.load(open(file_path, 'rb'))

    # get info about the cell
    cell_ind = int(file_name.split('_')[1])
    row = df_cells.loc[df_cells['label'] == cell_ind]
    inside_mask = row['inside_image'].values[0]

    # discard mitos not inside axons
    masks_sel = [x for x in masks_org if inside_mask[int(x['point_coords'][0][1]/2**res_adjust), int(x['point_coords'][0][0]/2**res_adjust)]]

    df = pd.DataFrame(masks_sel)
    if len(df) == 0:
        print('No masks found')
        continue

    # get mitos data
    df = df[((df['area'] > min_area) & (df['area'] < max_area) & (df['predicted_iou'] > min_iou))]
    print(f'Number of masks: {len(df)}')

    props_list = [] 
    for ind, row_mito in df.iterrows():
        props = extract_regionprops(row_mito,properties,small_size = min_area)
        # keep origin of the data
        props['origin'] = ind
        props_list.append(props)

    if len(props_list) == 0:
        print('No masks found')
        continue
    
    props_all = pd.concat(props_list, ignore_index=True)

    props_all.columns = [ f'mito_{col}' if col == 'area' else col for col in props_all.columns]

    # Concatenate with the original DataFrame
    df['origin'] = df.index
    df_all = pd.merge(df, props_all,left_on = 'origin', right_on= 'origin', how = 'right').reset_index()

    # drop objects by measured categories
    df_all = df_all[(df_all['eccentricity'] < max_eccentricity)]

    print(f'Number of separate objects: {len(df_all)}')

    # suppress by iou
    df_all = df_all.reset_index(drop=True)
    df_all = suppress_by_iou_mitos(df_all, iou_threshold=iou_threshold)
    df_all = df_all.loc[df_all.keep==1,:]
    print(f'Number of self standing objects: {len(df_all)}')

    # shift into the main coordinate system

    row_start = (int(row['inside_bbox-0'].iloc[0]) + row_offset)*2**res_adjust
    row_end = (int(row['inside_bbox-2'].iloc[0]) + row_offset)*2**res_adjust
    col_start = (int(row['inside_bbox-1'].iloc[0]) + col_offset)*2**res_adjust
    col_end = (int(row['inside_bbox-3'].iloc[0]) + col_offset)*2**res_adjust

    df_all['bbox-0'] = df_all['bbox-0'] + row_start
    df_all['bbox-2'] = df_all['bbox-2'] + row_start
    df_all['bbox-1'] = df_all['bbox-1'] + col_start
    df_all['bbox-3'] = df_all['bbox-3'] + col_start
    df_all['centroid-0'] = df_all['centroid-0'] + row_start
    df_all['centroid-1'] = df_all['centroid-1'] + col_start

    df_all['cell_label'] = cell_ind

    # drop segmentation in the local coord system
    df_all.drop(columns=['segmentation'], inplace=True)
    # drop area of the orignal segmented region
    df_all.drop(columns=['area'], inplace=True)

    df_list.append(df_all)

df_mitos = pd.concat(df_list, ignore_index=True)
df_mitos['label'] = df_mitos.index + 1

  0%|          | 0/546 [00:00<?, ?it/s]

Processing 88EM87C_000003_mito.pkl
Number of masks: 12
Number of separate objects: 11
Number of self standing objects: 11
Processing 88EM87C_000004_mito.pkl
Number of masks: 17
Number of separate objects: 12
Number of self standing objects: 10
Processing 88EM87C_000005_mito.pkl
Number of masks: 14
Number of separate objects: 12
Number of self standing objects: 9
Processing 88EM87C_000006_mito.pkl
Number of masks: 2
Number of separate objects: 2
Number of self standing objects: 2
Processing 88EM87C_000007_mito.pkl
Number of masks: 4
Number of separate objects: 4
Number of self standing objects: 3
Processing 88EM87C_000008_mito.pkl
Number of masks: 32
Number of separate objects: 28
Number of self standing objects: 22
Processing 88EM87C_000009_mito.pkl
Number of masks: 3
Number of separate objects: 3
Number of self standing objects: 2
Processing 88EM87C_000010_mito.pkl
Number of masks: 0
No masks found
Processing 88EM87C_000011_mito.pkl
Number of masks: 22
Number of separate objects: 19
N

## Classify objects

In [7]:
learn = load_learner(classifier_path)
learn.model.cuda()
print(next(learn.model.parameters()).device)  # Should print 'cuda:0'

If you only need to load model weights and optimizer state, use the safe `Learner.load` instead.
  warn("load_learner` uses Python's insecure pickle module, which can execute malicious arbitrary code when loading. Only load files you trust.\nIf you only need to load model weights and optimizer state, use the safe `Learner.load` instead.")


cuda:0


In [None]:
pad = 50

df_mitos['prediction_prob'] = None
df_mitos['prediction_prob'] = df_mitos['prediction_prob'].astype(object)

for ind, row in tqdm(df_mitos.iterrows(), total=len(df_mitos)):

    row_start = int(row['centroid-0'] - pad)
    row_stop = int(row['centroid-0'] + pad)
    col_start = int(row['centroid-1'] - pad)
    col_stop = int(row['centroid-1'] + pad)

    im_mito = gray2rgb(im[row_start:row_stop, col_start:col_stop]).compute()

    cat,_,probs = learn.predict(im_mito)
    df_mitos.loc[ind,'prediction'] = str(cat)
    df_mitos.at[ind,'prediction_prob'] = np.array(probs, dtype=float) 

    df_mitos['prediction_prob'] = df_mitos['prediction_prob'].apply(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)

# save the data frame
df_name = f'{prefix_save}mitos.pkl'
df_mitos.to_pickle(os.path.join(output_dir,df_name))