# Learn

In [None]:
import numpy as np
import sys
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl

import replay_trajectory_classification as rtc
import track_linearization as tl

  from tqdm.autonotebook import tqdm


In [None]:
positions_filename = "../../../../datasets/decoder_data/position_info.pkl"
spikes_filename = "../../../../datasets/decoder_data/clusterless_spike_times.pkl"
features_filename = "../../../../datasets/decoder_data/clusterless_spike_features.pkl"
model_filename = "../../../../datasets/decoder_data/clusterless_spike_decoder.pkl"
decoding_filename = "../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl"

In [4]:
positions_df = pd.read_pickle(positions_filename)
timestamps = positions_df.index.to_numpy()
dt = timestamps[1] - timestamps[0]
Fs = 1.0 / dt
spikes_bins = np.append(timestamps-dt/2, timestamps[-1]+dt/2)

In [5]:
positions_df

Unnamed: 0_level_0,nose_x,nose_y,nose_vel,tailBase_x,tailBase_y,tailBase_vel,tailMid_x,tailMid_y,tailMid_vel,tailTip_x,...,hindpawR_vel,forelimb_mid_x,forelimb_mid_y,forelimb_vel,body_dir,linear_position,track_segment_id,projected_x_position,projected_y_position,arm_name
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
22389.082875,6.302628,5.231174,45.320770,12.856064,5.547496,48.290909,11.367913,6.833021,56.508186,16.296467,...,,5.080671,3.215346,125.707005,2.338472,162.285019,3,6.345101,8.359380,Left Arm
22389.084875,6.755276,5.644503,45.579716,13.503448,5.689286,48.483751,12.142827,7.315345,56.751256,17.387420,...,,5.194684,3.302497,124.976602,2.347791,162.732014,3,6.792055,8.353311,Left Arm
22389.086875,7.207925,6.057833,45.838661,14.150832,5.831076,48.676593,12.917740,7.797669,56.994325,18.478373,...,,5.308697,3.389648,124.246200,2.357110,163.179010,3,7.239009,8.347243,Left Arm
22389.088875,7.660573,6.471162,46.097607,14.798215,5.972866,48.869435,13.692654,8.279993,57.237395,19.569326,...,,5.422709,3.476798,123.515797,2.366429,163.626005,3,7.685963,8.341174,Left Arm
22389.090875,8.113222,6.884492,46.356552,15.445599,6.114657,49.062277,14.467568,8.762317,57.480464,20.660279,...,,5.536722,3.563949,122.785394,2.375748,164.073000,3,8.132917,8.335106,Left Arm
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23293.722875,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.000000,0.000000,0.000000,-0.000057,161.447258,3,5.507417,8.370753,Left Arm
23293.724875,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.000000,0.000000,0.000000,-0.000057,161.447258,3,5.507417,8.370753,Left Arm
23293.726875,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.000000,0.000000,0.000000,-0.000057,161.447258,3,5.507417,8.370753,Left Arm
23293.728875,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.0,0.000000,0.000000,0.000000,-0.000057,161.447258,3,5.507417,8.370753,Left Arm


In [6]:
x = positions_df["nose_x"].to_numpy()
y = positions_df["nose_y"].to_numpy()
positions = np.column_stack((x, y))
node_positions = [(120.0, 100.0),
                    (  5.0, 100.0),
                    (  5.0,  55.0),
                    (120.0,  55.0),
                    (  5.0,   8.5),
                    (120.0,   8.5),
                    ]
edges = [
            (3, 2),
            (0, 1),
            (1, 2),
            (5, 4),
            (4, 2),
        ]
track_graph = rtc.make_track_graph(node_positions, edges)

In [7]:
edge_order = [
                (3, 2),
                (0, 1),
                (1, 2),
                (5, 4),
                (4, 2),
                ]

edge_spacing = [16, 0, 16, 0]

linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)

In [8]:
with open(features_filename, "rb") as f:
    clusterless_spike_features = pkl.load(f)

In [9]:
with open(spikes_filename, "rb") as f:
    clusterless_spike_times = pkl.load(f)

features = np.ones((len(timestamps), len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan
for n in range(len(clusterless_spike_times)):
    in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins)
    features[in_spikes_window, :, n] = clusterless_spike_features[n]

In [10]:
place_bin_size = 0.5
movement_var = 0.25

environment = rtc.Environment(place_bin_size=place_bin_size,
                                track_graph=track_graph,
                                edge_order=edge_order,
                                edge_spacing=edge_spacing)

transition_type = rtc.RandomWalk(movement_var=movement_var)

decoder = rtc.ClusterlessDecoder(
    environment=environment,
    transition_type=transition_type,
    clusterless_algorithm="multiunit_likelihood_integer_gpu"
)

In [11]:
print("Learning model parameters")
decoder.fit(linearized_positions.linear_position, features)

Learning model parameters


  x /= x.sum(axis=1, keepdims=True)


In [12]:
print(f"Saving model to {model_filename}")

results = dict(decoder=decoder, linearized_positions=linearized_positions,
                clusterless_spike_times=clusterless_spike_times, features=features, Fs=Fs)

with open(model_filename, "wb") as f:
    pkl.dump(results, f)

Saving model to ../../../../datasets/decoder_data/clusterless_spike_decoder_gpu_int.pkl


# Decode

In [13]:
decoding_start_secs = 0
decoding_duration_secs = 100

In [14]:
with open(model_filename, "rb") as f:
    model_results = pkl.load(f)
    
decoder = model_results["decoder"]
Fs = model_results["Fs"]
clusterless_spike_times = model_results["clusterless_spike_times"]
features = model_results["features"]
linearized_positions = model_results["linearized_positions"]

In [15]:
print("Decoding positions from features")
decoding_start_samples = int(decoding_start_secs * Fs)
decoding_duration_samples = int(decoding_duration_secs * Fs)
time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)
time = np.arange(linearized_positions.linear_position.size) / Fs
decoding_results = decoder.predict(features[time_ind], time=time[time_ind])

Decoding positions from features


n_electrodes: 100%|██████████| 28/28 [00:10<00:00,  2.55it/s]
  posterior[k] = state_transition.T @ posterior[k - 1] * likelihood[k]
  acausal_prior = state_transition.T @ causal_posterior[time_ind]


In [37]:
print(f"Saving decoding results to {decoding_filename}")

results = dict(decoding_results=decoding_results, time=time[time_ind],
                linearized_positions=linearized_positions.iloc[time_ind],
                features=features[time_ind])

with open(decoding_filename, "wb") as f:
    pkl.dump(results, f)

Saving decoding results to ../../../../datasets/decoder_data/clusterless_decoding_results.pkl
