# Learn

In [None]:
import numpy as np
import sys
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl

import replay_trajectory_classification as rtc
import track_linearization as tl

In [None]:
positions_filename = "../../../../datasets/decoder_data/position_info.pkl"
spikes_filename = "../../../../datasets/decoder_data/sorted_spike_times.pkl"
model_filename = "../../../../datasets/decoder_data/sorted_spike_decoder.pkl"
decoding_filename = "../../../../datasets/decoder_data/sorted_spike_decoding_results.pkl"

In [None]:
positions_df = pd.read_pickle(positions_filename)
timestamps = positions_df.index.to_numpy()
time_start = timestamps[0]
time_end = timestamps[-1]
dt = timestamps[1] - timestamps[0]
Fs = 1.0 / dt
spikes_bins = np.arange(time_start - dt, time_end + dt, dt)

In [None]:
positions_df

In [None]:
x = positions_df["nose_x"].to_numpy()
y = positions_df["nose_y"].to_numpy()
positions = np.column_stack((x, y))
node_positions = [(120.0, 100.0),
                    (  5.0, 100.0),
                    (  5.0,  55.0),
                    (120.0,  55.0),
                    (  5.0,   8.5),
                    (120.0,   8.5),
                    ]
edges = [
            (3, 2),
            (0, 1),
            (1, 2),
            (5, 4),
            (4, 2),
        ]
track_graph = rtc.make_track_graph(node_positions, edges)

In [None]:
edge_order = [
                (3, 2),
                (0, 1),
                (1, 2),
                (5, 4),
                (4, 2),
                ]

edge_spacing = [16, 0, 16, 0]

linearized_positions = tl.get_linearized_position(positions, track_graph, edge_order=edge_order, edge_spacing=edge_spacing, use_HMM=False)

In [None]:
with open(spikes_filename, "rb") as f:
    sorted_spike_times = pkl.load(f)

binned_spike_times = np.empty((len(spikes_bins) - 1, len(sorted_spike_times)), dtype=float)
for n in range(len(sorted_spike_times)):
    binned_spike_times[:, n] = np.histogram(sorted_spike_times[n], spikes_bins)[0]

linear_position = np.ones(len(spikes_bins) - 1) * np.nan
in_position_window = np.digitize(positions_df.index, spikes_bins) - 1
linear_position[in_position_window] = linearized_positions.linear_position

In [None]:
place_bin_size = 0.5
movement_var = 0.25

environment = rtc.Environment(place_bin_size=place_bin_size,
                                track_graph=track_graph,
                                edge_order=edge_order,
                                edge_spacing=edge_spacing)

transition_type = rtc.RandomWalk(movement_var=movement_var)

decoder = rtc.SortedSpikesDecoder(
    environment=environment,
    transition_type=transition_type,
)

In [None]:
print("Learning model parameters")
decoder.fit(linear_position, binned_spike_times)

In [None]:
print(f"Saving model to {model_filename}")

results = dict(decoder=decoder, linear_position=linear_position,
                spike_times=binned_spike_times, Fs=Fs)

with open(model_filename, "wb") as f:
    pkl.dump(results, f)

# Decode

In [None]:
decoding_start_secs = 0
decoding_duration_secs = 100

In [None]:
with open(model_filename, "rb") as f:
    model_results = pkl.load(f)
    
decoder = model_results["decoder"]
Fs = model_results["Fs"]
spike_times = model_results["spike_times"]
linear_position = model_results["linear_position"]

In [None]:
print("Decoding positions from spikes")
decoding_start_samples = int(decoding_start_secs * Fs)
decoding_duration_samples = int(decoding_duration_secs * Fs)
time_ind = slice(decoding_start_samples, decoding_start_samples + decoding_duration_samples)
time = np.arange(linear_position.size) / Fs
decoding_results = decoder.predict(spike_times[time_ind], time=time[time_ind])

In [None]:
print(f"Saving decoding results to {decoding_filename}")

results = dict(decoding_results=decoding_results, time=time[time_ind],
                linear_position=linear_position[time_ind],
                spike_times=spike_times[time_ind])

with open(decoding_filename, "wb") as f:
    pkl.dump(results, f)

In [None]:
import plotly.graph_objects as go

In [None]:
with open(decoding_filename, "rb") as f:
    load_res = pkl.load(f)

decoding_results = load_res["decoding_results"]
time = load_res["time"]
linear_position = load_res["linear_position"]

In [None]:
fig = go.Figure()

trace = go.Heatmap(z=decoding_results.acausal_posterior.T,
                    x=decoding_results.acausal_posterior.time,
                    y=decoding_results.acausal_posterior.position,
                    zmin=0.00, zmax=0.05, showscale=False)
fig.add_trace(trace)

trace = go.Scatter(x=time, y=linear_position,
                    mode="markers", marker={"color": "cyan", "size": 5},
                    name="position", showlegend=True)
fig.add_trace(trace)

fig.update_xaxes(title="Time (sec)")
fig.update_yaxes(title="Position (cm)")
fig.update_coloraxes(showscale=False)