### Imports
- Everthing which has remaze as focus

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.model_selection import KFold, train_test_split
from neuropy import plotting
import subjects

### Scatter plot maze vs remaze

In [None]:
pair_corr_df = []
for sub, sess in enumerate(sessions):
    maze = sess.paradigm["maze"].flatten()
    remaze = sess.paradigm["re-maze"].flatten()

    neurons = sess.neurons.get_neuron_type("pyr")
    maze_frate = neurons.time_slice(*maze).firing_rate
    remaze_frate = neurons.time_slice(*remaze).firing_rate

    good_indices = np.logical_and(maze_frate > 0, remaze_frate > 0)
    neurons = neurons[good_indices]

    pair_corr = []
    for e in [maze, remaze]:
        pair_corr.append(
            neurons.time_slice(*e).get_binned_spiketrains(0.25).get_pairwise_corr()
        )

    df = pd.DataFrame(dict(maze=pair_corr[0], remaze=pair_corr[1], grp=sess.tag))
    pair_corr_df.append(df)

pair_corr_df = pd.concat(pair_corr_df, ignore_index=True)

subjects.GroupData().save(pair_corr_df, "remaze_maze_paircorr")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.scatterplot(data=pair_corr_df, x="maze", y="remaze", hue="grp")

### Pairwise correlations between Zt0-5/Zt5-end vs remaze
- Scatter plots 
- I saw that pairwise correlations during NSD are more correlated to remaze compared to SD.
- To carry out some statistical comparison between these two correlations, we can shuffle the pairs in each group and calculate R<sup>2</sup> difference and those distributions of R<sup>2</sup> difference between NSD and SD. 

In [None]:
sessions = subjects.nsd.remaze + subjects.sd.remaze

In [None]:
print(f"#Sessions: {len(sessions)}")
pair_corr_df = []
for sub, sess in enumerate(sessions):
    maze = sess.paradigm["maze"].flatten()
    post = sess.paradigm["post"].flatten()
    remaze = sess.paradigm["re-maze"].flatten()

    e1 = [post[0], post[0] + 5 * 3600]  # Zt0-5
    # e1 = [post[0] + 5 * 3600, post[1]]  # Zt5-end
    e2 = remaze

    neurons = sess.neurons.get_neuron_type("pyr")
    e1_frate = neurons.time_slice(*e1).firing_rate
    e2_frate = neurons.time_slice(*e2).firing_rate

    good_indices = np.logical_and(e1_frate > 0, e2_frate > 0)
    neurons = neurons[good_indices]

    pair_corr = []
    for e in [e1, e2]:
        pair_corr.append(
            neurons.time_slice(*e).get_binned_spiketrains(0.25).get_pairwise_corr()
        )

    df = pd.DataFrame(dict(zt5h=pair_corr[0], remaze=pair_corr[1], grp=sess.tag))
    pair_corr_df.append(df)

pair_corr_df = pd.concat(pair_corr_df, ignore_index=True)

subjects.GroupData().save(pair_corr_df, "remaze_first5_paircorr")
# subjects.GroupData().save(pair_corr_df, "remaze_last5_paircorr")

#### Shuffle/Bootstrapping methods
- Method 1: Randomize x axis correlations
- Method 2: Subsample x and y and get a distribution of R^2 values and then compare
- Method 3 (**Using this here**): Combined bootstrap, combine both SD and NSD pairwise correlations and then subsample the same number of pairs (for each group) with or without replacement and calclate the R^2 difference and compare their distributions between original group vs this new sampled groups. 

In [None]:
rng = np.random.default_rng()
subsample_df = []

r2, n_pairs = [], []
for i, grp in enumerate(["NSD", "SD"]):
    df = pair_corr_df[pair_corr_df["grp"] == grp]
    x, y = df.zt5h.values, df.remaze.values
    linreg = stats.linregress(x, y)
    r2.append(linreg.rvalue**2)
    n_pairs.append(len(x))

r2_diff = r2[0] - r2[1]
combined_paircorr = pair_corr_df.zt5h.values

r2_boot = []
for i, grp in enumerate(["NSD", "SD"]):
    y = pair_corr_df[pair_corr_df["grp"] == grp].remaze.values

    r2_grp = []
    for p in range(1000):
        x_boot = rng.choice(combined_paircorr, n_pairs[i], replace=False)
        linreg_boot = stats.linregress(x_boot, y)
        r2_grp.append(linreg_boot.rvalue**2)

    r2_boot.append(r2_grp)

r2_boot = np.asarray(r2_boot)
r2_boot_diff = r2_boot[0] - r2_boot[1]

# subsample_df = pd.concat(subsample_df, ignore_index=True)

# subjects.GroupData().save(r2_boot_diff_df,'remaze_first5_bootstrap')
subjects.GroupData().save(
    {"boot_diff": r2_boot_diff, "r2_diff": r2_diff}, "remaze_first5_bootstrap"
)

In [None]:
_, axs = plt.subplots(1, 3)
axs = axs.reshape(-1)
for i, grp in enumerate(["NSD", "SD"]):
    df = pair_corr_df[pair_corr_df["grp"] == grp]
    linreg = stats.linregress(df["zt5h"], df["remaze"])
    sns.scatterplot(data=df, x="zt5h", y="remaze", hue="grp", ax=axs[i])
    # axs[i].set_title(f'r={linreg.rvalue.round(2)}, pvalue={linreg.pvalue}')

sns.histplot(data=r2_boot_diff, ax=axs[2], fill=True, element="step")
axs[2].axvline(r2_diff)
axs[2].set_xscale("log")
# axs[2].axvline(0.08,ls='--')
# axs[2].set_xscale('log')

### Correlation of pairwise correlations across time for remaze sessions

