In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pandas as pd
import glob
import os
import sys

In [None]:
sessions = []
# recordings are [mouse_name]_[date] and we assume one insertion per mouse per day
recordings = ["mousename1_date1", "mousename2_date2"]
# path to histology points or channel position map data by mouse name
histology_directory = {"mousename1": "/path/to/registered/histology/",
                       "mousename2": "/path/to/registered/histology/",
                      }
# in the processed data directory, we expect one recording folder per date matching the glob: *{curr_name}*{curr_date}*
processed_data_directory = {
    "mousename1": "/path/to/catgt/kilosort/tprime/and/ecephys/processed/data/mousename1/",
    "mousename2": "/path/to/catgt/kilosort/tprime/and/ecephys/processed/data/mousename2/",
}

# replace path below with directory to contain results of probe placement into atlas
output_directory = os.path.join('/directory/to/export/atlas/probe/locations/', 'hist_export_data')

In [None]:
atlas_pixel_size = [10., 10., 10.]
pixels_per_micron = 1./atlas_pixel_size[0]

def compute_tract(pos):
    """ Takes points pos as input and computes a tract bottom and vector
    """
    z_idx = 1
    r0 = pos[pos[:,z_idx] == pos[:,z_idx].max()]
    top_of_tract = pos[pos[:,z_idx] == pos[:,z_idx].min()]
    rel_pos = pos - r0
    U,S,V = np.linalg.svd(rel_pos)
    V = V[0,:]
    tract_len = np.linalg.norm(top_of_tract - r0)
    tmax = tract_len / np.linalg.norm(V)
    if np.linalg.norm(top_of_tract-(tmax*V+r0)) > np.linalg.norm(top_of_tract-r0):
        V = -V
    tvals = np.arange(tmax)
    tract = np.tile(tvals.T,[3,1]).T * V + r0
    return tract, r0, V

def transform_to_atlas(pos, r0, V, pixels_per_micron=pixels_per_micron):
    """
    Return the brain areas for an array of points 
    Args:
        pos: 2D array of neuron positions along the electrode in um 
    Returns:
        transformed_pos: positions of each unit in reference space, in pixels
    """
    if len(pos.shape) == 1:
        pos_pix = pos * pixels_per_micron
    else:
        pos_pix = pos[:,1] * pixels_per_micron # if have 2d coords, just use depth
    pos_pix = np.repeat(pos_pix,3).reshape([len(pos_pix),3]) # make Nx3 array to multiply by V
    transformed_pos = r0 + pos_pix / np.linalg.norm(V) * V
    return transformed_pos

In [None]:
# load all experiments and transform all unit positions into atlas coords (and get area locations)

for recording in recordings:
    
    curr_name = recording.split('_')[0]
    curr_date = recording.split('_')[1]
    
    print(f'Processing: {recording}...')
    
    # grab file list for trajectories of a given insertion (by date)
    # trajectory filename convention is: [imecN]_[date]_[tract_name].csv
    tract_files = sorted(glob.glob(os.path.join(histology_directory[curr_name], f"*{curr_date}*.csv")))
    tract_names = [os.path.basename(f).split('.csv')[0] for f in tract_files]
        
    transformed_chans = []
    transformed_locs = []
    chan_probe_id = []


    for i, tract in enumerate(tract_names):
        # get probe number
        idx = int(tract.split('_')[0].split('imec')[-1])
        
        ecephys_output_dir = glob.glob(os.path.join(processed_data_directory[curr_name], f"*{curr_name}*{curr_date}*", f"{curr_name}_{curr_date}*imec{idx}", f"imec{idx}_ks2"))[0]
        
        ### Reconstruct trajectory from points -- BEGIN
        # comment out below and skip to load _transformed_chans if already have it
        
        # get traced points
        pts = np.array(pd.read_csv(tract_files[i])[["axis-0","axis-1","axis-2"]].astype('float64'))
        
        # get trajectory
        t, r0, V = compute_tract(pts)
                
        # get channel positions
        chan_pos = np.load(os.path.join(ecephys_output_dir, "channel_positions.npy"))
        
        _chan_probe_id = np.array(chan_pos.shape[0]*[idx])
        
        # transform channel positions to atlas using trajectory
        _transformed_chans = transform_to_atlas(chan_pos, r0, V)
        
        ### Reconstruct trajectory from points -- END
        
        # Uncomment to load existing channel-atlas locations
        # _transformed_chans = np.load("/path/to/transformed/chan/positions.npy")
        
        # transform clusters to atlas
        if os.path.exists(os.path.join(ecephys_output_dir, "cluster_info.tsv")):
            metrics = pd.read_csv(os.path.join(ecephys_output_dir, "cluster_info.tsv"), sep="\t")
        else:
            metrics = pd.read_csv(os.path.join(ecephys_output_dir, "metrics.csv"))
        peak_cluster_channels = np.array(metrics["peak_channel"])
        _transformed_locs = _transformed_chans[peak_cluster_channels]

        transformed_chans.append(_transformed_chans)
        transformed_locs.append(_transformed_locs)
        chan_probe_id.append(_chan_probe_id)

    transformed_chans = np.vstack(transformed_chans)
    transformed_locs = np.vstack(transformed_locs)
    chan_probe_id = np.hstack(chan_probe_id)
    
    np.savez(os.path.join(output_directory, f"{curr_name}_{recording}.npz"),
                        transformed_locs=transformed_locs, transformed_chans=transformed_chans, chan_probe_id=chan_probe_id)
print('Done!')