In [1]:
import pandas as pd
import os
import numpy as np
from open_ephys.analysis import Session

In [2]:
def align_ephys_data(main_processor_tuple, aux_processor_tuples, session_path=None, synch_channel=1):
    session_data = Session(str(session_path))
    if len(session_data.recordnodes) != 1:
        raise ValueError("should be exactly one record node.")
    if len(session_data.recordnodes[0].recordings) != 1:
        raise ValueError("Should be exactly one recording.")
    for rn, recordnode in enumerate(session_data.recordnodes):
        for r, recording in enumerate(recordnode.recordings):
            recording.add_sync_line(
                synch_channel,
                main_processor_tuple[0],
                main_processor_tuple[1],
                main=True,
            )
            for aux_processor in aux_processor_tuples:
                recording.add_sync_line(
                    synch_channel,
                    aux_processor[0],
                    aux_processor[1],
                    main=False,
                )
            print('this should be zero:')
            print(rn)
    return recording

In [3]:
organized_data_folder = "/ceph/sjones/projects/sequence_squad/organised_data/animals/EJT178_implant1/recording7_30-03-2022/"
processed_data_folder = "/nfs/gatsbystor/nicholasg/striatal_replay/processed_data"
open_ephys_folder = "/ceph/sjones/projects/sequence_squad/data/raw_neuropixel/OE_DATA/EJT178/300322/2022-03-30_13-48-39/"

In [4]:
all_trajectories = pd.read_csv(os.path.join(processed_data_folder, "all_trajectories.csv"), index_col=0)
all_trajectories

Unnamed: 0,trial_id,camera_idx,camera_time,ephys_time,linear_position,x_position,y_position,port1,port2
22191,1,22191,369.857173,4800.884355,0.000000,670.492310,513.852478,2,1
22192,1,22192,369.873841,4800.901010,-0.607927,671.156128,512.822449,2,1
22193,1,22193,369.890509,4800.917665,-0.010407,670.305847,513.338928,2,1
22194,1,22194,369.907178,4800.934320,0.082560,670.071716,513.164001,2,1
22195,1,22195,369.923846,4800.950974,0.010824,670.314087,513.453674,2,1
...,...,...,...,...,...,...,...,...,...
163228,276,163228,2720.473922,7150.310765,597.237558,650.354614,576.198364,3,7
163229,276,163229,2720.490593,7150.327440,606.095453,654.041809,570.472900,3,7
163230,276,163230,2720.507265,7150.344115,616.064659,655.980164,564.733582,3,7
163231,276,163231,2720.523936,7150.360789,627.483287,657.201050,558.478149,3,7


In [5]:
Fs = 30000.0
cutoff = 1_000_000

session_data = Session(str(open_ephys_folder))
recording = session_data.recordnodes[0].recordings[0]
recording.add_sync_line(1, 102, '0', main=True)
recording.compute_global_timestamps()
ephys_timestamps = recording.continuous[2].global_timestamps

spike_times = (np.load(os.path.join(organized_data_folder, "ephys", "kilosort3", "spike_times.npy")).squeeze() / Fs) + ephys_timestamps[0]
spike_clusters = np.load(os.path.join(organized_data_folder, "ephys", "kilosort3", "spike_clusters.npy")).squeeze()

unique_clusters = np.unique(spike_clusters)
clusters_count = len(unique_clusters)
clusters_mapping = np.arange(np.max(spike_clusters)+1)
clusters_mapping[unique_clusters] = np.arange(clusters_count)
spikes = np.zeros((clusters_count, cutoff), dtype=int)
spikes[clusters_mapping[spike_clusters][:cutoff], np.arange(spikes.shape[1])] = 1

trial_start_times = all_trajectories.groupby(["trial_id"]).agg("first")["ephys_time"]
trial_start_times.name = "start_time"
trial_end_times = all_trajectories.groupby(["trial_id"]).agg("last")["ephys_time"]
trial_end_times.name = "end_time"

trial_ephys_times = pd.concat([trial_start_times, trial_end_times], axis=1)

times = np.array([])
for i, (trial_time) in trial_ephys_times.iterrows():
    start = trial_time["start_time"]
    end = trial_time["end_time"]
    mask_start = (spike_times >= start).argmax()
    mask_end = (spike_times >= end).argmax()
    times = np.append(times, spike_times[mask_start:mask_end])
    if len(times) > cutoff:
        times = times[:cutoff]
        break

position = np.array([])
for i, spike_time in enumerate(times):
    print(f"{(i+1)/len(times)*100:.2f}%", end="\r")
    idx = (all_trajectories["ephys_time"] - spike_time).abs().argmin()
    position = np.append(position, all_trajectories.iloc[idx]["linear_position"])

np.save(os.path.join(processed_data_folder, "spikes.npy"), spikes)
np.save(os.path.join(processed_data_folder, "position.npy"), position)
np.save(os.path.join(processed_data_folder, "times.npy"), times)
display(spikes)
display(position)
display(times)


Processor ID: 102, Stream Name: 0, Line: 1 (main sync line))
  First event sample number: 60920380
  Last event sample number: 512017786
  Total sync events: 15036
  Sample rate: 30000
100.00%

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

array([   0.        ,    0.        ,    0.        , ..., -265.65648489,
       -265.65648489, -265.65648489])

array([4800.88493333, 4800.88536667, 4800.88553333, ..., 6471.92076667,
       6471.92086667, 6471.92093333])

In [6]:
spikes = np.load(os.path.join(processed_data_folder, "spikes.npy"))
position = np.load(os.path.join(processed_data_folder, "position.npy"))
times = np.load(os.path.join(processed_data_folder, "times.npy"))
display(spikes)
display(position)
display(times)

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

array([   0.        ,    0.        ,    0.        , ..., -265.65648489,
       -265.65648489, -265.65648489])

array([4800.88493333, 4800.88536667, 4800.88553333, ..., 6471.92076667,
       6471.92086667, 6471.92093333])