# Cross-registrate sessions

Specify directories and dataset patterns

In [None]:
import os
import sys
import warnings
import itertools as itt
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from holoviews.operation.datashader import datashade, regrid
from dask.diagnostics import ProgressBar
from pathlib import Path
from IPython.display import display
from ipyfilechooser import FileChooser
import param
import panel as pn
import warnings
from holoviews import opts
import ipywidgets as widgets


warnings.filterwarnings("ignore")

In [None]:
try:
    %store -r dpath
except:
    print("data not in strore")
    dpath ="//10.69.168.1/crnldata/forgetting/"


# Set up Initial Basic Parameters#
minian_path = "."

fc1 = FileChooser(dpath,select_default=True, show_only_dirs = True, title = "<b>Choose Mice Folder</b>", layout=widgets.Layout(width='100%'))
display(fc1)

# Sample callback function
def update_my_folder(chooser):
    global dpath
    dpath = chooser.selected
    %store dpath
    return 

# Register callback function
fc1.register_callback(update_my_folder)

In [None]:
cd "C:/Users/Manip2/SCRIPTS/minian/"

In [None]:
sys.path.append(minian_path)
from minian.cross_registration import (calculate_centroids, calculate_centroid_distance, calculate_mapping,
                                       group_by_session, resolve_mapping, fill_mapping)
from minian.motion_correction import estimate_motion, apply_transform
from minian.utilities import open_minian, open_minian_mf
from minian.visualization import AlignViewer

minian_path = "."
f_pattern = r"minian"
id_dims = ["session"]

param_dist = 5
output_size = 100

hv.notebook_extension('bokeh', width=100)
pbar = ProgressBar(minimum=2)
pbar.register()

## Allign videos automatically

Open datasets

In [None]:
def correct_meta(ds,  minian_ds_path_Origin):
    path_obj =Path(minian_ds_path_Origin)
    print(len(path_obj.parts))
    if len(path_obj.parts)==12:        
        ds['animal'] = path_obj.parents[5].name
    elif len(path_obj.parts)==13: 
        ds['animal'] = path_obj.parents[6].name

    ds['session'] = path_obj.parents[1].name
    #ds['animal'] = path_obj.parents[5].name
    print(f"['animal: {ds['animal'].values}']")
    return ds

minian_ds = open_minian_mf(
    dpath, id_dims, pattern=f_pattern, post_process=correct_meta)

"""
minian_ds = open_minian_mf(
    dpath, id_dims, pattern=f_pattern)
    """

Estimate shifts

In [None]:
temps = minian_ds['max_proj'].rename('temps')
#temps = temps.drop_sel(session=['session2_1', 'session2_2', 'session2_1', 'session3_1', 'session3_2', 'session3_3', 'session3_4'])
shifts = estimate_motion(temps, dim='session').compute().rename('shifts')

In [None]:
"""
try: # Try to load a pre-existing mappping
    shiftds=xr.open_dataset(f'{dpath}/shiftdsAB.nc')
    temps=shiftds.temps
    shifts=shiftds.shifts
    print('Cross-registration already done !')
except: pass 
"""
temps = minian_ds['max_proj'].rename('temps')
#temps = temps.drop_sel(session=['session2_1', 'session2_2', 'session2_1', 'session3_1', 'session3_2', 'session3_3', 'session3_4'])
shifts = estimate_motion(temps, dim='session').compute().rename('shifts')

print(shifts.session.values)
print(shifts.values)

Apply the shifts for visualization

In [None]:
temps_sh = apply_transform(temps, shifts).compute().rename('temps_shifted')
shiftds = xr.merge([temps, shifts, temps_sh])

Visualize overlap of field of view across all sessions

In [None]:
hv.output(size=int(output_size * 0.6))
opts_im = {'aspect': shiftds.sizes['width'] / shiftds.sizes['height'],'frame_width': 500, 'cmap': 'hot'}
window = shiftds['temps_shifted'].isnull().sum('session')
window, _ = xr.broadcast(window, shiftds['temps_shifted'])
hv_wnd = hv.Dataset(window).to(hv.Image, ['width', 'height'])
hv_temps = hv.Dataset(temps_sh).to(hv.Image, ['width', 'height'])
#hv_wnd.opts(**opts_im).relabel("Window") + 
hv_temps.opts(**opts_im).relabel("Automatically shifted")

