### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from neuropy import plotting
from tqdm.notebook import tqdm
from neuropy.core import Epoch
from neuropy.utils.mathutil import min_max_scaler
from neuropy import plotting
from scipy import stats
import pandas as pd
import subjects

### Pooled dataframe across all sessions

In [None]:
sessions = subjects.pf_sess()

replay_df = []
for sub, sess in enumerate(sessions):
    neurons = sess.neurons_stable.get_neuron_type(["pyr", "mua"])
    neurons = neurons[neurons.firing_rate <= 10]

    df = sess.replay_filtered.to_dataframe()
    df.drop("posterior", axis=1, inplace=True)
    df.rename(columns=dict(label="zt"), inplace=True)
    df["name"] = sess.name
    df["grp"] = sess.tag
    wcorr_pre, wcorr_maze, radon_pre = [
        (f := df.groupby("zt")).get_group("PRE").wcorr,
        f.get_group("MAZE").wcorr,
        f.get_group("MAZE").radon,
        # pd.concat([f.get_group("PRE"),f.get_group('MAZE')]).radon,
    ]

    df["wcorr_rel_pre"] = df.wcorr / wcorr_pre.mean()
    df["wcorr_zsc_pre"] = (df.wcorr - wcorr_pre.median()) / (
        wcorr_pre.quantile(0.75) - wcorr_pre.quantile(0.25)
    )
    df["wcorr_rel_maze"] = df.wcorr / wcorr_maze.mean()

    df["radon_zsc_pre"] = (df.radon - radon_pre.median()) / (
        radon_pre.quantile(0.75) - radon_pre.quantile(0.25)
    )
    df["radon_rel_pre"] = df.radon / radon_pre.mean()
    df["abs_wcorr"] = np.abs(df["wcorr"])

    replay_df.append(df)


replay_df = pd.concat(replay_df, ignore_index=True)

# subjects.GroupData().save(replay_df, "replay_mua")

### Significant replay events using wcorr and jump distance

In [None]:
_, ax = plt.subplots()

grpby_kw = dict(by=["grp", "name", "zt"], sort=False)
bool_indx = (
    (replay_df.wcorr_perc >= 90)
    | (replay_df.wcorr_perc <= 10)
    # (replay_df.wcorr >= 0.5)
) & (replay_df.jd_perc <= 10)

df_sig = replay_df[bool_indx].groupby(**grpby_kw).count()
prop = df_sig / replay_df.groupby(**grpby_kw).count()
prop = prop.reset_index().iloc[:, :7].rename(columns=dict(wcorr="prop"))
# prop = prop[prop.zt != "MAZE"]
# sns.stripplot(data=df_sig.reset_index(), x="zt", y="wcorr", hue="grp", dodge=True)
sns.stripplot(data=prop, x="zt", y="prop", hue="grp", dodge=True)

### Plot measures in individual sessions

In [None]:
_, axs = plt.subplots(7, 2)

meas = "radon"
for g, grp in enumerate(["NSD", "SD"]):
    grp_df = replay_df[replay_df.grp == grp]
    sessions = grp_df.session.unique()

    for s, sess in enumerate(sessions):
        sess_df = grp_df[grp_df.session == sess]
        mean_pre = sess_df[sess_df.zt == "PRE"][meas].mean()

        ax = axs[s, g]
        sns.violinplot(data=sess_df, x="zt", y=meas, color=subjects.colors_sd()[g], ax=ax)
        ax.axhline(mean_pre)

### Plot measures pooled across all sessions

In [None]:
from plotters import violinplot

_, ax = plt.subplots()

# bool_indx = (
#     (replay_df.wcorr_perc_shuffle >= 95) | (replay_df.wcorr_perc_shuffle <=5)
# ) & (replay_df.jd_perc_shuffle <= 10)
# df = replay_df[bool_indx]

violinplot(data=replay_df, x="zt", y="radon_rel_pre", stat_anot=True, stat_test="Kruskal")

