### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from neuropy import plotting
from tqdm.notebook import tqdm
from neuropy.core import Epoch
from neuropy.utils.mathutil import min_max_scaler
import pandas as pd
from neuropy.utils.position_util import run_direction 
from neuropy.plotting import plot_epochs
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
import subjects

### Scratch/Test for running epochs

In [None]:
sess = subjects.nsd.ratKday2[0]

maze = sess.paradigm["maze"].flatten()
t = sess.maze.time
dt = 1 / sess.maze.sampling_rate
x = gaussian_filter1d(sess.maze.x, sigma=0.1 / dt)
speed = np.abs(np.concatenate([[0], np.diff(x) / dt]))

speed_thresh = np.where(speed >= 10, speed, 0)
peaks, props = find_peaks(speed_thresh, height=30, prominence=0)

starts, stops = props["left_bases"], props["right_bases"]
peaks_power = speed_thresh[peaks]

# ----- merge overlapping epochs ------
n_epochs = len(starts)
ind_delete = []
for i in range(n_epochs - 1):
    if (starts[i + 1] - stops[i]) < (1*(1/dt) + 1e-6):

        # stretch the second epoch to cover the range of both epochs
        starts[i + 1] = min(starts[i], starts[i + 1])
        stops[i + 1] = max(stops[i], stops[i + 1])

        peaks_power[i + 1] = max(peaks_power[i], peaks_power[i + 1])
        peaks[i + 1] = [peaks[i], peaks[i + 1]][
            np.argmax([peaks_power[i], peaks_power[i + 1]])
        ]

        ind_delete.append(i)

epochs_arr = np.vstack((starts, stops, peaks, peaks_power)).T
starts, stops, peaks, peaks_power = np.delete(epochs_arr, ind_delete, axis=0).T
starts = starts.astype("int")
stops = stops.astype("int")


run = run_direction(sess.maze,min_distance=10,sigma=0.1)
peak_time = run.to_dataframe().peak_time.values
_, axs = plt.subplots(2, 1, sharex=True)

axs[0].plot(t, x)
# ax.plot(sess.lin_maze.time,gaussian_filter1d(sess.lin_maze.x,sigma=50))
plot_epochs(
    ax=axs[0],
    epochs=run,
    colors={"up": "r", "down": "k"},
    alpha=0.2,
    collapsed=True,
)

axs[1].plot(t, speed)
axs[1].plot(t[starts], speed[starts], "r*")
axs[1].plot(t[stops], speed[stops], "k*")
axs[1].plot(peak_time,60*np.ones_like(peak_time),'g*')

# axs[2].plot(t[1:],np.abs(np.diff(x)/dt))


### Pooled running epochs detection

In [None]:
# sessions = subjects.pf_sess()
# sessions = subjects.nsd.ratVday3 + subjects.sd.ratUday1
# sessions = subjects.nsd.ratVday1 + subjects.nsd.ratVday3
sessions = subjects.sd.ratVday2

In [None]:
for sub,sess in enumerate(sessions):
    maze = sess.paradigm['maze'].flatten()
    mazepos = sess.maze
    run_epochs = run_direction(sess.maze,min_distance=10,sigma=0.1)
    run_epochs.save(sess.filePrefix.with_suffix('.maze.running'))