### Figure 4 
- Sliding template during sleep deprivation
- Correlation across time window

### load data

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import subjects
from subjects import stat_kw
from plotters import violinplot
import pandas as pd
import seaborn as sns
from neuropy.plotting import Fig
from scipy import stats
from statannotations.Annotator import Annotator

group = ['NSD','SD']
grpdata = subjects.GroupData()
pf_norm_tuning = grpdata.pf_norm_tuning


# replay examples
examples = grpdata.replay_examples

#replay rate
wcorr_dist = grpdata.replay_wcorr
# post replay score
post_df = grpdata.replay_post_score
# post_df = post_df[np.abs(post_df['velocity'])>100]
post_df['weighted_corr'] = post_df['weighted_corr'].abs()
colors_post = subjects.colors_sd(1)


fig = Fig(grid=(7, 6), hspace=0.45, wspace=0.5,constrained_layout=False)
filename = subjects.figpath_sd / 'figure4'
caption = 'Place cells: \n(A) Place fields recorded during NSD and SD sessions.'

### replay examples

In [None]:
subfig = fig.add_subfigure(fig.gs[1, 0:4])
axs = subfig.subplots(2, 10, sharey="row")

dt = 0.02*1000
pf_cmaps = ['Greys','Reds']
for i, g in enumerate(group):
    events = examples[g]["posteriors"]
    wcorr = examples[g]['wcorr']
    for i1, ev in enumerate(events):
        ev = np.apply_along_axis(
            np.convolve, axis=0, arr=ev, v=np.ones(2 * 2 + 1), mode="same"
        )
        nt = ev.shape[1]
        ax = axs[i,i1]
        if g=='SD' and i1>7:
            cmap = 'Blues'
        else:
            cmap = pf_cmaps[i]
        ax.pcolormesh(ev, cmap=cmap, vmin=0, vmax=0.2,rasterized=True)
        ax.text(0.6,7,f'{wcorr[i1].round(2)}',color='#039be5',fontdict=dict(fontsize=6))
        ax.spines['right'].set_visible(True)
        ax.spines['top'].set_visible(True)
        ax.set_yticks([])
        ax.set_xticks([nt//2],[int(nt*dt)])
        ax.tick_params(length=0)


### wcorr percentile cdf

In [None]:
subfig = fig.add_subfigure(fig.gs[2, 0:4])
axs = subfig.subplots(1, 5, sharey=True,sharex=True)

zts = wcorr_dist.zt.unique()

plot_kw = dict(
    x="perc",
    hue="grp",
    stat="probability",
    common_bins=True,
    binwidth=5,
    common_norm=False,
    cumulative=True,
    fill=False,
    element="poly",
    lw=0.8,
)

ax = axs[0]
sns.histplot(
    data=wcorr_dist[wcorr_dist.zt == "PRE"],
    **plot_kw,
    ax=ax,
    palette=subjects.colors_sd(1),
)
ax.legend("", frameon=False)
val1 = wcorr_dist[(wcorr_dist.grp == "SD") & (wcorr_dist.zt == 'PRE')].perc.values
val2 = wcorr_dist[(wcorr_dist.grp == "NSD") & (wcorr_dist.zt == 'PRE')].perc.values
htest = stats.ks_2samp(val1, val2)
# p = np.format_float_scientific(htest.pvalue,precision=2)
p = htest.pvalue
sig_text = "n.s" if p > 0.05 else "*"
ax.text(25, 0.6, sig_text, color='g')
ax.set_title('PRE')
ax.set_xticks([0,25,50,75,100])


for i, zt in enumerate(zts[1:]):
    ax = axs[i + 1]
    yvals = [0.8, 0.7]
    for i1, g in enumerate(["NSD", "SD"]):
        val1 = wcorr_dist[(wcorr_dist.grp == g) & (wcorr_dist.zt == "PRE")].perc.values
        val2 = wcorr_dist[(wcorr_dist.grp == g) & (wcorr_dist.zt == zt)].perc.values
        htest = stats.ks_2samp(val1, val2, alternative="greater")
        # p = np.format_float_scientific(htest.pvalue,precision=2)
        p = htest.pvalue
        sig_text = "n.s" if p > 0.05 else "*"
        ax.text(25, yvals[i1], sig_text, color=subjects.colors_sd()[i1])

    sns.histplot(
        data=wcorr_dist[wcorr_dist.zt == "PRE"],
        **plot_kw,
        ax=ax,
        palette=subjects.colors_sd(1),
        ls="--",
    )
    sns.histplot(
        data=wcorr_dist[wcorr_dist.zt == zt],
        **plot_kw,
        ax=ax,
        palette=subjects.colors_sd(1),
    )
    val1 = wcorr_dist[(wcorr_dist.grp == "SD") & (wcorr_dist.zt == zt)].perc.values
    val2 = wcorr_dist[(wcorr_dist.grp == "NSD") & (wcorr_dist.zt == zt)].perc.values
    htest = stats.ks_2samp(val1, val2, alternative="greater")
    # p = np.format_float_scientific(htest.pvalue,precision=2)
    p = htest.pvalue
    sig_text = "n.s" if p > 0.05 else "*"
    ax.text(25, 0.6, sig_text, color='g')

    ax.legend("", frameon=False)
    ax.set_ylabel("")
    ax.set_title(zt)
    # ax.set_yscale('log')



### Percentile violinplot

In [None]:
colors = subjects.colors_sd(1)

ax = fig.subplot(fig.gs[4, 0])
# plot_kw =dict(data=wcorr_dist,x='zt',y='seq_score',ax=ax,color=colors[g]) 
violinplot(data=wcorr_dist, x="zt", y="perc",stat_anot=True)


#stats
# orders =df.zt.unique()
# pairs = [('PRE', _) for _ in orders[1:]]
# annotator = Annotator(pairs=pairs, **plot_kw, order=orders)
# annotator.configure(test="Kruskal", **stat_kw)
# annotator.apply_and_annotate()
# annotator.reset_configuration()

ax.set_xlabel('')
# ax.set_ylim(-2.5,10)
ax.tick_params('x',rotation=30)
ax.legend('',frameon=False)


### saving

In [None]:
fig.savefig(filename,caption=caption)