# Code for registering 2d FOVs using keypoint registration
Goal of the script:
1) Find cells that express ChRmine ('fov')  AND that were successfully tracked across all days ('t2p')
2) Find cells that were stimulated ('stim') AND that were successfully tracked across all days ('t2p')
3) Export the registered 1100nm image (for downstream visualisations)

Brief outline of the script:
1) Imports all 'fov' data (usually 830nm, 920nm and 1100nm), suite2p mean fov and photostim data
2) Motion correct raw data using Suite2p's motion correction algorithm
3) Segment all three using cpsam TODO: (if it exists) replace this with the original segmentation (and curated segmentation) from the experimental procedure
4) (if not existing yet) add manual keypoints to 1100nm image and the suite2p mean fov
5) Compute affine transform that registers 1100nm ('moving') to suite2p mean fov (reference)
6) Compute affine transform that registers stim coordinates ('moving') to suite2p mean fov (reference)
6) Apply the appropriate transforms to keypoints, stim coordinates, 1100nm image and 1100nm segmentation, to have them all in the suite2p mean fov coordinate system
7) Import suite2p masks for the cells that were sucessfully tracked across all days by track2p
8) Match the stim coordinates to s2p (t2p) rois and 1100nm (fov) rois to s2p rois (using Hungarian algorithm) using euclidean distance between centroids as the metric.
9) Threshold matches based on absolute distance (max_dist_px parameter) using euclidean distance between centroids as the metric.
10) Visualise the overlay of all data and highlight matches
11) Export matched indices (for stim->t2p and fov->t2p)

### TODO:
Instead of recomputing the 1100nm segementation do it with the saved manual curation that was done on the first day of an experiment.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tifffile as tiff
from skimage.transform import warp, resize
import napari
import yaml


from photostim_deve.image_analysis.plot import plot_motcorr_comparison, plot_segmentation_overlay_dict, plot_image_seg_xy_stim, plot_keypoints_scatter
from photostim_deve.image_analysis.io import get_all_fov_image, get_s2p_image, get_xy_stim, save_keypoints, load_keypoints, get_t2p_s2p_indices_session, get_s2p_rois_filt
from photostim_deve.image_analysis.segment import segment_fov_cpsam, get_cent_from_seg
from photostim_deve.image_analysis.register import register_keypoints_affine, match_ref_moving

In [None]:
# set params
subject = 'jm065'


