# Patch motion trajectories to per-particle trajectories

This notebook will derive per-particle trajectories from patch motion trajectories, by interpolating the patch motion spline functions at the particle pick locations.

## Setup
We begin by importing libraries, connecting to the CryoSPARC instance, and defining the functions which will interpolate the patch motion splines.

In [None]:
import json
from pathlib import Path
import os
import io

import numpy as np
from scipy.ndimage import map_coordinates

from cryosparc.tools import CryoSPARC
from cryosparc import dataset

with open(Path("~/instance-info.json").expanduser(), "r") as f:
    instance_info = json.load(f)

cs = CryoSPARC(**instance_info)
assert cs.test_connection()

def spline_interp(gridshape, spl, pix):
    K_Z, K_Y, K_X = spl.shape
    mz, my, mx = gridshape
    coords = pix * np.array(
        [(K_Z - 1) / float(max(mz - 1, 1)), (K_Y - 1) / float(max(my - 1, 1)), (K_X - 1) / float(max(mx - 1, 1))],
        np.float32,
    ).reshape((3,) + (1,) * (pix.ndim - 1))

    res = map_coordinates(
        np.pad(spl, 1, "reflect", reflect_type="odd"), 
        coords + 1, 
        mode="constant", 
        prefilter=False
    )
    return res

def spline_interp_traj(gridshape, splxy, pos):
    mz, my, mx = gridshape
    N_P = pos.shape[0]
    pix = np.zeros((3, N_P, mz), dtype=np.float32)
    pix[0] = np.arange(mz).reshape(1, -1)
    pix[1] = pos[:, 1].reshape(-1, 1)
    pix[2] = pos[:, 0].reshape(-1, 1)
    res = np.empty((N_P, mz, 2), np.float32)
    res[:, :, 0] = spline_interp(gridshape, splxy[0], pix)
    res[:, :, 1] = spline_interp(gridshape, splxy[1], pix)
    return res

Next, we define the project, workspace, and jobs from which we will load data.
Eventually, particles with local trajectories will be saved in new jobs in `dest_workspace`.
Ideally, `micrograph_source` is a Patch CTF Estimation job which is downstream of a Patch Motion Correction job.

In [None]:
project_uuid   = "P423"
dest_workspace = "W4"   # <-- particles with local tracectories will be saved to this workspace

particle_source   = "J45"
micrograph_source = "J29" # <-- ideally a patch CTF job

## Creating an external job
We now create the [external job](https://tools.cryosparc.com/api/job.html#cryosparc.job.ExternalJob) which will hold the particles with updated trajectories.
Connecting the particles and micrographs as inputs has two main advantages:

1. The external job is in the correct position in the tree view
2. The input particles' fields will be [passed through](https://guide.cryosparc.com/guides-for-v3/job-builder-tutorial#passthrough-results). This makes the output dataset much smaller.

We also add new fields to the particles dataset.
These fields will store the local motion trajectories.
Currently, the best way to determine what fields are necessary for a given result is to look at a job that produces that result.
In this case, you could inspect the output of a Local Motion Correction job to determine that these fields are necessary.

In [None]:
project = cs.find_project(project_uuid)
job = project.create_external_job(dest_workspace, title="Patch to Local")
job.connect("micrographs", micrograph_source, "exposures", slots=["movie_blob", "spline_motion", "rigid_motion"])
job.connect("particles", particle_source, "particles", slots=["location"])
job.add_output("particle", "particles", slots=["motion"], passthrough="particles")

pcls = job.load_input("particles", ["location"])
mics = job.load_input("micrographs", ["movie_blob", "spline_motion", "rigid_motion"])

job.mkdir("traj")

pcls.add_fields([
    ('motion/type', 'O'), 
    ('motion/path', 'O'), 
    ('motion/idx','u4'), 
    ('motion/frame_start', 'u4'), 
    ('motion/frame_end', 'u4'), 
    ('motion/zero_shift_frame', 'u4'), 
    ('motion/psize_A', 'f4')
])

particle_subsets = pcls.split_by('location/micrograph_uid')
particle_outsubsets = []

## Running the job
We now perform the necessary computation to translate the patch motion splines into local motion trajectories.
Doing this calculation with a [context manager](https://realpython.com/python-with-statement/#the-with-statement) provides two benefits:
1. We don't need to remember to mark the job as "Done" -- this is done automatically when we exit the "with" block.
2. If there is an error during execution, the job will automatically be marked as "Failed" and the python error will be added to the job's log


In [None]:
with job.run():
    for uid in mics["uid"]:
        if uid not in particle_subsets:
            continue
        
        subset = particle_subsets[uid]
        particle_outsubsets.append(subset)
        movie = mics.mask(mics['uid']==uid).rows()[0]
        mpsz = movie['movie_blob/psize_A']

        # load trajectories
        tpath_spl = movie['spline_motion/path']
        tpsz_spl  = movie['spline_motion/psize_A']
        assert tpsz_spl != 0
        with project.download(tpath_spl) as F:
            f = io.BytesIO(F.read())
            splxy = np.load(f)

        tpath_rig = movie['rigid_motion/path']
        tpsz_rig  = movie['rigid_motion/psize_A']
        assert tpsz_rig != 0
        assert tpsz_rig == tpsz_spl
        assert tpsz_rig == mpsz
        with project.download(tpath_rig) as F:
            f = io.BytesIO(F.read())
            t_rigid = np.load(f)

        N_Z, ny, nx = movie['movie_blob/shape']

        pcl_coords = np.zeros((len(subset),2))
        pcl_coords[:,0] = subset['location/center_x_frac'] * nx
        pcl_coords[:,1] = subset['location/center_y_frac'] * ny

        # pos_raw is the patch center locations, as (x,y) pairs, in raw movie coords 
        pos_raw = np.array( pcl_coords ).astype(np.int32).copy(order='C')
    
        # to get the absolute trajectory:
        gridshape = (N_Z, ny, nx) # shape of input movie
        # pos_raw = x,y pairs of integers for center of particles in raw movie coords
        ts = (spline_interp_traj(gridshape, splxy, pos_raw) + t_rigid.reshape(1,N_Z,2)).astype(np.float32)
        
        opath = "trajs/" + str(uid) + '.npy'
        outfile = io.BytesIO()
        np.save(outfile, ts)
        outfile.seek(0)
        job.upload(opath, outfile)
        opath = os.path.join(job.uid, opath)

        subset['motion/type'][:] = 'particle'
        subset['motion/idx'][:] = np.arange(len(subset))
        subset['motion/zero_shift_frame'][:] = 0 # NOT USED 
        subset['motion/psize_A'][:] = mpsz
        subset['motion/path'][:] = opath

        assert movie['rigid_motion/frame_start'] == 0
        assert movie['rigid_motion/frame_end'] == N_Z
        subset['motion/frame_start'][:] = 0
        subset['motion/frame_end'][:] = N_Z

    full_particle_dset = dataset.Dataset.append_many(*particle_outsubsets)
    full_particle_dset = full_particle_dset.filter_prefixes(['motion','location'])

    job.save_output("particles", full_particle_dset)