# script to get cellpose masks from FOV image

In [None]:
import tifffile as tiff
import numpy as np

import matplotlib.pyplot as plt

from skimage.registration import phase_cross_correlation
from scipy.ndimage import fourier_shift

import numpy as np
from cellpose import models
import os 


from photostim_deve.control_exp.io import get_med_img_s2p, get_seg_img_cp, write_mp_file, write_gpl_file
from photostim_deve.control_exp.plot   import plot_fov, plot_fov_meds

In [None]:
# fov_file = 'data_raw/jm/jm059/2025-10-28_a/fov/1100nm/TSeries-10032023-1822-015/TSeries-10032023-1822-015_Cycle00001_Ch1_000001.ome.tif'
fov_file = 'data_raw/jm/jm060/2025-10-28_a/fov/1100nm/TSeries-10032023-1822-020/TSeries-10032023-1822-020_Cycle00001_Ch1_000001.ome.tif'

mouse_str = 'jm060'

use_seg = 'cellpose'

n_cells_choose = 30 # number of cells to choose for stimulation
seed = 42

# visualisation parameters
vmax = 1000 # saturation

# exclude edges when choosing cells (due to brain growth, motion correction, etc)
edge_excl = 0.10 # fraction of image size to exclude on each side (0.1 = 10%)

In [None]:
# The parameters underneath are used to generate the MarkPoints and Galvo Point List files.
# They should usually not be changed for a fixed imaging configuration (e. g. 512x512 FOV, 1.5 zoom, 30Hz frame rate, 15 um spiral size, 50 ms duration).

fov_shape = (512, 512)
edge_excl_pix = (int(fov_shape[0]*edge_excl), int(fov_shape[1]*edge_excl))

# paths to templates
mp_temp_path = 'utils/templates/MarkPoints_template.xml'
gpl_temp_path = 'utils/templates/galvo_point_list_template.gpl'

# MarkPoints parameters (.xml file)
SpiralWidth = '0.0199325637636341'     # as proportion of the FOV
SpiralHeight = '0.0199325637636341'    # as proportion of the FOV
SpiralSizeInMicrons = '15.0000000000001'        # in microns

# Galvo point list parameters (.gpl file) - These should usually not be changed!
ActivityType = "MarkPoints" 
UncagingLaser = "Uncaging" 
UncagingLaserPower = "1000" 
Duration = "50" 
IsSpiral = "True" 
SpiralSize = "0.110870362837845" 
SpiralRevolutions = "7" 
Z = "807.424999999999"
X_lim = 2.79639654844993 # determined empirically by putting points in the corners of the FOV (I think this is voltage control to the galvo)
Y_lim = 3.09924006097119 # determined empirically by putting points in the corners of the FOV (I think this is voltage control to the galvo)

In [None]:
export_path = os.path.join('export', mouse_str)
if not os.path.exists(export_path):
    os.makedirs(export_path)

In [None]:
# load and average FOV image

img = tiff.imread(fov_file)
# deal with the fact it is float64
img = img.astype(np.float32)

mn_image = np.mean(img, axis=0)

# register images to the mean image
registered_img = np.zeros_like(img)
for i in range(img.shape[0]):
    shift, error, diffphase = phase_cross_correlation(mn_image, img[i], upsample_factor=10)
    shifted_img = fourier_shift(np.fft.fftn(img[i]), shift)
    registered_img[i] = np.fft.ifftn(shifted_img).real
registered_mn_image = np.mean(registered_img, axis=0)
fig = plt.figure(figsize=(8,8), dpi=300)
plt.imshow(registered_mn_image, cmap='gray', vmax=vmax)
plt.title('Registered Mean FOV image')
plt.axis('off')

# save registered mean image
mn_image = registered_mn_image

In [None]:

# mn_image is your 512Ã—512 numpy array (dtype float or uint8 etc.)

# 1. instantiate the model (use GPU if available)
model = models.CellposeModel(gpu=True, pretrained_model='cpsam')

