# Train the decoder and save the results

In [None]:
import numpy as np
import pandas as pd
import pickle as pkl

import replay_trajectory_classification as rtc
import track_linearization as tl

In [None]:
positions_filename = "../../../../datasets/decoder_data/position_info.pkl"
spikes_filename = "../../../../datasets/decoder_data/clusterless_spike_times.pkl"
features_filename = "../../../../datasets/decoder_data/clusterless_spike_features.pkl"
model_filename = "../../../../datasets/decoder_data/clusterless_spike_decoder.pkl"
decoding_filename = "../../../../datasets/decoder_data/clusterless_spike_decoding_results.pkl"

In [None]:
positions_df = pd.read_pickle(positions_filename)
timestamps = positions_df.index.to_numpy()
time_start = timestamps[0]
time_end = timestamps[-1]
dt = 0.02
Fs = 1.0 / dt
spikes_bins = np.arange(time_start - dt, time_end + dt, dt)

In [None]:
positions_df

In [None]:
x = positions_df["nose_x"].to_numpy()
y = positions_df["nose_y"].to_numpy()
positions = np.column_stack((x, y))
node_positions = [(120.0, 100.0),
                    (  5.0, 100.0),
                    (  5.0,  55.0),
                    (120.0,  55.0),
                    (  5.0,   8.5),
                    (120.0,   8.5),
                    ]
edges = [
            (3, 2),
            (0, 1),
            (1, 2),
            (5, 4),
            (4, 2),
        ]
track_graph = rtc.make_track_graph(node_positions, edges)

In [None]:
edge_order = [
                (3, 2),
                (0, 1),
                (1, 2),
                (5, 4),
                (4, 2),
                ]

edge_spacing = [16, 0, 16, 0]

linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)

In [None]:
with open(features_filename, "rb") as f:
    clusterless_spike_features = pkl.load(f)

with open(spikes_filename, "rb") as f:
    clusterless_spike_times = pkl.load(f)

features = np.ones((len(spikes_bins) - 1, len(clusterless_spike_features[0][0]), len(clusterless_spike_times)), dtype=float) * np.nan
for n in range(len(clusterless_spike_times)):
    in_spikes_window = np.digitize(clusterless_spike_times[n], spikes_bins) - 1
    features[in_spikes_window, :, n] = clusterless_spike_features[n]

linear_position = np.ones(len(spikes_bins) - 1) * np.nan
in_position_window = np.digitize(positions_df.index, spikes_bins) - 1
linear_position[in_position_window] = linearized_positions.linear_position

In [None]:
place_bin_size = 0.5
movement_var = 0.25

environment = rtc.Environment(place_bin_size=place_bin_size,
                                track_graph=track_graph,
                                edge_order=edge_order,
                                edge_spacing=edge_spacing)

transition_type = rtc.RandomWalk(movement_var=movement_var)

decoder = rtc.ClusterlessDecoder(
    environment=environment,
    transition_type=transition_type,
    clusterless_algorithm="multiunit_likelihood_integer_gpu"
)

In [None]:
print("Learning model parameters")
decoder.fit(linear_position, features)

In [None]:
print(f"Saving model to {model_filename}")

results = dict(decoder=decoder)

with open(model_filename, "wb") as f:
    pkl.dump(results, f)

# Decode

In [None]:
decoding_start_secs = 0
decoding_duration_secs = 100

In [None]:
print("Decoding positions from features")
decoding_start_samples = int(decoding_start_secs * Fs)
decoding_duration_samples = int(decoding_duration_secs * Fs)
time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)
time = np.arange(linear_position.size) / Fs
decoding_results = decoder.predict(features[time_ind], time=time[time_ind])

In [None]:
print(f"Saving decoded results to {decoding_filename}")

results = dict(decoding_results=decoding_results,
                linear_position=linear_position[time_ind],
                spikes=features[time_ind])

with open(decoding_filename, "wb") as f:
    pkl.dump(results, f)

## Optional

Plot the decoded results

In [None]:
import plotly.graph_objects as go

In [None]:
fig = go.Figure()

trace = go.Heatmap(z=decoding_results.acausal_posterior.T,
                    x=decoding_results.acausal_posterior.time,
                    y=decoding_results.acausal_posterior.position,
                    zmin=0.00, zmax=0.05, showscale=False)
fig.add_trace(trace)

trace = go.Scatter(x=time[time_ind], y=linear_position,
                    mode="markers", marker={"color": "cyan", "size": 5},
                    name="position", showlegend=True)
fig.add_trace(trace)

fig.update_xaxes(title="Time (sec)")
fig.update_yaxes(title="Position (cm)")
fig.update_coloraxes(showscale=False)