In [None]:
from statannotations.Annotator import Annotator
from subjects import stat_kw
from plotters import violinplot

fig = plotting.Fig(grid=(6, 4), fontsize=7)

# sns.violinplot(data=wcorr_df,x='zt',y='score',hue='grp',split=True)
# df = radon_df[~radon_df.name.isin(['RatUDay2'])]
# df = radon_df[radon_df.n_neurons>=80]
# df = radon_df[radon_df.speed>=00]

ax = fig.subplot(fig.gs[0])
plot_kw = dict(
    data=replay_df, x="zt", y="radon_zsc_pre", hue="grp", hue_order=["NSD", "SD"], ax=ax
)
violinplot(**plot_kw)
orders = replay_df.zt.unique()

# Within groups
for i, g in enumerate(["NSD", "SD"]):
    pairs2 = [(("0-2.5", g), ("5-7.5", g)), (("2.5-5", g), ("5-7.5", g))]
    annotator = Annotator(pairs=pairs2, **plot_kw, order=orders)
    annotator.configure(test="Kruskal", **stat_kw, color=subjects.colors_sd(1)[i])
    annotator.apply_and_annotate()
    # annotator.apply_test().annotate(line_offset_to_group=k)
    annotator.reset_configuration()


fig.savefig(subjects.figpath_sd / "radon_dist_stat")

### Cummulative distribution plot

In [None]:
from plotters import violinplot

# --- violinplot -----
# a = replay_df[replay_df.pval < 0.05]
# sig_counts = a.groupby(["grp", "session", "zt"]).count()
# all_counts = replay_df.groupby(["grp", "session", "zt"]).count()

# sig_df = (sig_counts / all_counts).reset_index()

# _, ax = plt.subplots()

# df1 = replay_df[replay_df.perc > 95 ].groupby(["zt", "grp", "session"]).count()
# df2 = replay_df.groupby(["zt", "grp", "session"]).count()
# df = (df1 / df2).reset_index()
# violinplot(data=replay_df, x="zt", y="perc", stat_anot=True, stat_test="Kruskal")


# --- CDF plotting each epoch separately -----
_, axs = plt.subplots(1, 5, sharey=True)

zts = replay_df.zt.unique()

plot_kw = dict(
    x="perc",
    hue="grp",
    stat="probability",
    common_bins=True,
    binwidth=5,
    common_norm=False,
    cumulative=True,
    fill=False,
    element="poly",
)

ax = axs[0]
sns.histplot(
    data=replay_df[replay_df.zt == "PRE"],
    **plot_kw,
    ax=ax,
    palette=subjects.colors_sd(1),
)
ax.legend("", frameon=False)
val1 = replay_df[(replay_df.grp == "SD") & (replay_df.zt == "PRE")].perc.values
val2 = replay_df[(replay_df.grp == "NSD") & (replay_df.zt == "PRE")].perc.values
htest = stats.ks_2samp(val1, val2)
# p = np.format_float_scientific(htest.pvalue,precision=2)
p = htest.pvalue
sig_text = "n.s" if p > 0.05 else "*"
ax.text(25, 0.6, sig_text, color="g")
ax.set_title("PRE")


