# script to get cellpose masks from FOV image
Make sure to check .gpl export works for the new version of stim_select_cp.ipynb (if it works delete the old code in the second part of notebook)

In [None]:
import numpy as np
import os
import napari
import yaml

from photostim_deve.image_analysis.plot import plot_motcorr_comparison
from photostim_deve.image_analysis.io import get_all_fov_image
from photostim_deve.image_analysis.segment import segment_fov_cpsam, get_cent_from_seg

from photostim_deve.control_exp.io import write_mp_file_cp, write_gpl_file_cp
from photostim_deve.control_exp.utils import remove_edge_masks
from photostim_deve.control_exp.plot import plot_stim, plot_segmentation_overlay

In [None]:
# set params
subject = 'jm065'


In [None]:
# TODO: remove unnecessary parameters (also from the yaml file)
with open("stim_select_cp_config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

session_type = cfg['session_type']
fov_imsize_onedim = cfg['fov_imsize_onedim']
s2p_imsize_onedim = cfg['s2p_imsize_onedim']
n_stim_cell = cfg['n_stim_cell']
session_reg_idx = cfg['session_reg_idx']
force_recompute = cfg['force_recompute']
run_motcorr = cfg['run_motcorr']
nimg_init = cfg['nimg_init']
sat_perc = cfg['sat_perc']
use_seg = cfg['use_seg']
manually_curate = cfg['manually_curate']
seed = cfg['seed']
edge_excl = cfg['edge_excl']
win_data_proc_path = cfg['win_data_proc_path']


In [None]:
fov_imsize = (fov_imsize_onedim, fov_imsize_onedim)  # size of the FOV in pixels (assumed square)
s2p_imsize = (s2p_imsize_onedim, s2p_imsize_onedim)  # size used for Suite2p processing (assumed square)

In [None]:
if win_data_proc_path is None:
    subject_path = os.path.join('data_proc', 'jm', subject)
else:
    subject_path = os.path.join(win_data_proc_path, 'jm', subject)
    
all_session_path = sorted([os.path.join(subject_path, d) for d in os.listdir(subject_path) if os.path.isdir(os.path.join(subject_path, d)) and session_type in d])
session_path = all_session_path[session_reg_idx]

save_path = os.path.join(session_path, 'stim_select_cp')
save_path_fig = os.path.join(save_path, 'figures')
os.makedirs(save_path, exist_ok=True)
os.makedirs(save_path_fig, exist_ok=True)

fov_s2p_px_fact = fov_imsize[0] / s2p_imsize[0] # both have the same aspect ratio

In [None]:
all_fov_image = get_all_fov_image(subject_path, 
                                  session_type=session_type,
                                  session_reg_idx = session_reg_idx, 
                                  run_motcorr=run_motcorr, 
                                  fov_imsize=fov_imsize, 
                                  nimg_init=nimg_init,
                                  force_recompute=force_recompute)

In [None]:
save_path

In [None]:
plot_motcorr_comparison(all_fov_image, sat_perc=sat_perc, save_path=save_path_fig, crop=(64, 64))

In [None]:
all_fov_image_seg = segment_fov_cpsam(all_fov_image, flow_threshold=0.4, cellprob_threshold=0.0, force_recompute=force_recompute, save_path=save_path, segment_only=['1100nm'])

fov_image = all_fov_image_seg['1100nm']
seg = all_fov_image_seg['1100nm_seg']
seg_cur = remove_edge_masks(fov_image, seg, edge_excl=edge_excl)

In [None]:
if manually_curate:
    viewer = napari.Viewer()
    viewer.add_image(fov_image, name='Mean FOV Image', colormap='gray', contrast_limits=[0, np.percentile(fov_image, sat_perc)])
    viewer.add_labels(seg_cur, name='Cellpose Segmentation')
    viewer.show()

    # add a line to not run the following cells until manual curation is done
    raise RuntimeError("Manual curation required! It is sufficent to do the curation of the labels layer and close napari afterwards.\nThe next cell will visualise the results of curation etc.")

In [None]:
plot_segmentation_overlay(fov_image, seg, sat_perc=sat_perc, title='Cellpose Segmentation', save_path=os.path.join(save_path_fig, 'cp_seg.png'))
plot_segmentation_overlay(fov_image, seg_cur, sat_perc=sat_perc, title='Curated Cellpose Segmentation', save_path=os.path.join(save_path_fig, 'cp_seg_cur.png'))

In [None]:
x_meds, y_meds = get_cent_from_seg(seg_cur) 

meds = np.column_stack((y_meds, x_meds))
inds = np.random.choice(len(meds), n_stim_cell, replace=False)
meds_stim = meds[inds]

In [None]:
plot_stim(fov_image, meds_stim, meds=meds, sat_perc=sat_perc, title='Curated CP medians and selected stim points (orange)', save_path=os.path.join(save_path_fig, 'meds_all.png'))
plot_stim(fov_image, meds_stim, sat_perc=sat_perc, title='Slected CP medians (stim points)', save_path=os.path.join(save_path_fig, 'meds_stim.png'))

In [None]:
# readjust based on pixel ratios
if fov_s2p_px_fact != 1.0:
    meds_stim /= fov_s2p_px_fact
    meds /= fov_s2p_px_fact

# swap x and y because of different Bruker convention 
meds_stim_export = np.copy(meds_stim)
meds_stim_export[:, 0] = meds_stim[:, 1]
meds_stim_export[:, 1] = meds_stim[:, 0]

In [None]:
write_mp_file_cp(meds_stim_export,
              mp_temp_path=cfg['mp_temp_path'], 
              mouse_str=subject,
              export_path=save_path, 
              fov_shape=s2p_imsize, 
              SpiralWidth=cfg['SpiralWidth'], 
              SpiralHeight=cfg['SpiralHeight'], 
              SpiralSizeInMicrons=cfg['SpiralSizeInMicrons'],
              use_seg=use_seg)

In [None]:
write_gpl_file_cp(meds_stim_export, 
                gpl_temp_path=cfg['gpl_temp_path'], 
                mouse_str=subject,
                export_path=save_path, 
                fov_shape=s2p_imsize, 
                ActivityType=cfg['ActivityType'], 
                UncagingLaser=cfg['UncagingLaser'], 
                UncagingLaserPower=cfg['UncagingLaserPower'], 
                Duration=cfg['Duration'], 
                IsSpiral=cfg['IsSpiral'], 
                SpiralSize=cfg['SpiralSize'], 
                SpiralRevolutions=cfg['SpiralRevolutions'], 
                Z=cfg['Z'], 
                X_lim=cfg['X_lim'], 
                Y_lim=cfg['Y_lim'],
                use_seg=use_seg)

In [None]:
# save cfg for reproducibility


In [None]:
np.save(os.path.join(save_path, 'fov_image.npy'), fov_image)
np.save(os.path.join(save_path, 'seg.npy'), seg)
np.save(os.path.join(save_path, 'seg_cur.npy'), seg_cur)
np.save(os.path.join(save_path, 'meds.npy'), meds)
np.save(os.path.join(save_path, 'meds_stim.npy'), meds_stim)

np.save(os.path.join(save_path, 'cfg.npy'), cfg, allow_pickle=True)