### Imports

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import scipy.signal as sg
import matplotlib.pyplot as plt
from neuropy import plotting
import subjects

### Posterior time lag comparison SD vs NSD

In [None]:
import scipy.signal as sg
from scipy.ndimage import gaussian_filter1d

sessions = subjects.pf_sess()

lag_time_all, grp = [], []
for sub, sess in enumerate(sessions):
    neurons = sess.neurons_stable.get_neuron_type("pyr")
    pre = sess.paradigm["pre"].flatten()
    post = sess.paradigm["post"].flatten()
    maze = sess.paradigm["maze"].flatten()
    period = [post[0] + 2.5 * 3600, post[0] + 5 * 3600]
    # period = [pre[0], pre[1]]
    # period = maze
    starts = sess.replay_radon_mua.starts

    replay_pbe = sess.replay_radon_mua.to_dataframe()
    pbe_epochs = sess.replay_radon_mua.flatten()
    pbe_peak = replay_pbe.peak_time
    n_spikes = [np.histogram(_, bins=pbe_epochs)[0][::2] for _ in neurons.spiketrains]
    n_neurons_bool = (np.array(n_spikes) > 0).sum(axis=0) >= 5

    pbe_spikes, nbins = neurons.get_spikes_in_epochs(sess.pbe, bin_size=0.02)
    each_bin_bool = np.array([np.sum(arr.sum(axis=0) > 0) for arr in pbe_spikes])
    each_bin_bool = each_bin_bool / nbins > 0.8

    rpl_epochs = sess.ripple.flatten()
    rpl_bool = np.digitize(pbe_peak, rpl_epochs) % 2 == 1

    run_epochs = sess.maze_run.flatten()
    run_bool = np.digitize(pbe_peak, run_epochs) % 2 == 1

    good_bool = rpl_bool & ~run_bool & n_neurons_bool & each_bin_bool
    # good_bool = ~run_bool

    ind = (starts >= period[0]) & (starts <= period[1]) & good_bool
    # ind = starts<post[0]

    posteriors = sess.replay_radon_mua.metadata["down_posterior"]
    posteriors = [posteriors[_] for _ in np.argwhere(ind).squeeze()]
    posteriors = np.hstack(posteriors)

    npos, nt = posteriors.shape
    rows, cols = np.tril_indices(npos, -1)
    lags = sg.correlation_lags(nt, nt, mode="same") * 0.02
    idx = (lags >= -0.5) & (lags <= 0.5)
    lags = lags[idx]

    lag_time = np.zeros((npos, npos))
    for p1 in range(npos):
        for p2 in range(npos):
            xcorr = sg.correlate(
                posteriors[p1], posteriors[p2], mode="same", method="fft"
            )[idx]
            lag_time[p1, p2] = np.abs(
                lags[gaussian_filter1d(xcorr, sigma=0.06 / 0.02).argmax()]
            )

    lag_time_all.append(lag_time)
    grp.append(sess.tag)

In [None]:
_, axs = plt.subplots(7, 2)

for i, (lag_mat, g) in enumerate(zip(lag_time_all, grp)):
    if g == "NSD":
        ax = axs[i, 0]
    else:
        ax = axs[i - 6, 1]

    im = ax.pcolormesh(np.abs(lag_mat), cmap="viridis", vmin=0, vmax=0.4)
    cb = plt.colorbar(im, ax=ax)

In [None]:
_, axs = plt.subplots(7, 2, sharex=True, sharey=True)

for i, (lag_mat, g) in enumerate(zip(lag_time_all, grp)):
    if g == "NSD":
        ax = axs[i, 0]
    else:
        ax = axs[i - 6, 0]

    # im = ax.pcolormesh(lag_mat,cmap='viridis',vmin=0,vmax=0.4)
    # cb = plt.colorbar(im,ax=ax)
    pos_bins = np.arange(lag_mat.shape[0])
    pos_dist = np.abs(pos_bins[np.newaxis, :] - pos_bins[:, np.newaxis])

    mean_lags = []
    for d in np.arange(60):
        mean_lags.append(lag_mat[pos_dist == d].mean())

    ax.plot(np.arange(60), mean_lags)