for i, zt in enumerate(zts[1:]):
    ax = axs[i + 1]
    yvals = [0.8, 0.7]
    for i1, g in enumerate(["NSD", "SD"]):
        val1 = replay_df[(replay_df.grp == g) & (replay_df.zt == "PRE")].perc.values
        val2 = replay_df[(replay_df.grp == g) & (replay_df.zt == zt)].perc.values
        htest = stats.ks_2samp(val1, val2, alternative="greater")
        # p = np.format_float_scientific(htest.pvalue,precision=2)
        p = htest.pvalue
        sig_text = "n.s" if p > 0.05 else "*"
        ax.text(25, yvals[i1], sig_text, color=subjects.colors_sd()[i1])

    sns.histplot(
        data=replay_df[replay_df.zt == "PRE"],
        **plot_kw,
        ax=ax,
        palette=subjects.colors_sd(1),
        ls="--",
    )
    sns.histplot(
        data=replay_df[replay_df.zt == zt],
        **plot_kw,
        ax=ax,
        palette=subjects.colors_sd(1),
    )
    val1 = replay_df[(replay_df.grp == "SD") & (replay_df.zt == zt)].perc.values
    val2 = replay_df[(replay_df.grp == "NSD") & (replay_df.zt == zt)].perc.values
    htest = stats.ks_2samp(val1, val2, alternative="greater")
    # p = np.format_float_scientific(htest.pvalue,precision=2)
    p = htest.pvalue
    sig_text = "n.s" if p > 0.05 else "*"
    ax.text(25, 0.6, sig_text, color="g")

    ax.legend("", frameon=False)
    ax.set_ylabel("")
    ax.set_title(zt)
    # ax.set_yscale('log')

### Jump distance vs Wcorr histogram (Silva2015 style)

In [None]:
sessions = subjects.pf_sess()
# sessions = subjects.nsd.pf_sess + subjects.sd.pf_sess

grp_all, name_all, zt_all, wcorr_all, jd_all = [], [], [], [], []
for sub, sess in enumerate(sessions):
    neurons = sess.neurons_stable.get_neuron_type(["pyr", "mua"])
    neurons = neurons[neurons.firing_rate <= 10]

    replay_pbe = sess.replay_pbe_mua_column_max
    metadata = replay_pbe.metadata
    up_shuffle_measures = metadata["up_shuffle_measures"]
    down_shuffle_measures = metadata["down_shuffle_measures"]
    # shuffle_measures = np.vstack([up_shuffle_measures, down_shuffle_measures])

    # ---- Filtering by good PBEs ---------
    pbe_filter = sess.pbe_filters.to_dataframe()
    good_bool = pbe_filter.is_rpl & pbe_filter.is_5units & pbe_filter.is_rest
    good_bool = good_bool.values
    replay_pbe = replay_pbe[good_bool]
    # shuffle_measures = shuffle_measures[:, :, good_bool]
    up_shuffle_measures = up_shuffle_measures[:, :, good_bool]
    down_shuffle_measures = down_shuffle_measures[:, :, good_bool]

    replay_pbe_df = replay_pbe.to_dataframe()
    starts = replay_pbe.starts

    replay_pbe_df.loc[:, "down_wcorr"] *= -1
    measure_names = ["wcorr", "jd"]
    up_measures = replay_pbe_df.loc[:, ["up_" + _ for _ in measure_names]].to_numpy()
    down_measures = replay_pbe_df.loc[:, ["down_" + _ for _ in measure_names]].to_numpy()

    best_bool = np.abs(up_measures[:, 0]) > np.abs(down_measures[:, 0])
    measures = np.zeros_like(up_measures)
    measures[best_bool] = up_measures[best_bool]
    measures[~best_bool] = down_measures[~best_bool]

    shuffle_measures = np.zeros_like(up_shuffle_measures)
    shuffle_measures[:, :, best_bool] = up_shuffle_measures[:, :, best_bool]
    shuffle_measures[:, :, ~best_bool] = down_shuffle_measures[:, :, ~best_bool]

    epochs = sess.get_zt_epochs()
    starts_bool, _, starts_labels = epochs.contains(starts)
    measures = measures[starts_bool]
    shuffle_measures = shuffle_measures[:, :, starts_bool]

    wcorr = measures[:, 0][np.newaxis, :]
    jd = measures[:, 1][np.newaxis, :]

    n_pbes = measures.shape[0]
    grp_all.append([sess.tag] * n_pbes)
    name_all.append([sess.name] * n_pbes)
    zt_all.append(starts_labels)
    wcorr_all.append(np.abs(np.vstack((wcorr, shuffle_measures[:, 0, :]))))
    jd_all.append(np.vstack((jd, shuffle_measures[:, 1, :])))


