# NinjaCap-wholeHeadHD-probe registration to Colin27
This example jupyter notebook shows how to align the optodes of the NinjaCap-wholeHeadHD to the Colin27 head model.
Thorough coregistration is the foundation of every data analysis using head models. 

Currently, `cedaĺion` offers a simple registration method, which finds an affine transformation (scaling, rotating, translating) that matches the landmark positions of the head model and their digitized counterparts (probe data). Afterward, optodes are snapped to the nearest vertex on the scalp.

In [None]:
import pyvista as pv
#pv.set_jupyter_backend('html')
pv.set_jupyter_backend('static')
#pv.OFF_SCREEN=True

In [None]:
import os, numpy as np, xarray as xr

import cedalion
import cedalion.io as cio
import cedalion.datasets
import cedalion.imagereco.forward_model as fw
import cedalion.geometry.registration as cgeoreg
import cedalion.geometry.landmarks as cgeolm
import cedalion.plots as cp

## Load segmented MRI scan

For this example use a segmentation of the Colin27 average brain.

In [None]:
SEG_DATADIR, mask_files, landmarks_file = cedalion.datasets.get_colin27_segmentation()
masks, t_ijk2ras = cedalion.io.read_segmentation_masks(SEG_DATADIR, mask_files)

Construct Colin27 headmodel from segmentation masks

In [None]:
colin = fw.TwoSurfaceHeadModel.from_surfaces(
    segmentation_dir=SEG_DATADIR,
    mask_files = mask_files,
    brain_surface_file= os.path.join(SEG_DATADIR, "mask_brain.obj"),
    scalp_surface_file= os.path.join(SEG_DATADIR, "mask_scalp.obj"),
    landmarks_ras_file=landmarks_file,
    smoothing=0.5,
    fill_holes=True,
)
colin.scalp.units = cedalion.units.mm
colin.brain.units = cedalion.units.mm

## Compute EEG's 10-10 system landmarks of Colin27 for optode coregistration

In [None]:
# Build the 10-10 system landmarks from the fiducials and the scalp using cedalions LandmarksBuilder1010
scalp_surface = colin.scalp

# Align fiducials to head coordinate system
fiducials_ras = cio.read_mrk_json(os.path.join(SEG_DATADIR, landmarks_file), crs="aligned")
fiducials_ijk = fiducials_ras.points.apply_transform(np.linalg.pinv(t_ijk2ras))
# Compute landmarks by EEG's 1010 system rules
lmbuilder = cgeolm.LandmarksBuilder1010(scalp_surface, fiducials_ijk)
all_landmarks = lmbuilder.build()
lmbuilder.plot()

## Load NinjaCap data

In [None]:
ninjacap_optodes, ninjacap_landmarks, meas_list = cedalion.datasets.get_ninja_cap_probe() 

In [None]:
# Handpick or load handpicked fiducials from file
fiducials_ras = cio.read_mrk_json(os.path.join(SEG_DATADIR, landmarks_file), crs="aligned")

## Construct transform from matching landmarks

In [None]:
# Individial landmarks
individual_ref_pos = np.array(all_landmarks) 
individual_ref_labels = [lab.item() for lab in all_landmarks.label] 

# Load ninja cap data
ninjacap_optodes, ninjacap_landmarks, meas_list = cedalion.datasets.get_ninja_cap_probe() 
ninja_ref_pos = list(np.array(ninjacap_landmarks.values))
ninja_ref_labels = list(np.array(ninjacap_landmarks.label))

# Construct transform from intersection
intersection = list(set(ninja_ref_labels) & set(individual_ref_labels)) 
individual_ref_pos = [individual_ref_pos[individual_ref_labels.index(intsct)] for intsct in intersection]
ninja_ref_pos = [ninja_ref_pos[ninja_ref_labels.index(intsct)] for intsct in intersection]
print("%d Landmarks used for co-registration:\n" % len(intersection), intersection)

In [None]:
# This transform is somehow not working: I havn't figured out why yet
"""
# Individial landmarks
individual_ref_pos = all_landmarks
# the landmarks are in Colins current coordinate system
individual_ref_pos = individual_ref_pos.rename({individual_ref_pos.points.crs: colin.scalp.crs})
individual_ref_labels = [lab.item() for lab in all_landmarks.label] 

# Load ninja cap data
ninjacap_optodes, ninjacap_landmarks, meas_list = cedalion.datasets.get_ninja_cap_probe() 
ninja_ref_pos = ninjacap_landmarks
ninja_ref_labels = list(np.array(ninjacap_landmarks.label))

# Construct transform from intersection
intersection = list(set(ninja_ref_labels) & set(individual_ref_labels))
print("%d Landmarks used for co-registration:\n" % len(intersection), intersection)

individual_ref_pos = individual_ref_pos.sel(label=intersection)
ninja_ref_pos = ninja_ref_pos.sel(label=intersection)
ninja_ref_pos = ninja_ref_pos.pint.quantify(cedalion.units.mm)

T = cgeoreg.register_trans_rot_isoscale(individual_ref_pos, ninja_ref_pos)
"""

In [None]:
# Alternative, non-cedalion, implementation from atlasviewer
def gen_xform_from_pts(p1, p2):
    """
    given two sets of points, p1 and p2 in n dimensions,
    find the n-dims affine transformation matrix t, from p1 to p2.

    Source: https://github.com/bunpc/atlasviewer/blob/71fc98ec8ca54783378310304113e825bbcd476a/utils/gen_xform_from_pts.m#l4
    
    parameters:
    p1 : ndarray
        an array of shape (p, n) representing the first set of points.
    p2 : ndarray
        an array of shape (p, n) representing the second set of points.

    returns:
    t : ndarray
        the (n+1, n+1) affine transformation matrix.
    """
    p1, p2 = np.array(p1), np.array(p2)
    p = p1.shape[0]
    q = p2.shape[0]
    m = p1.shape[1]
    n = p2.shape[1]
    
    if p != q:
        raise valueerror('number of points for p1 and p2 must be the same')
    
    if m != n:
        raise valueerror('number of dimensions for p1 and p2 must be the same')
    
    if p < n:
        raise valueerror(f'cannot solve transformation with fewer anchor points ({p}) than dimensions ({n}).')
    
    t = np.eye(n + 1)
    a = np.hstack((p1, np.ones((p, 1))))
    
    for ii in range(n):
        x = np.linalg.pinv(a) @ p2[:, ii]
        t[ii, :] = x
        
    return t


T = gen_xform_from_pts(ninja_ref_pos, individual_ref_pos); # get affine  

## Apply transform and snap optodes

In [None]:
# Apply transform
ninja_aligned = ninjacap_optodes.points.apply_transform(T)
if isinstance(T, np.ndarray):
    ninja_aligned = ninja_aligned.rename({ninja_aligned.points.crs: colin.scalp.crs})
plt = pv.Plotter()
cedalion.plots.plot_surface(plt, colin.scalp, opacity=0.1)
cedalion.plots.plot_labeled_points(plt, ninja_aligned)
plt.show()

# Snap to surface
ninja_snapped_aligned = colin.scalp.snap(ninja_aligned)
# Plot
plt = pv.Plotter()
cedalion.plots.plot_surface(plt, colin.scalp)
cedalion.plots.plot_labeled_points(plt, ninja_snapped_aligned)
plt.show()

In [None]:
# Construct forward model
fwm = cedalion.imagereco.forward_model.ForwardModel(colin, ninja_snapped_aligned, meas_list)