# script to get cellpose masks from FOV image
Make sure to check .gpl export works for the new version of stim_select_cp.ipynb (if it works delete the old code in the second part of notebook)

In [None]:
import numpy as np
import os
import napari
import yaml

from photostim_deve.image_analysis.plot import plot_motcorr_comparison
from photostim_deve.image_analysis.io import get_all_fov_image
from photostim_deve.image_analysis.segment import segment_fov_cpsam, get_cent_from_seg

from photostim_deve.control_exp.io import write_mp_file_cp, write_gpl_file_cp
from photostim_deve.control_exp.utils import remove_edge_masks
from photostim_deve.control_exp.plot import plot_stim, plot_segmentation_overlay

In [None]:
# set params
subject = 'jm065'


In [None]:
# TODO: remove unnecessary parameters (also from the yaml file)
with open("stim_select_cp_config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

session_type = cfg['session_type']
fov_imsize_onedim = cfg['fov_imsize_onedim']
s2p_imsize_onedim = cfg['s2p_imsize_onedim']
n_stim_cell = cfg['n_stim_cell']
session_reg_idx = cfg['session_reg_idx']
force_recompute = cfg['force_recompute']
run_motcorr = cfg['run_motcorr']
nimg_init = cfg['nimg_init']
sat_perc = cfg['sat_perc']
use_seg = cfg['use_seg']
manually_curate = cfg['manually_curate']
seed = cfg['seed']
edge_excl = cfg['edge_excl']


In [None]:
fov_imsize = (fov_imsize_onedim, fov_imsize_onedim)  # size of the FOV in pixels (assumed square)
s2p_imsize = (s2p_imsize_onedim, s2p_imsize_onedim)  # size used for Suite2p processing (assumed square)

In [None]:
subject_path = os.path.join('data_proc', 'jm', subject)
all_session_path = sorted([os.path.join(subject_path, d) for d in os.listdir(subject_path) if os.path.isdir(os.path.join(subject_path, d)) and session_type in d])
session_path = all_session_path[session_reg_idx]

save_path = os.path.join(session_path, 'stim_select_cp')
save_path_fig = os.path.join(save_path, 'figures')
os.makedirs(save_path, exist_ok=True)
os.makedirs(save_path_fig, exist_ok=True)

fov_s2p_px_fact = fov_imsize[0] / s2p_imsize[0] # both have the same aspect ratio

In [None]:
all_fov_image = get_all_fov_image(subject_path, 
                                  session_type=session_type,
                                  session_reg_idx = session_reg_idx, 
                                  run_motcorr=run_motcorr, 
                                  fov_imsize=fov_imsize, 
                                  nimg_init=nimg_init,
                                  force_recompute=force_recompute)

In [None]:
save_path

In [None]:
plot_motcorr_comparison(all_fov_image, sat_perc=sat_perc, save_path=save_path_fig, crop=(64, 64))

In [None]:
all_fov_image_seg = segment_fov_cpsam(all_fov_image, flow_threshold=0.4, cellprob_threshold=0.0, force_recompute=force_recompute, save_path=save_path, segment_only=['1100nm'])

fov_image = all_fov_image_seg['1100nm']
seg = all_fov_image_seg['1100nm_seg']
seg_cur = remove_edge_masks(fov_image, seg, edge_excl=edge_excl)

In [None]:
if manually_curate:
    viewer = napari.Viewer()
    viewer.add_image(fov_image, name='Mean FOV Image', colormap='gray', contrast_limits=[0, np.percentile(fov_image, sat_perc)])
    viewer.add_labels(seg_cur, name='Cellpose Segmentation')
    viewer.show()

    # add a line to not run the following cells until manual curation is done
    raise RuntimeError("Manual curation required! It is sufficent to do the curation of the labels layer and close napari afterwards.\nThe next cell will visualise the results of curation etc.")

In [None]:
plot_segmentation_overlay(fov_image, seg, sat_perc=sat_perc, title='Cellpose Segmentation', save_path=os.path.join(save_path_fig, 'cp_seg.png'))
plot_segmentation_overlay(fov_image, seg_cur, sat_perc=sat_perc, title='Curated Cellpose Segmentation', save_path=os.path.join(save_path_fig, 'cp_seg_cur.png'))

In [None]:
x_meds, y_meds = get_cent_from_seg(seg_cur) 

meds = np.column_stack((y_meds, x_meds))
inds = np.random.choice(len(meds), n_stim_cell, replace=False)
meds_stim = meds[inds]

In [None]:
plot_stim(fov_image, meds_stim, meds=meds, sat_perc=sat_perc, title='Curated CP medians and selected stim points (orange)', save_path=os.path.join(save_path_fig, 'meds_all.png'))
plot_stim(fov_image, meds_stim, sat_perc=sat_perc, title='Slected CP medians (stim points)', save_path=os.path.join(save_path_fig, 'meds_stim.png'))

In [None]:
if fov_s2p_px_fact != 1.0:
    meds_stim /= fov_s2p_px_fact
    meds /= fov_s2p_px_fact

In [None]:
write_mp_file_cp(meds_stim,
              mp_temp_path=cfg['mp_temp_path'], 
              mouse_str=subject,
              export_path=save_path, 
              fov_shape=s2p_imsize, 
              SpiralWidth=cfg['SpiralWidth'], 
              SpiralHeight=cfg['SpiralHeight'], 
              SpiralSizeInMicrons=cfg['SpiralSizeInMicrons'],
              use_seg=use_seg)

In [None]:
write_gpl_file_cp(meds_stim, 
                gpl_temp_path=cfg['gpl_temp_path'], 
                mouse_str=subject,
                export_path=save_path, 
                fov_shape=s2p_imsize, 
                ActivityType=cfg['ActivityType'], 
                UncagingLaser=cfg['UncagingLaser'], 
                UncagingLaserPower=cfg['UncagingLaserPower'], 
                Duration=cfg['Duration'], 
                IsSpiral=cfg['IsSpiral'], 
                SpiralSize=cfg['SpiralSize'], 
                SpiralRevolutions=cfg['SpiralRevolutions'], 
                Z=cfg['Z'], 
                X_lim=cfg['X_lim'], 
                Y_lim=cfg['Y_lim'],
                use_seg=use_seg)

In [None]:
# save cfg for reproducibility


In [None]:
np.save(os.path.join(save_path, 'fov_image.npy'), fov_image)
np.save(os.path.join(save_path, 'seg.npy'), seg)
np.save(os.path.join(save_path, 'seg_cur.npy'), seg_cur)
np.save(os.path.join(save_path, 'meds.npy'), meds)
np.save(os.path.join(save_path, 'meds_stim.npy'), meds_stim)

np.save(os.path.join(save_path, 'cfg.npy'), cfg, allow_pickle=True)

In [None]:
# now test loading the saved files
a = np.load(os.path.join(save_path, 'fov_image.npy'))
b = np.load(os.path.join(save_path, 'seg.npy'))
c = np.load(os.path.join(save_path, 'seg_cur.npy'))
d = np.load(os.path.join(save_path, 'meds.npy'))
e = np.load(os.path.join(save_path, 'meds_stim.npy'))

In [None]:
viewer = napari.Viewer()
viewer.add_image(a, name='Mean FOV Image', colormap='gray', contrast_limits=[0, np.percentile(a, sat_perc)])
viewer.add_labels(b, name='Cellpose Segmentation')
viewer.add_labels(c, name='Curated Cellpose Segmentation')
viewer.add_points(d[:, [1, 0]]*2, name='All Cell Centroids', size=6, face_color='blue')
viewer.add_points(e[:, [1, 0]]*2, name='Selected Stimulation Points', size=4, face_color='orange')
viewer.show()

# OLD VERSION OF CODE (TODO: delete once tested that the new .gpl file works in bruker)

In [None]:
import tifffile as tiff
import numpy as np

import matplotlib.pyplot as plt

from skimage.registration import phase_cross_correlation
from scipy.ndimage import fourier_shift

import numpy as np
from cellpose import models
import os 

import napari

from photostim_deve.control_exp.io import get_med_img_s2p, get_seg_img_cp, write_mp_file, write_gpl_file
from photostim_deve.control_exp.plot   import plot_fov, plot_fov_meds

%load_ext autoreload
%autoreload 2

In [None]:
fov_file = '/Volumes/data_jm_share/data_raw/jm/jm067/2025-12-02_a/fov/1100nm/TSeries-10032023-1822-008/TSeries-10032023-1822-008_Cycle00001_Ch1_000001.ome.tif'

mouse_str = 'jm067'

npx = 1024 # image size (assumed square) in pixels (512 or 1024)
fov_shape = (1024, 1024)


use_seg = 'cellpose'
manually_curate = True

n_stim_cell = 45 # number of cells to choose for stimulation
seed = 123

# visualisation parameters
vmax_perc = 99.9  # saturation percentile when plotting fov mean image

# exclude edges when choosing cells (due to brain growth, motion correction, etc)
edge_excl = 0.10 # fraction of image size to exclude on each side (0.1 = 10%)

In [None]:
# The parameters underneath are used to generate the MarkPoints and Galvo Point List files.
# They should usually not be changed for a fixed imaging configuration (e. g. 512x512 FOV, 1.5 zoom, 30Hz frame rate, 15 um spiral size, 50 ms duration).
edge_excl_pix = (int(fov_shape[0]*edge_excl), int(fov_shape[1]*edge_excl))

# paths to templates
mp_temp_path = 'utils/templates/MarkPoints_template.xml'
gpl_temp_path = 'utils/templates/galvo_point_list_template.gpl'

# MarkPoints parameters (.xml file)
SpiralWidth = '0.0199325637636341'     # as proportion of the FOV
SpiralHeight = '0.0199325637636341'    # as proportion of the FOV
SpiralSizeInMicrons = '15.0000000000001'        # in microns

# Galvo point list parameters (.gpl file) - These should usually not be changed!
ActivityType = "MarkPoints" 
UncagingLaser = "Uncaging" 
UncagingLaserPower = "1000" 
Duration = "50" 
IsSpiral = "True" 
SpiralSize = "0.110870362837845" 
SpiralRevolutions = "7" 
Z = "807.424999999999"
X_lim = 2.79639654844993 # determined empirically by putting points in the corners of the FOV (I think this is voltage control to the galvo)
Y_lim = 3.09924006097119 # determined empirically by putting points in the corners of the FOV (I think this is voltage control to the galvo)

In [None]:
export_path = os.path.join('export', mouse_str)
if not os.path.exists(export_path):
    os.makedirs(export_path)

In [None]:
# load and average FOV image

img = tiff.imread(fov_file)
# deal with the fact it is float64
img = img.astype(np.float32)

mn_image = np.mean(img, axis=0)

# register images to the mean image
registered_img = np.zeros_like(img)
for i in range(img.shape[0]):
    shift, error, diffphase = phase_cross_correlation(mn_image, img[i], upsample_factor=10)
    shifted_img = fourier_shift(np.fft.fftn(img[i]), shift)
    registered_img[i] = np.fft.ifftn(shifted_img).real
registered_mn_image = np.mean(registered_img, axis=0)
fig = plt.figure(figsize=(8,8), dpi=300)
plt.imshow(registered_mn_image, cmap='gray', vmax=np.percentile(registered_mn_image, vmax_perc))
plt.title('Registered Mean FOV image')
plt.axis('off')

# save registered mean image
mn_image = registered_mn_image

In [None]:

# mn_image is your 512Ã—512 numpy array (dtype float or uint8 etc.)

# 1. instantiate the model (use GPU if available)
model = models.CellposeModel(gpu=True, pretrained_model='cpsam')

# 2. Cellpose expects a list of images, possibly with channel dimension(s).
#    If your image is single-channel, wrap it in a list.
imgs = [mn_image]

# 3. Run segmentation
#    You can tune e.g. flow_threshold, cellprob_threshold, diameter, etc.
masks, flows, styles = model.eval(
    imgs,
    diameter=None,
    flow_threshold=0.4,
    cellprob_threshold=0.0,
    resample=True,
    normalize=True,
    # other options you might want to adjust:
    # invert=False, rescale=None, augment=False, tile_overlap=0.1, min_size=15
)

# 4. masks[0] is the segmentation mask for your image
mask0 = masks[0]

# Example: inspect number of objects
num_labels = mask0.max()
print("Detected", num_labels, "objects")



In [None]:
fig = plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(mn_image, cmap='gray', vmax=np.percentile(mn_image, vmax_perc))
for label in range(1, num_labels + 1):
    # get random color for contour
    icolor = np.random.rand(3,)
    # draw contours
    contour = np.where(mask0 == label, 1, 0)
    plt.contour(contour, colors=icolor, linewidths=0.5)
plt.title('Cellpose Segmentation Overlay (before curation)')
plt.axis('off')
plt.show()


In [None]:
# curation step in napari:

if manually_curate:
    viewer = napari.Viewer()
    viewer.add_image(mn_image, name='Mean FOV Image', colormap='gray', contrast_limits=[0, np.percentile(mn_image, vmax_perc)])
    viewer.add_labels(mask0, name='Cellpose Segmentation')
    viewer.show()

    # add a line to not run the following cells until manual curation is done
    raise RuntimeError("Manual curation required! It is sufficent to do the curation of the labels layer and close napari afterwards.\nThe next cell will visualise the results of curation etc.")



In [None]:
fig = plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(mn_image, cmap='gray', vmax=np.percentile(mn_image, vmax_perc))
for label in range(1, num_labels + 1):
    # get random color for contour
    icolor = np.random.rand(3,)
    # draw contours
    contour = np.where(mask0 == label, 1, 0)
    plt.contour(contour, colors=icolor, linewidths=0.5)
plt.title('Cellpose Segmentation Overlay (after curation)')
plt.axis('off')
plt.show()

In [None]:
# compute medians of the masks
meds = []
for label in range(1, num_labels + 1):
    # get mask for this label
    mask = (mask0 == label)
    # compute median position
    coords = np.column_stack(np.where(mask))
    if coords.shape[0] == 0:
        continue
    median = np.median(coords, axis=0)
    meds.append(median)

meds = np.array(meds)  # shape (num_labels, 2)

# filter meds to be within edge exclusion
meds_filtered = []
for med in meds:
    if (edge_excl_pix[0] <= med[0] <= fov_shape[0] - edge_excl_pix[0]) and (edge_excl_pix[1] <= med[1] <= fov_shape[1] - edge_excl_pix[1]):
        meds_filtered.append(med)
meds_filtered = np.array(meds_filtered)
print(f"Number of cells after edge exclusion: {len(meds_filtered)}")
meds = meds_filtered


In [None]:
# TODO: CHOOSE RANDOM CELLS AND EXPORT GPL
# now randomly choose 30 cells
# np.random.seed(seed)
# inds = np.random.choice(len(meds), n_stim_cell, replace=False)
# print(inds)........

inds = np.random.choice(len(meds), n_stim_cell, replace=False)

plot_fov(mn_image, export_path, vmax=np.percentile(mn_image, vmax_perc), use_seg=use_seg)
plot_fov_meds(mn_image, meds, [], export_path, vmax=np.percentile(mn_image, vmax_perc), use_seg=use_seg)
plot_fov_meds(mn_image, meds, inds, export_path, vmax=np.percentile(mn_image, vmax_perc), use_seg=use_seg)

In [None]:
# print the coordinates of the medians of the cells
print('Medians of the ROIs (in randomised order):')
# randomise the order of the indices
inds = np.random.permutation(inds)
count = 1 # bruker uses 1-based indexing
for i in inds:
    # add leading zeros to the index
    i_str = str(i).zfill(3)
    print(f'point: {count} ROI {i_str}: {meds[i]}')
    count += 1

In [None]:
# # if fov is 1024 divide meds by 2
# meds_export = meds
# if npx == 1024:
#     meds_export = np.round(meds_export / 2.0)
plot_fov_meds(mn_image, meds, inds, export_path, vmax=np.percentile(mn_image, vmax_perc), use_seg=use_seg)

In [None]:
# now replot with these indices


In [None]:
meds.shape

In [None]:
write_mp_file(meds,
              inds,
              mp_temp_path=mp_temp_path, 
              mouse_str=mouse_str,
              export_path=export_path, 
              fov_shape=fov_shape, 
              SpiralWidth=SpiralWidth, 
              SpiralHeight=SpiralHeight, 
              SpiralSizeInMicrons=SpiralSizeInMicrons,
              use_seg=use_seg)


In [None]:
write_gpl_file(meds, 
                inds,
                gpl_temp_path=gpl_temp_path, 
                mouse_str=mouse_str,
                export_path=export_path, 
                fov_shape=fov_shape, 
                ActivityType=ActivityType, 
                UncagingLaser=UncagingLaser, 
                UncagingLaserPower=UncagingLaserPower, 
                Duration=Duration, 
                IsSpiral=IsSpiral, 
                SpiralSize=SpiralSize, 
                SpiralRevolutions=SpiralRevolutions, 
                Z=Z, 
                X_lim=X_lim, 
                Y_lim=Y_lim,
                use_seg=use_seg)