### Imports

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import subjects
from neuropy import plotting
from neuropy.utils import signal_process
from neuropy.analyses import Pf1D
from neuropy.analyses import Decode1d

### Decoding re-maze with maze template

In [None]:
sessions = (
    subjects.nsd.ratSday2
    + subjects.nsd.ratVday1
    + subjects.nsd.ratUday2
    + subjects.sd.ratSday3
    + subjects.sd.ratVday2
    + subjects.sd.ratUday4
)


In [None]:
from tqdm import tqdm

score_df = []
dist_replay_df = []
for sub, sess in enumerate(tqdm(sessions)):
    re_maze = sess.paradigm["re-maze"].flatten()
    neurons = sess.neurons.get_neuron_type(neuron_type="pyr")
    pos = sess.maze
    # run = sess.run
    pf = Pf1D(
        neurons=neurons,
        position=pos,
        speed_thresh=5,
        sigma=4,
        grid_bin=2,
        # epochs=run,
        frate_thresh=1,
    )
    pf_neurons = neurons.get_by_id(pf.neuron_ids)
    epochs = sess.pbe.time_slice(re_maze[0], re_maze[1])
    decode = Decode1d(
        neurons=pf_neurons,
        ratemap=pf,
        epochs=epochs,
        bin_size=0.02,
        decode_margin=15,
        nlines=5000,
    )
    jump_distance = [np.abs(np.diff(_)).mean() for _ in decode.decoded_position]

    norm_pos = min_max_scaler(decode.ratemap.xbin_centers)
    decoded_position_mean = np.nanmean(np.hstack(decode.posterior), axis=1)
    pos_bins = np.linspace(0, 1, 50)
    mean_pos = np.interp(pos_bins, norm_pos, decoded_position_mean)

    df_dist = pd.DataFrame(
        dict(
            bins=pos_bins,
            mean_pos=mean_pos,
            name=sess.animal.name,
            grp=sess.tag,
        )
    )

    df = pd.DataFrame(
        dict(
            score=decode.score,
            velocity=decode.velocity,
            jump_distance=jump_distance,
            epoch="re-maze",
            name=sess.animal.name,
            grp=sess.tag,
        )
    )
    score_df.append(df)
    dist_replay_df.append(df_dist)

score_df = pd.concat(score_df, ignore_index=True)
dist_replay_df = pd.concat(dist_replay_df, ignore_index=True)
subjects.GroupData().save(score_df, "replay_re_maze_score")
subjects.GroupData().save(dist_replay_df, "replay_re_maze_position_distribution")


In [None]:
%matplotlib widget
import seaborn as sns

sns.violinplot(
    data=score_df,
    x="epoch",
    y="score",
    hue="grp",
    split=True,
    inner=None,
    scale="width",
)