grp_all = np.concatenate(grp_all)
name_all = np.concatenate(name_all)
zt_all = np.concatenate(zt_all)
wcorr_all = np.hstack(wcorr_all)
jd_all = np.hstack(jd_all)

#### Individual sessions

In [None]:
import matplotlib.colors as mcolors

cmap = mcolors.ListedColormap(["red", "blue"])

divnorm = mcolors.TwoSlopeNorm(vmin=0.001, vcenter=0.05, vmax=1)
bounds = [0.001, 0.05, 1]
norm = mcolors.BoundaryNorm(bounds, cmap.N)
# _,axs = plt.subplots(7,10,sharex=True,sharey=True)
fig = plotting.Fig(grid=(7, 2))

zts = ["PRE", "MAZE", "0-2.5", "2.5-5", "5-7.5"]

bins = np.arange(0, 1.1, 0.1)
n_bins = len(bins) - 1


def get_chist(x, y):
    bins = np.arange(0, 1.1, 0.1)
    n_bins = len(bins) - 1
    hist_zt = np.histogram2d(x, y, bins=[bins, bins])[0]
    cumsum_hist = np.zeros_like(hist_zt)
    for row in range(n_bins):
        for col in range(n_bins):
            cumsum_hist[row, col] = hist_zt[-(row + 1) :, : col + 1].sum()

    return cumsum_hist


k = 0
for g, grp in enumerate(["NSD", "SD"]):
    names = np.unique(name_all[grp_all == grp])
    for name in names:
        subfig = fig.add_subfigure(fig.gs[k])
        axs = subfig.subplots(1, 5, sharex=True, sharey=True)

        for i, zt in enumerate(zts):
            indx = (grp_all == grp) & (zt_all == zt) & (name_all == name)
            jd = jd_all[0, indx]
            wcorr = wcorr_all[0, indx]
            real_dist = get_chist(wcorr, jd)

            sh_jd = jd_all[1:, indx]
            sh_wcorr = wcorr_all[1:, indx]
            pval = np.zeros_like(real_dist)

            for sh_i in range(sh_wcorr.shape[0]):
                jd_ = sh_jd[sh_i]
                wcorr_ = sh_wcorr[sh_i]
                sh_dist = get_chist(wcorr_, jd_)

                pval += (sh_dist >= real_dist).astype("float")

            pval = pval / sh_wcorr.shape[0]

            ax = axs[i]
            im = ax.pcolormesh(bins, bins, np.flipud(pval), cmap=cmap, norm=norm)
            cb = plt.colorbar(im, ax=ax)

        k += 1

#### Pooled

In [None]:
import matplotlib.colors as mcolors

# divnorm = mcolors.TwoSlopeNorm(vmin=0.001, vcenter=0.01, vmax=1)

cmap = mcolors.ListedColormap(["red", "pink", "blue"])
bounds = [0.001, 0.01, 0.05, 1]
norm = mcolors.BoundaryNorm(bounds, cmap.N)

fig = plotting.Fig(grid=(9, 6), fontsize=7)
subfig = fig.add_subfigure(fig.gs[:2, :4])
axs = subfig.subplots(2, 5, sharex=True, sharey=True)

zts = ["PRE", "MAZE", "0-2.5", "2.5-5", "5-7.5"]

bins = np.arange(0, 1.1, 0.1)
n_bins = len(bins) - 1


def get_chist(x, y):
    bins = np.arange(0, 1.1, 0.1)
    n_bins = len(bins) - 1
    hist_zt = np.histogram2d(x, y, bins=[bins, bins])[0]
    cumsum_hist = np.zeros_like(hist_zt)
    for row in range(n_bins):
        for col in range(n_bins):
            cumsum_hist[row, col] = hist_zt[-(row + 1) :, : col + 1].sum()

    return cumsum_hist


