In [1]:
import pandas as pd
from src.parameters import CM_PER_PIXEL
from src.load_data import flip_y

hex_coords = pd.read_csv("../Raw-Data/hex_center_coordinates_IM-1594_07252023.csv")
hex_coords[["x", "y"]] = (
    flip_y(
        hex_coords[["x", "y"]].to_numpy(),
        hex_coords[["x", "y"]].max().to_numpy(),
    )
    * CM_PER_PIXEL
)

In [3]:
from src.load_data import get_spike_times, get_position_info

position_file_name = (
    "/Users/edeno/Documents/GitHub/tk_decoding/Raw-Data/IM-1594_07252023_position.csv"
)
spike_file_name = "/Users/edeno/Documents/GitHub/tk_decoding/Raw-Data/IM-1594_07252023_spikesWithPosition.csv"

position_info = get_position_info(position_file_name)
spike_times = get_spike_times(spike_file_name)

In [17]:
import numpy as np


time = position_info.index.values
st = [
    time[spike_time_ind.values]
    for _, spike_time_ind in spike_times.groupby(["channel", "cluster_ID"])
]
st[0]

array([[2.160000e-01],
       [9.638520e+02],
       [9.668040e+02],
       ...,
       [6.454960e+03],
       [6.454992e+03],
       [6.455160e+03]])

In [2]:
from src.load_data import load_data


position_info, spikes, multiunit_firing_rate, multiunit_HSE_times = load_data(
    position_file_name="/Users/edeno/Documents/GitHub/tk_decoding/Raw-Data/IM-1594_07252023_position.csv",
    spike_file_name="/Users/edeno/Documents/GitHub/tk_decoding/Raw-Data/IM-1594_07252023_spikesWithPosition.csv",
)

  is_start_time = (~series.shift(1).fillna(False)) & series
  is_end_time = series & (~series.shift(-1).fillna(False))


In [None]:
from src.load_data import make_track_graph, get_auto_linear_edge_order_spacing
import matplotlib.pyplot as plt
from track_linearization import plot_track_graph, plot_graph_as_1D

track_graph = make_track_graph(position_info, hex_coords)
linear_edge_order, linear_edge_spacing = get_auto_linear_edge_order_spacing(track_graph)


fig, ax = plt.subplots(figsize=(7, 7))
ax.plot(position_info["x"], position_info["y"], color="lightgrey", alpha=0.5)
plot_track_graph(track_graph, ax=ax)


fig, ax = plt.subplots(figsize=(25, 1))
plot_graph_as_1D(track_graph, linear_edge_order, linear_edge_spacing, ax=ax)

In [None]:
from track_linearization import get_linearized_position

linear_position_info = get_linearized_position(
    position_info[["x", "y"]].to_numpy(),
    track_graph,
    edge_order=linear_edge_order,
    edge_spacing=linear_edge_spacing,
    use_HMM=False,
)

linear_position_info

In [5]:
from src.load_data import determine_if_centrifugal

track_segment_id = linear_position_info["track_segment_id"].to_numpy().astype(int)
head_direction = position_info["head_direction"].to_numpy()

is_centrifugal, centrifugal_edges = determine_if_centrifugal(
    track_graph, track_segment_id, head_direction
)

In [None]:
from non_local_detector import ContFragSortedSpikesClassifier, Environment
from src.parameters import SAMPLING_FREQUENCY

# Cut out first 20 s because animal is being placed on track
start_ind = int(20.0 * SAMPLING_FREQUENCY)
position_info = position_info.iloc[start_ind:]
spikes = spikes.iloc[start_ind:]
multiunit_firing_rate = multiunit_firing_rate.iloc[start_ind:]

spike_times = [
    spikes.index.values[spikes.values[:, i].astype(bool)]
    for i in range(spikes.shape[1])
]

# environment = Environment(place_bin_size=2.0)

classifier = ContFragSortedSpikesClassifier()

classifier.fit(
    position_time=position_info.index.values,
    position=position_info[["x", "y"]].to_numpy(),
    spike_times=spike_times,
)

In [None]:
results = classifier.predict(
    spike_times=spike_times,
    time=position_info.index.values,
    n_chunks=10,
)

In [4]:
spike_times2 = [
    spikes.index.values[spikes.values[:, i].astype(bool)]
    for i in range(spikes.shape[1])
]
spike_times2

[array([ 963.852,  966.804, 1620.532, ..., 6454.96 , 6454.992, 6455.16 ]),
 array([4.808000e+00, 5.308000e+00, 5.872000e+00, ..., 6.456320e+03,
        6.456520e+03, 6.456748e+03]),
 array([2.932000e+00, 2.944000e+00, 3.076000e+00, ..., 6.452072e+03,
        6.452192e+03, 6.456020e+03]),
 array([4.544000e+00, 4.548000e+00, 4.928000e+00, ..., 6.449040e+03,
        6.454416e+03, 6.454424e+03]),
 array([5.132000e+00, 5.260000e+00, 5.500000e+00, ..., 6.449724e+03,
        6.449736e+03, 6.452352e+03]),
 array([3.428000e+00, 4.372000e+00, 1.130000e+01, ..., 6.456568e+03,
        6.456732e+03, 6.456752e+03]),
 array([  49.464,   49.776,   49.868, ..., 6419.384, 6419.624, 6443.616]),
 array([  20.472,   21.284,   25.072, ..., 6452.396, 6452.4  , 6456.776]),
 array([   8.448,    8.7  ,    8.704,    8.708,   49.904,   50.124,
          50.148,   50.192,   52.192,   53.06 ,   56.96 ,   61.876,
          68.092,   68.096,   68.528,   68.532,   69.464,   69.556,
          69.56 ,   69.576,   69.776