## Change the shift manually /!\ only if needed

In [None]:
# Reset all shifts to 0
for i in np.arange(len(shifts)):
    shifts[i]=[0,0]
temps_sh = apply_transform(temps, shifts).compute().rename('temps_shifted')
shiftds = xr.merge([temps, shifts, temps_sh])

In [None]:
SessionTemplate = temps_sh.session[0].item()

class ImageAligner(param.Parameterized):
    session_options = list(temps_sh.session.values)
    Session_to_shift = param.ObjectSelector(default=temps_sh.session[0].item(), objects=session_options)
    Left_to_Right = param.Integer(0, bounds=(-300, 300))
    Down_to_Up = param.Integer(0, bounds=(-300, 300))
    alpha = param.Number(0.5, bounds=(0.0, 1.0))

    @param.depends('Session_to_shift', 'Left_to_Right', 'Down_to_Up', 'alpha')
    def view(self):
        
        img1 = hv.Dataset(temps_sh.sel(session=SessionTemplate)).to(hv.Image, ['width', 'height']).opts(width=800, height=800, cmap='Blues')
        img2 = hv.Dataset(temps_sh.sel(session=self.Session_to_shift)).to(hv.Image, ['width', 'height']).opts(cmap='Reds', alpha=0.5)
        img2.data['temps_shifted'].values = np.flipud(np.rot90((img2.data['temps_shifted'])))
        shifted = img2.data['temps_shifted'].shift(width=self.Down_to_Up, height=self.Left_to_Right)
        shifted_img = hv.Image(shifted).opts(width=800, height=800,cmap='Reds', alpha=self.alpha)
        return img1 * shifted_img
    
    @property
    def shifted_session(self):
        return self.Session_to_shift

aligner = ImageAligner()

layout = pn.Row(
    pn.Column(
        pn.pane.Markdown(f"## {SessionTemplate} (blue) vs selected session (red)"),
        pn.Param(aligner.param, parameters=['Session_to_shift', 'Left_to_Right', 'Down_to_Up', 'alpha'])
    ),
    pn.panel(aligner.view)
)

#layout.servable()
display(layout)

Apply changes for that specific session (only run it once per session!)

In [None]:
session_data = shifts.sel(session=aligner.shifted_session)
shifts.loc[{'session': aligner.shifted_session}] = shifts.loc[{'session': aligner.shifted_session}].values + np.array([aligner.Down_to_Up, aligner.Left_to_Right])
print(shifts.values)

In [None]:
"""
# Reset all shifts to 0
for i in np.arange(len(shifts)):
    shifts[i]=[0,0]
"""
"""
Session_to_shift='session3_4'
shifts.loc[{'session': Session_to_shift}] = np.array([42. ,64.])
print(shifts.values)
"""

## Validate changes

Apply shifts and set window

In [None]:
A_shifted = apply_transform(minian_ds['A'].chunk(dict(height=-1, width=-1)), shiftds['shifts'])

def set_window(wnd):
    return wnd == wnd.min()

window = xr.apply_ufunc(
    set_window,
    window,
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']],
    vectorize=True)

Calculate centroid distance

In [None]:
%%time
cents = calculate_centroids(A_shifted, window)
id_dims.remove("session")
dist = calculate_centroid_distance(cents, index_dim=id_dims)
dist_ft = dist[dist['variable', 'distance'] < param_dist].copy()
dist_ft = group_by_session(dist_ft)

Generate mappings

In [None]:
%%time
mappings = calculate_mapping(dist_ft)
mappings_meta = resolve_mapping(mappings)
mappings_meta_fill = fill_mapping(mappings_meta, cents)
print(len(mappings_meta_fill))
mappings_meta_fill

Visualize mappings

In [None]:
hv.output(size=int(output_size * 0.7))
alnviewer = AlignViewer(minian_ds, cents, mappings_meta_fill, shiftds)
alnviewer.show()

Save results

In [None]:
mappings_meta_fill.to_pickle(os.path.join(dpath, "mappingsAB.pkl"))
cents.to_pickle(os.path.join(dpath, "centsAB.pkl"))
#shiftds.to_netcdf(os.path.join(dpath, "shiftdsAB.nc"))