for i, zt in enumerate(zts):
    for g, grp in enumerate(["NSD", "SD"]):
        indx = (grp_all == grp) & (zt_all == zt)
        jd = jd_all[0, indx]
        wcorr = wcorr_all[0, indx]

        real_dist = get_chist(wcorr, jd)
        # real_dist = np.histogram2d(wcorr,jd,bins=[bins,bins])[0]
        # real_dist = real_dist/real_dist.sum()

        sh_jd = jd_all[1:, indx]
        sh_wcorr = wcorr_all[1:, indx]

        pval = np.zeros_like(real_dist)
        for sh_i in range(sh_wcorr.shape[0]):
            jd_ = sh_jd[sh_i]
            wcorr_ = sh_wcorr[sh_i]

            sh_dist = get_chist(wcorr_, jd_)

            pval += (sh_dist >= real_dist).astype("float")

        pval = (pval + 1) / (sh_wcorr.shape[0] + 1)

        ax = axs[g, i]
        im = ax.pcolormesh(
            bins, bins, np.flipud(pval), cmap=cmap, norm=norm, rasterized=True
        )
        # im = ax.pcolormesh(bins,bins,real_dist,cmap='jet',vmin=0,vmax=0.1)
        if grp == "NSD":
            ax.set_title(zt)

ax.set_xlabel("Mean jump distance")
yticks = [0, 0.3, 0.6, 0.9]
ax.set_yticks(yticks, [f">{_}" for _ in yticks[::-1]])
xticks = [0.2, 0.5, 0.8]
ax.set_xticks(xticks, [f"<{_}" for _ in xticks])

cax = fig.subplot(fig.gs[4])
cax.set_axis_off()
cb = plt.colorbar(im, ax=cax)

fig.savefig(subjects.figpath_sd / "max_jump_distance")

### Compared neuron_id vs column_cycle shuffles

In [None]:
sessions = subjects.sd.ratRday2

replay_df = []
for sub, sess in enumerate(sessions):
    neurons = sess.neurons_stable.get_neuron_type(["pyr", "mua"])
    neurons = neurons[neurons.firing_rate <= 10]

    replay_pbe = sess.replay_pbe_mua
    starts = replay_pbe.starts
    replay_pbe_df = replay_pbe.to_dataframe()
    up_wcorr = replay_pbe_df["up_wcorr"].values
    down_wcorr = replay_pbe_df["down_wcorr"].values
    id_metadata = replay_pbe.metadata
    id_shuffle_measures = np.vstack(
        [id_metadata["up_shuffle_measures"], id_metadata["down_shuffle_measures"]]
    )

    replay_column = sess.replay_pbe_mua_column
    column_metadata = replay_column.metadata
    column_shuffle_measures = np.vstack(
        [
            column_metadata["up_shuffle_measures"],
            column_metadata["down_shuffle_measures"],
        ]
    )

    best_bool = np.abs(up_wcorr) > np.abs(down_wcorr)
    wcorr = np.zeros_like(up_wcorr)
    wcorr[best_bool] = up_wcorr[best_bool]
    wcorr[~best_bool] = down_wcorr[~best_bool]

    id_perc_shuffle = np.array(
        [
            stats.percentileofscore(id_shuffle_measures[:, 0, i], wcorr[i], kind="strict")
            for i in range(len(wcorr))
        ]
    )

    column_perc_shuffle = np.array(
        [
            stats.percentileofscore(
                column_shuffle_measures[:, 0, i], wcorr[i], kind="strict"
            )
            for i in range(len(wcorr))
        ]
    )

    epochs = sess.get_zt_epochs()
    starts_bool, _, starts_labels = epochs.contains(starts)

    df = pd.DataFrame(
        dict(
            zt=starts_labels,
            id_perc=id_perc_shuffle[starts_bool],
            col_perc=column_perc_shuffle[starts_bool],
        )
    )

In [None]:
_, axs = plt.subplots(1, 2)

sns.violinplot(data=df, x="zt", y="id_perc", ax=axs[0])
sns.violinplot(data=df, x="zt", y="col_perc", ax=axs[1])