In [None]:
with open("match_stim_fov_t2p_config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

fov_imsize_onedim = cfg['fov_imsize_onedim']
s2p_imsize_onedim = cfg['s2p_imsize_onedim']
n_stim_cell = cfg['n_stim_cell']
n_stim_ctrl = cfg['n_stim_ctrl']
session_reg_idx = cfg['session_reg_idx']
force_recompute = cfg['force_recompute']
run_motcorr = cfg['run_motcorr']
nimg_init = cfg['nimg_init']
filt_by = cfg['filt_by']
track2p_dirname = cfg['track2p_dirname']
cell_prob_thr = cfg['cell_prob_thr']
max_dist_px = cfg['max_dist_px']
session_type = cfg['session_type']
sat_perc = cfg['sat_perc']


In [None]:
fov_imsize = (fov_imsize_onedim, fov_imsize_onedim)  # size of the FOV in pixels (assumed square)
s2p_imsize = (s2p_imsize_onedim, s2p_imsize_onedim)  # size used for Suite2p processing (assumed square)

In [None]:
subject_path = os.path.join('data_proc', 'jm', subject)
all_session_path = sorted([os.path.join(subject_path, d) for d in os.listdir(subject_path) if os.path.isdir(os.path.join(subject_path, d)) and session_type in d])
session_path = all_session_path[session_reg_idx]

match_save_dir = os.path.join(session_path, 'match_stim_fov_t2p')

keypoints_save_path = os.path.join(session_path, 'fov_reg_keypoints.csv')


fov_s2p_px_fact = fov_imsize[0] / s2p_imsize[0] # both have the same aspect ratio

In [None]:
all_fov_image = get_all_fov_image(subject_path, 
                                  session_type=session_type,
                                  session_reg_idx = session_reg_idx, 
                                  run_motcorr=run_motcorr, 
                                  fov_imsize=fov_imsize, 
                                  nimg_init=nimg_init,
                                  force_recompute=force_recompute)

In [None]:
plot_motcorr_comparison(all_fov_image, sat_perc=sat_perc, crop=(64, 64))

In [None]:
all_fov_image_seg = segment_fov_cpsam(all_fov_image, flow_threshold=0.4, cellprob_threshold=0.0, force_recompute=force_recompute, save_path=session_path)

In [None]:
plot_segmentation_overlay_dict(all_fov_image_seg, sat_perc=sat_perc)

In [None]:
s2p_image = get_s2p_image(session_path)

In [None]:
x_stim, y_stim = get_xy_stim(session_path, session_type=session_type)

In [None]:
plot_image_seg_xy_stim(all_fov_image['1100nm'], x_stim=x_stim, y_stim=y_stim, segmentation=all_fov_image_seg['1100nm_seg'], sat_perc=sat_perc, fov_s2p_px_fact=fov_s2p_px_fact)


In [None]:
s2p_image_upscaled = resize(s2p_image, fov_imsize, preserve_range=True, anti_aliasing=True).astype(s2p_image.dtype)

In [None]:
keypoints_save_path

In [None]:
if not os.path.exists(keypoints_save_path) or force_recompute:

    viewer = napari.Viewer()

    viewer.add_image(s2p_image_upscaled, name='s2p_mean_image', colormap='green', contrast_limits=(np.percentile(s2p_image_upscaled, 0.1), np.percentile(s2p_image_upscaled, sat_perc)))
    viewer.add_points(name='s2p_keypoints', size=5, face_color='green')
    viewer.add_image(all_fov_image['1100nm'], name='fov_1100nm', colormap='magenta', contrast_limits=(np.percentile(all_fov_image['1100nm'], 0.1), np.percentile(all_fov_image['1100nm'], sat_perc)))
    viewer.add_points(name='fov_1100nm_keypoints', size=5, face_color='magenta')
    napari.run()
    


In [None]:
if not os.path.exists(keypoints_save_path) or force_recompute:
    save_keypoints(viewer, keypoints_save_path)

In [None]:
x_kp_s2p, y_kp_s2p, x_kp_fov, y_kp_fov = load_keypoints(keypoints_save_path)

In [None]:
x_kp_fov_reg, y_kp_fov_reg, transform = register_keypoints_affine(x_kp_s2p, y_kp_s2p, x_kp_fov, y_kp_fov)
plot_keypoints_scatter(x_kp_s2p, y_kp_s2p, x_kp_fov, y_kp_fov, x_kp_fov_reg, y_kp_fov_reg)

In [None]:
x_stim_upscaled, y_stim_upscaled = x_stim * fov_s2p_px_fact, y_stim * fov_s2p_px_fact
x_stim_upscaled_reg, y_stim_upscaled_reg = transform.inverse(np.stack([x_stim_upscaled, y_stim_upscaled], axis=1)).T


In [None]:
# now get and transform the centroids of the 1100nm CPs
x_fov, y_fov = get_cent_from_seg(all_fov_image_seg['1100nm_seg'])
x_fov_reg, y_fov_reg = transform.inverse(np.stack([x_fov, y_fov], axis=1)).T

In [None]:
# now apply transform to the image
fov_image = all_fov_image['1100nm']
fov_image_reg = warp(fov_image, inverse_map=transform.inverse, output_shape=fov_imsize)

fov_seg = all_fov_image_seg['1100nm_seg']
fov_seg_reg = warp(fov_seg, inverse_map=transform.inverse, output_shape=fov_imsize, order=0, preserve_range=True).astype(fov_seg.dtype)


In [None]:
s2p_image_upscaled = resize(s2p_image, fov_imsize, preserve_range=True, anti_aliasing=True).astype(s2p_image.dtype)

In [None]:
t2p_idxs_session = get_t2p_s2p_indices_session(subject_path, track2p_dirname=track2p_dirname, session_reg_idx=session_reg_idx)

In [None]:
roi_s2p, x_s2p_med, y_s2p_med, idxs_filt = get_s2p_rois_filt(session_path, filt_by='t2p', t2p_idxs_session=t2p_idxs_session)

In [None]:
x_s2p_med_upscaled = x_s2p_med * fov_s2p_px_fact
y_s2p_med_upscaled = y_s2p_med * fov_s2p_px_fact

In [None]:
row_ind_s2p_fov, col_ind_s2p_fov = match_ref_moving(x_s2p_med_upscaled, y_s2p_med_upscaled, x_fov_reg, y_fov_reg, max_dist_px=max_dist_px)
row_ind_s2p_stim, col_ind_s2p_stim = match_ref_moving(x_s2p_med_upscaled, y_s2p_med_upscaled, x_stim_upscaled_reg, y_stim_upscaled_reg, max_dist_px=max_dist_px)

In [None]:
viewer = napari.Viewer()

viewer.add_image(fov_image_reg, name='fov_1100nm_registered', colormap='magenta', blending='additive', contrast_limits=(np.percentile(fov_image_reg, 0.1), np.percentile(fov_image_reg, 99.9)))
viewer.add_image(s2p_image_upscaled, name='s2p_mean_image_registered', colormap='green', blending='additive', contrast_limits=(np.percentile(s2p_image_upscaled, 0.1), np.percentile(s2p_image_upscaled, sat_perc)))
viewer.add_image(fov_seg_reg>0, name='fov_1100nm_seg_registered', opacity=0.3, colormap='magenta')
viewer.add_points(np.stack([x_s2p_med_upscaled, y_s2p_med_upscaled], axis=1), name='s2p_rois_medians_upscaled', size=5, face_color='green')
viewer.add_points(np.stack([x_stim_upscaled_reg, y_stim_upscaled_reg], axis=1), name='stim_points_registered', symbol='x', size=5, face_color='cyan')
viewer.add_points(np.stack([x_fov_reg[col_ind_s2p_fov], y_fov_reg[col_ind_s2p_fov]], axis=1), name='matched_centroids_1100nm', size=15, border_color='yellow', border_width=0.2, face_color=[0,0,0,0], opacity=0.5)
viewer.add_points(np.stack([x_stim_upscaled_reg[col_ind_s2p_stim], y_stim_upscaled_reg[col_ind_s2p_stim]], axis=1), name='matched_stim_points', symbol='s', size=20, border_color='white', border_width=0.2, face_color=[0,0,0,0], opacity=0.5)
napari.run()
    

In [None]:
# calculate proportions
n_stim_to_s2p = len(col_ind_s2p_stim)
n_fov_to_s2p = len(col_ind_s2p_fov)
n_s2p = len(x_s2p_med)

prop_stim_to_s2p = n_stim_to_s2p / n_stim_cell
prop_fov_to_s2p = n_fov_to_s2p / n_s2p

print(f'Identified {n_fov_to_s2p} (/ {n_s2p}) tracked cells as expressing opsin, corresponding to: {prop_fov_to_s2p:.3f}.')
print(f'Tracked {n_stim_to_s2p} (/ {n_stim_cell}) stimulated cells, corresponding to: {prop_stim_to_s2p:.3f}.')

In [None]:
# TODO: export this in a format that will be easy to match with longipy
# make a numpy array with indexes that are True for stimulated & tracked cells
is_stim_and_t2p = row_ind_s2p_stim
is_stim_and_t2p_idx = col_ind_s2p_stim # the index of that ROI according to the order of stimulation (from Bruker MarkPoints)
is_fov_and_t2p = row_ind_s2p_fov
is_fov_and_t2p_idx = col_ind_s2p_fov # the index of that ROI in the CP segmentation (for now not really needed)

# Save these in the a separate folder
if not os.path.exists(match_save_dir):
    os.makedirs(match_save_dir)
else:
    print(f'Matching save directory {match_save_dir} already exists.')

np.save(os.path.join(match_save_dir, 'is_stim_and_t2p.npy'), is_stim_and_t2p)
np.save(os.path.join(match_save_dir, 'is_stim_and_t2p_idx.npy'), is_stim_and_t2p_idx)
np.save(os.path.join(match_save_dir, 'is_fov_and_t2p.npy'), is_fov_and_t2p)
np.save(os.path.join(match_save_dir, 'is_fov_and_t2p_idx.npy'), is_fov_and_t2p_idx)
np.save(os.path.join(match_save_dir, 'fov_image_reg.npy'), fov_image_reg)