# Correlative SIMS-EM

In this notebook, data coming from a time of flight mass spectroscopy and electron microscopy as aligned.

For this the FIB image contained in the TOF data set is extracted and manually aligned to the EM image. 

In [None]:
import mrcfile
import xarray as xr
import numpy as np
import twtof
import matplotlib.pyplot as plt


def mrcread(filename):
    """Load MRC files as xarray"""
    with mrcfile.open(em_path) as mrc:
        data = np.asarray(mrc.data)
        da = xr.DataArray(
            data=data,
            coords={
                "Y": np.arange(data.shape[0]) * mrc.voxel_size.y / 10000,
                "X": np.arange(data.shape[1]) * mrc.voxel_size.x / 10000,
            },
        )
        da = da.pint.quantify({"Y": "micrometer", "X": "micrometer"})
        return da


def xrshow(data):
    """Display a 2D xarray"""
    img = data.pint.dequantify(format="~P")
    plt.imshow(
        img,
        extent=(0, img.X[-1], 0, img.Y[-1]),
        cmap="gray",
    )
    plt.xlabel(f"X [{img.X.units}]")
    plt.ylabel(f"Y [{img.Y.units}]")


em_path = "data/20241016_HO_134_1_S1_cell41.mrc"
em_data = mrcread("data/20241016_HO_134_1_S1_cell41.mrc")

sims_path = "data/20241017_Fluorine_HO_134_1_spot8_run1_30kV_50pA.h5"
fib_data = twtof.fibread(sims_path)

plt.subplot(121)
xrshow(em_data)
plt.title(f"EM {em_data.shape}")
plt.subplot(122)
xrshow(fib_data[2:4].mean(axis=0))
plt.title(f"FIB {fib_data.shape}");

In [None]:
import napari
from skimage import transform

class CornerAligner:
    """Corner aligner"""

    def __init__(self, viewer, reference, moving):
        self.moving = moving
        self.reference_layer = viewer.add_image(
            **self.xarray2napari(reference), name="reference", colormap="gray"
        )
        self.moving_layer = viewer.add_image(
            **self.xarray2napari(moving), opacity=0.75, name="moving", colormap="green"
        )
        self.original_points = self.get_corners(moving)
        self.points_layer = viewer.add_points(
            self.original_points.copy(), size=0.4, face_color="blue"
        )
        self.points_layer.mouse_drag_callbacks.append(self.update)
        self.tfm = transform.AffineTransform()
        

    def xarray2napari(self, data):
        data = data.pint.dequantify("~P")
        dims = data.dims
        scale = [(data[d][1].item() - data[d][0].item()) for d in dims]
        return {"data": data.to_numpy(), "scale": scale}
    
    def get_corners(self, data: xr.DataArray):
        X = data["X"].to_numpy()
        Y = data["Y"].to_numpy()
        return np.stack(
            np.meshgrid(
                np.linspace(Y[0], Y[-1], 2), np.linspace(X[0], X[-1], 2), indexing="ij"
            )
        ).T.reshape(-1, 2)

    def update(self, points_layer, event):
        """inspired from https://napari.org/dev/gallery/mouse_drag_image_warping.html"""
        yield
        # while the mouse is moving, we call our warp function
        while event.type == "mouse_move":
            # ensure a point is selected and we're in select mode
            if not points_layer.selected_data or points_layer.mode != "select":
                return
            # find the index into the points data of the currently selected point
            # we use the last selected point as that's likely what the mouse is hovering
            # over
            moved_point_index = list(points_layer.selected_data)[-1]
            # make a copy of the moving_points so original array is unchanged
            dst = points_layer.data.copy()
            # assign the current mouse position into the correct index to
            # update the location of the point
            if len(event.position) == 3:
                dst[moved_point_index] = event.position[1:]
            else:
                dst[moved_point_index] = event.position
            # warp image
            self.tfm.estimate(self.original_points, dst)
            self.moving_layer.affine = self.tfm.params
            self.moving_layer.refresh()
            points_layer.data = self.tfm(self.original_points.copy())
            yield


viewer = napari.Viewer()
aligner = CornerAligner(viewer, fib_data, em_data)


In [None]:
S1 = transform.AffineTransform(scale=[twtof.get_scale(em_data,"X"),twtof.get_scale(em_data,"Y")]).params
R = aligner.tfm.params
S2 = transform.AffineTransform(scale=[twtof.get_scale(fib_data,"X"),twtof.get_scale(fib_data,"Y")]).params
T = transform.AffineTransform(matrix=np.linalg.inv(S1)@np.linalg.inv(R)@S2)
em_aligned = transform.warp(em_data.as_numpy().T, T, output_shape=fib_data[0].shape).T

def stretch_contrast(x):
    a,b = np.percentile(x, [1, 99])
    if isinstance(x,xr.DataArray):
        return (x.as_numpy().astype(float)-a)/(b-a)
    else:
        return (x.astype(float)-a)/(b-a)

u,v = stretch_contrast(fib_data[5]), stretch_contrast(em_aligned)
rgb = np.stack([u,np.maximum(u,v),u],-1)

plt.imshow(rgb)