### Sharpness of posterior comparison between NSD and SD

In [None]:
sess = subjects.nsd.ratJday2[0]

## Example figures for sd_paper

### Chosen by percentile
- All possible measures are pooled across sessions and best replays within epoch are displayed
- Possible criterias: replays at 95th percentile of radon score or wcorr, at 5th percentile of jump distance and that has traversed maximum distance

In [None]:
sessions = subjects.pf_sess()

replay_df = []
for sub, sess in enumerate(sessions):
    df = sess.replay_filtered.to_dataframe()
    df.rename({"label": "zt"}, inplace=True, axis=1)
    df["grp"] = sess.tag
    replay_df.append(df)

replay_df = pd.concat(replay_df, ignore_index=True)

In [None]:
from replay_funcs import get_distance

_, axs = plt.subplots(2, 10)

zts = replay_df.zt.unique()

examples_df = []
for g, grp in enumerate(["NSD", "SD"]):
    for i, zt in enumerate(zts):
        df = replay_df[(replay_df.zt == zt) & (replay_df.grp == grp)].reset_index(
            drop=True
        )
        scores = df.jd.values
        # velocity = df.velocity.values
        percentile = np.array(
            [stats.percentileofscore(scores, _, kind="strict") for _ in scores]
        )
        indx = percentile < 10
        posteriors = df[indx]["posterior"].to_list()
        dx = 1 / posteriors[0].shape[0]
        distance = np.array([np.abs(get_distance(_)) for _ in posteriors]) * dx
        sort_ind = np.argsort(distance)[::-1]

        chosen_posteriors = [
            posteriors[sort_ind[0]],
            posteriors[sort_ind[1]],
            # posteriors[distance.argmax()],
            # posteriors[distance.argmin()],
        ]

        cmap = "binary" if grp == "NSD" else "Reds"

        for i1, p in enumerate(chosen_posteriors):
            p_enh = np.apply_along_axis(
                np.convolve, axis=0, arr=p, v=np.ones(2 * 4 + 1), mode="same"
            )

            ax = axs[g, 2 * i + i1]
            ax.pcolormesh(p_enh, cmap=cmap)

            df = pd.DataFrame(
                dict(
                    zt=zt,
                    jd=[np.abs(np.diff(np.argmax(p, axis=0))).mean()],
                    posterior=[p],
                    grp=grp,
                )
            )

            examples_df.append(df)

examples_df = pd.concat(examples_df, ignore_index=True)

# subjects.GroupData().save(examples_df, "replay_examples")

### Proportion of continuous trajectory vs Explained variance

In [None]:
grpdata = subjects.GroupData()
cont_events = grpdata.replay_continuous_events
cont_events = cont_events[~cont_events.zt.isin(["PRE", "MAZE"])].reset_index(drop=True)
ev_pooled = grpdata.ev_in_chunks

data = cont_events.merge(
    ev_pooled, how="left", left_on=["name", "zt"], right_on=["name", "zt"]
)

fig = plotting.Fig(grid=(7, 5))

for g, grp in enumerate(["NSD", "SD"]):
    for i, zt in enumerate(["0-2.5", "2.5-5", "5-7.5"]):
        ax = fig.subplot(fig.gs[g, i])

        e_dt = data[(data.zt == zt) & (data.grp_x == grp)].reset_index(drop=True)
        x = e_dt.ev_diff.values
        y = e_dt.prop.values
        linfit = stats.linregress(x, y)
        corr = linfit.rvalue
        pval = linfit.pvalue

        # ax.scatter(x,y,c=subjects.colors_sd(1)[g])
        sns.regplot(
            data=e_dt,
            x="ev_diff",
            y="prop",
            ax=ax,
            color=subjects.colors_sd(1)[g],
            ci=None,
        )
        ax.set_title(f"r={corr.round(2)}, p={pval.round(2)}")
        # ax.axline((0,0),slope=linfit.slope)


fig.savefig(subjects.figpath_sd / "prop_ev_scatter")