# Registration

In [None]:
import xarray as xr
import numpy as np

import cedalion
import cedalion.io
import cedalion.dataclasses as cdc
import cedalion.geometry.registration
import cedalion.geometry.segmentation
import cedalion.plots

xr.set_options(display_expand_data=False);

## Read optode locations from snirf file

Optode locations are returned as a 2D xr.DataArray. Different labeled points are found along the first dimension 'label'. The second dimensions contains the 3D coordinates of each point. There is an abundance of coordinate system (CRS) definitions and in this example alone we have to distinguish between these different coordinate system:
- the segmented volume is in voxel space, denoted 'ijk', unitless
- the coordinates with physical units in scanner or atlas space
- the coordinate system of the digitization device

To keep track we use the name of the second dimension to store an identifier for CRS.

In [None]:
elements = cedalion.io.read_snirf("../../data/BIDS-NIRS-Tapping/sub-01/nirs/sub-01_task-tapping_nirs.snirf")
geo3d_meas = elements[0].geo3d
geo3d_meas = geo3d_meas.points.rename({"NASION" : "Nz"})
geo3d_meas = geo3d_meas.rename({"pos" : "digitized"})
display(geo3d_meas)

## Read segmented MRI scans

The image cubes are returned as a stacked xr.DataArray. 

In [None]:
DATADIR = "/home/eike/Projekte/ibslab/30_dev/AtlasViewerPy/demo_data"

In [None]:
masks, t_ijk2ras = cedalion.io.read_segmentation_masks(DATADIR+"/anatomy_data")
masks

Additionaly, a transformation matrix is returned ton convert from voxel space (ijk) to scanner space as it is defined in the niftii files. Since the segmentation masks were derived from a MRI scan, nibabel denotes the coordinate system with the affine code `'aligned'`.

The transformation matrices are also xr.DataArrays that contain both CRS names as dimension names. When applying this transformation to coordinates in voxel space (`'ijk'`) the matrix multiplication will contract the `'ijk'` dimension and the coordinates will have their coordinate dimension named `'aligned'`. The units of the transformation matrix will take care of necessary unit conversions. Here dimensionless in voxel space to millimeter in scanner space.

In [None]:
t_ijk2ras # transform from voxel space (ijk) to scanner space (x=Right y=Anterior z=Superior)

## Derive surfaces from segmentations

In [None]:
pial_surface = cedalion.geometry.segmentation.surface_from_segmentation(masks, ["wm", "gm"])
pial_surface = pial_surface.apply_transform(t_ijk2ras)

scalp_surface = cedalion.geometry.segmentation.surface_from_segmentation(
    masks, 
    masks.segmentation_type.values, # select all
    fill_holes_in_mask=True)
scalp_surface = scalp_surface.apply_transform(t_ijk2ras)
display(scalp_surface)

## Load landmarks of the loaded scan.

These were handpicked are define a reference to which the otopde positions should be registered. 

In [None]:
geo3d_volume = cedalion.io.read_mrk_json(DATADIR+"/anatomy_data/landmarks.mrk.json", crs="aligned")
geo3d_volume

## Simple registration algorithm
Find an affine transformation that translates and rotates the optode coordinates to match the landmarks.
Scaling is allowed only to transform units.

In [None]:
trafo = cedalion.geometry.registration.register_trans_rot(geo3d_volume, geo3d_meas)
display(trafo)
cedalion.plots.plot3d(None, scalp_surface.mesh, geo3d_meas.points.apply_transform(trafo), None) 

## Snap points to closest vertex on the scalp surface

In [None]:
snapped = scalp_surface.snap(geo3d_meas.points.apply_transform(trafo))
cedalion.plots.plot3d(None, scalp_surface.mesh, snapped, None) 

## Compare common landmarks in both point sets

In [None]:
common = snapped.points.common_labels(geo3d_volume)
display(geo3d_volume.sel(label=common))
display(snapped.sel(label=common))

## Transform registered optode locations back to voxel space

In [None]:
t_ras2ijk = cedalion.xrutils.pinv(t_ijk2ras)
snapped.points.apply_transform(t_ras2ijk).round()

## ICP registration [WIP]

In [None]:
losses, trafos = cedalion.geometry.registration.register_icp(scalp_surface, geo3d_volume, elements[0].geo3d)

p.plot(losses)

In [None]:
reg2 = elements[0].geo3d.points.apply_transform(trafos[-1])
cedalion.plots.plot3d(None, scalp_surface, reg2, None)
display(trafos[-1])

In [None]:
simple_scalp = surface.as_trimesh().simplify_quadric_decimation(60e3)
simple_brain = pial_surface.simplify_quadric_decimation(60e3)

In [None]:
brain_mask = masks.sel(segmentation_type=["gm", "wm"]).sum("segmentation_type")

In [None]:
cell_coords = cedalion.imagereco.geometry.cell_coordinates(brain_mask, t_vox2ras).stack({"cell" : ["i","j","k"]})

In [None]:
from scipy.spatial import KDTree
t = KDTree(simple_brain.vertices)

In [None]:
cell_indices = np.flatnonzero(brain_mask.values)
dists, vertex_indices = t.query(cell_coords[:,indices].values.T, workers=-1)

In [None]:
cell_indices

In [None]:
import scipy.sparse
scipy.sparse.coo_matrix?

In [None]:
ncells = np.prod(brain_mask.shape)
nvertices = len(simple_scalp.vertices)
Mcoo = scipy.sparse.coo_array((np.ones(len(cell_indices)), (vertex_indices, cell_indices)), shape=(nvertices, ncells)) 
Mcsr = scipy.sparse.csr_array((np.ones(len(cell_indices)), (vertex_indices, cell_indices)), shape=(nvertices, ncells)) 

In [None]:
test = np.arange(ncells)

In [None]:
%timeit (Mcoo @ test)

In [None]:
%timeit (Mcsr @ test)

In [None]:
t_ras2vox = np.linalg.pinv(t_vox2ras).round(12)

In [None]:
reg2.pint.to("mm").points.apply_transform(t_ras2vox).max("label")

In [None]:
[str(i) for i in geo3d_volume.label.values]

In [None]:
geo3d_volume.pint.to("mm").pint.dequantify().values

In [None]:
#trimesh.smoothing.filter_taubin(pial_surface, lamb=0.5).show()
#pial_surface_low = pial_surface.simplify_quadric_decimation(60e3)

#display(pial_surface)
#display(pial_surface_low)
#trimesh.smoothing.filter_taubin(pial_surface.mesh, lamb=0.5)

In [None]:
# calculate median tri size
#tri = pial_surface_low.vertices[pial_surface_low.faces]
#a = np.linalg.norm(tri[:,1,:] - tri[:,0,:], axis=1)
#b = np.linalg.norm(tri[:,2,:] - tri[:,0,:], axis=1)
#A = a*b/2
#np.median(A), np.std(A)