In [None]:
import pingouin as pg

corr = []
for sub, sess in enumerate(sessions):
    rec_duration = sess.eegfile.duration
    neurons = sess.neurons.get_neuron_type("pyr")

    windows = np.arange(0, rec_duration, 300)
    pair_corr = []
    for w in windows[:-1]:
        pair_corr.append(
            neurons.time_slice(w, w + 300).get_binned_spiketrains().get_pairwise_corr()
        )
    pair_corr = np.array(pair_corr).T
    df = pd.DataFrame(pair_corr, columns=np.arange(len(windows) - 1))
    corr.append(df.corr().values)
subjects.GroupData().save(corr, "remaze_corr_across_session")

In [None]:
import matplotlib.pyplot as plt

_, axs = plt.subplots(1, 2)
axs = axs.reshape(-1)

for i in range(2):
    axs[i].imshow(corr[i])

### Paircorr Zt0-5/Zt5-end vs remaze while controlling for MAZE
- Scatter plots of residuals of pariwise correlations after controlling for PRE  

In [None]:
sessions = subjects.remaze_sess()

In [None]:
print(f"#Sessions: {len(sessions)}")


def get_residuals(x, y):
    mask = ~np.isnan(x) & ~np.isnan(y)
    linreg = stats.linregress(x[mask], y[mask])
    slope, intercept = linreg.slope, linreg.intercept
    return y - (slope * x + intercept)


pair_corr_df = []
for sub, sess in enumerate(sessions):
    pre = sess.paradigm["pre"].flatten()
    maze = sess.paradigm["maze"].flatten()
    post = sess.paradigm["post"].flatten()
    remaze = sess.paradigm["re-maze"].flatten()

    e1 = [pre[0], maze[1]]  # maze and pre
    # e1 = maze # maze only
    e2 = [post[0], post[0] + 5 * 3600]  # Zt0-5
    e3 = [post[0] + 5 * 3600, post[1]]  # Zt5-end
    e4 = remaze

    neurons = sess.neurons_stable.get_neuron_type("pyr")
    e1_frate = neurons.time_slice(*e1).firing_rate
    e2_frate = neurons.time_slice(*e2).firing_rate

    good_indices = np.logical_and(e1_frate > 0, e2_frate > 0)
    neurons = neurons[good_indices]

    pair_corr = []
    for e in [e1, e2, e3, e4]:
        pair_corr.append(
            neurons.time_slice(*e).get_binned_spiketrains(0.25).get_pairwise_corr()
        )

    zt05_residuals = get_residuals(pair_corr[0], pair_corr[1])
    zt5e_residuals = get_residuals(pair_corr[0], pair_corr[2])
    remaze_residuals = get_residuals(pair_corr[0], pair_corr[3])

    df = pd.DataFrame(
        dict(
            zt05=zt05_residuals,
            zt5e=zt5e_residuals,
            remaze=remaze_residuals,
            grp=sess.tag,
        )
    )
    pair_corr_df.append(df)

pair_corr_df = pd.concat(pair_corr_df, ignore_index=True)

# subjects.GroupData().save(pair_corr_df, "remaze_residual_corr")
# subjects.GroupData().save(pair_corr_df, "remaze_last5_paircorr")

In [None]:
import seaborn as sns
from neuropy.plotting import Fig

fig = Fig(grid=(8, 6))

pair_corr_df = pair_corr_df.dropna(axis=0).reset_index(drop=True)
# _, axs = plt.subplots(1,2,sharex=True,sharey=True)

for i, e in enumerate(["zt05", "zt5e"]):
    ax = fig.subplot(fig.gs[i])
    sns.scatterplot(
        data=pair_corr_df,
        x=e,
        y="remaze",
        hue="grp",
        s=5,
        palette=subjects.colors_sd(0.9),
        ax=ax,
        rasterized=True,
        alpha=0.6,
    )
    ax.legend("", frameon=False)
    ax.set_xlim(-0.3, 0.7)
    ax.set_ylim(-0.3, 0.7)

    for g, grp in enumerate(["NSD", "SD"]):
        df = pair_corr_df[pair_corr_df["grp"] == grp]
        x = df[e].values
        y = df.remaze.values
        mask = ~np.isnan(x) & ~np.isnan(y)
        linreg = stats.linregress(x[mask], y[mask])
        slope, intercept, r = linreg.slope, linreg.intercept, linreg.rvalue
        print(linreg.rvalue, linreg.pvalue)
        ax.axline(
            (0, intercept), slope=slope, color=subjects.colors_sd(0.8)[g], ls="--"
        )

# fig.savefig(subjects.figpath_sd/'remaze_paircorr_residuals')

### Population vector correlation between maze and remaze

In [None]:
sessions = subjects.remaze_sess()


def remove_empty_bins(arr):
    return arr[:, arr.sum(axis=0) > 0]


pcorrs = []
for s, sess in enumerate(sessions):
    maze = sess.paradigm["maze"].flatten()
    remaze = sess.paradigm["re-maze"].flatten()
    neurons = sess.neurons_stable.get_neuron_type("pyr")

    maze_sc = neurons.time_slice(*maze).get_binned_spiketrains(1).spike_counts
    remaze_sc = neurons.time_slice(*remaze).get_binned_spiketrains(1).spike_counts

    comb_binspk = np.hstack([remove_empty_bins(maze_sc), remove_empty_bins(remaze_sc)])
    pcorrs.append(np.corrcoef(comb_binspk.T))

In [None]:
fig = plotting.Fig(3, 3)

for i, pcorr in enumerate(pcorrs):
    np.fill_diagonal(pcorr, 0)
    ax = fig.subplot(fig.gs[i])
    ax.imshow(pcorr)