# 2. Cellpose expects a list of images, possibly with channel dimension(s).
#    If your image is single-channel, wrap it in a list.
imgs = [mn_image]

# 3. Run segmentation
#    You can tune e.g. flow_threshold, cellprob_threshold, diameter, etc.
masks, flows, styles = model.eval(
    imgs,
    diameter=None,
    flow_threshold=0.4,
    cellprob_threshold=0.0,
    resample=True,
    normalize=True,
    # other options you might want to adjust:
    # invert=False, rescale=None, augment=False, tile_overlap=0.1, min_size=15
)

# 4. masks[0] is the segmentation mask for your image
mask0 = masks[0]

# Example: inspect number of objects
num_labels = mask0.max()
print("Detected", num_labels, "objects")



In [None]:
fig = plt.figure(figsize=(8, 8), dpi=300)
plt.imshow(mn_image, cmap='gray', vmax=vmax)
for label in range(1, num_labels + 1):
    # get random color for contour
    icolor = np.random.rand(3,)
    # draw contours
    contour = np.where(mask0 == label, 1, 0)
    plt.contour(contour, colors=icolor, linewidths=0.5)
plt.title('Cellpose Segmentation Overlay')
plt.axis('off')
plt.show()


In [None]:
# compute medians of the masks
meds = []
for label in range(1, num_labels + 1):
    # get mask for this label
    mask = (mask0 == label)
    # compute median position
    coords = np.column_stack(np.where(mask))
    if coords.shape[0] == 0:
        continue
    median = np.median(coords, axis=0)
    meds.append(median)

meds = np.array(meds)  # shape (num_labels, 2)

# filter meds to be within edge exclusion
meds_filtered = []
for med in meds:
    if (edge_excl_pix[0] <= med[0] <= fov_shape[0] - edge_excl_pix[0]) and (edge_excl_pix[1] <= med[1] <= fov_shape[1] - edge_excl_pix[1]):
        meds_filtered.append(med)
meds_filtered = np.array(meds_filtered)
print(f"Number of cells after edge exclusion: {len(meds_filtered)}")
meds = meds_filtered


In [None]:
# TODO: CHOOSE RANDOM CELLS AND EXPORT GPL
# now randomly choose 30 cells
# np.random.seed(seed)
# inds = np.random.choice(len(meds), n_cells_choose, replace=False)
# print(inds)........

inds = np.random.choice(len(meds), n_cells_choose, replace=False)

plot_fov(mn_image, export_path, vmax=vmax, use_seg=use_seg)
plot_fov_meds(mn_image, meds, inds, export_path, vmax=vmax, use_seg=use_seg)

In [None]:
# print the coordinates of the medians of the cells
print('Medians of the ROIs (in randomised order):')
# randomise the order of the indices
inds = np.random.permutation(inds)
count = 1 # bruker uses 1-based indexing
for i in inds:
    # add leading zeros to the index
    i_str = str(i).zfill(3)
    print(f'point: {count} ROI {i_str}: {meds[i]}')
    count += 1

In [None]:
write_mp_file(meds,
              inds,
              mp_temp_path=mp_temp_path, 
              mouse_str=mouse_str,
              export_path=export_path, 
              fov_shape=fov_shape, 
              SpiralWidth=SpiralWidth, 
              SpiralHeight=SpiralHeight, 
              SpiralSizeInMicrons=SpiralSizeInMicrons,
              use_seg=use_seg)

In [None]:
write_gpl_file(meds, 
                inds,
                gpl_temp_path=gpl_temp_path, 
                mouse_str=mouse_str,
                export_path=export_path, 
                fov_shape=fov_shape, 
                ActivityType=ActivityType, 
                UncagingLaser=UncagingLaser, 
                UncagingLaserPower=UncagingLaserPower, 
                Duration=Duration, 
                IsSpiral=IsSpiral, 
                SpiralSize=SpiralSize, 
                SpiralRevolutions=SpiralRevolutions, 
                Z=Z, 
                X_lim=X_lim, 
                Y_lim=Y_lim,
                use_seg=use_seg)