In [1]:
import warnings
from typing import Dict

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pingouin as pg
import scipy.signal as sg
import scipy.stats as stats
import seaborn as sns
import signal_process
from callfunc import processData
from mathutil import threshPeriods
from plotUtil import Colormap, Fig
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from joblib import Parallel, delayed
#%matplotlib qt
# warnings.simplefilter(action="default")

  return warn(


### Functions used only within this script

In [3]:
%connect_info in a cell
def doWavelet(lfp, freqs, ncycles=3):
    wavdec = signal_process.wavelet_decomp(lfp, freqs=freqs)
    # wav = wavdec.cohen(ncycles=ncycles)
    wav = wavdec.colgin2009()

    wav = stats.zscore(wav)
    wav = gaussian_filter(wav, sigma=4)

    return wav


def getPxx(lfp):
    window = 5 * 1250

    freq, Pxx = sg.welch(
        lfp,
        fs=1250,
        nperseg=window,
        noverlap=window / 6,
        detrend="linear",
    )
    noise = np.where(
        ((freq > 59) & (freq < 61)) | ((freq > 119) & (freq < 121)) | (freq > 220)
    )[0]
    freq = np.delete(freq, noise)
    Pxx = np.delete(Pxx, noise)

    return Pxx, freq



{
  "shell_port": 9002,
  "iopub_port": 9004,
  "stdin_port": 9003,
  "control_port": 9001,
  "hb_port": 9000,
  "ip": "127.0.0.1",
  "key": "c937697d-9eb3-425d-a279-2ae9de987394",
  "transport": "tcp",
  "signature_scheme": "hmac-sha256",
  "kernel_name": ""
}

Paste the above JSON into a file, and connect with:
    $> jupyter <app> --existing <file>
or, if you are local, you can connect with just:
    $> jupyter <app> --existing /tmp/tmp-446471trgwSVSP3G0G.json
or even just:
    $> jupyter <app> --existing
if this is the most recent Jupyter kernel you have started.


### Subjects/DataPath

In [2]:
basePath = [
    # "/data/Clustering/SleepDeprivation/RatJ/Day1/",
    # "/data/Clustering/SleepDeprivation/RatK/Day1/",
    # "/data/Clustering/SleepDeprivation/RatN/Day1/",
    # "/data/Clustering/SleepDeprivation/RatJ/Day2/",
    # "/data/Clustering/SleepDeprivation/RatK/Day2/",
    "/data/Clustering/SleepDeprivation/RatN/Day2/",
    # "/data/Clustering/SleepDeprivation/RatJ/Day4/",
    # "/data/Clustering/SleepDeprivation/RatK/Day4/",
    # "/data/Clustering/SleepDeprivation/RatN/Day4/",
    # "/data/Clustering/SleepDeprivation/RatA14d1LP/Rollipram/",
]
sessions = [processData(_) for _ in basePath]


### Example figure of power spectral density and changes w.r.t speed

In [5]:

figure = Fig()
fig, gs = figure.draw(grid=(2, 2))
for sub, sess in enumerate(sessions):
    eegSrate = sess.recinfo.lfpSrate
    maze = sess.epochs.maze
    chan = sess.theta.bestchan
    eeg = sess.recinfo.geteeg(chans=chan, timeRange=maze)
    f, pxx = sg.welch(eeg, fs=eegSrate, nperseg=5 * 1250, noverlap=1250)

    ax = plt.subplot(gs[0])
    ax.plot(f, pxx)
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_ylabel('Power')
    ax.set_xlabel('Frequency (Hz)')


### Phase-amplitude comodulogram for multiple frequencies 

In [None]:

# during REM sleep
plt.clf()
fig = plt.figure(1, figsize=(1, 15))
gs = GridSpec(2, 3, figure=fig)
fig.subplots_adjust(hspace=0.5)

colband = ["#CE93D8", "#1565C0", "#E65100"]
p = Pac(idpac=(6, 3, 0), f_pha=(4, 10, 1, 1), f_amp=(30, 100, 5, 5))

for sub, sess in enumerate(sessions):

    sess.trange = np.array([])
    tstart = sess.epochs.post[0]
    tend = sess.epochs.post[0] + 5 * 3600
    lfp, _, _ = sess.spindle.best_chan_lfp()
    t = np.linspace(0, len(lfp) / 1250, len(lfp))
    states = sess.brainstates.states

    if sub < 3:
        plt_ind = sub
        # color = "r"
        # color = colband[sub]
        lnstyle = "solid"
        rem = states[(states["start"] > tend) & (states["name"] == "rem")]
    else:
        plt_ind = sub - 3
        # color = colband[sub - 3]
        lnstyle = "dashed"
        rem = states[(states["start"] > tstart) & (states["name"] == "rem")]

    binlfp = lambda x, t1, t2: x[(t > t1) & (t < t2)]
    freqIntervals = [[30, 50], [50, 90], [100, 150]]  # in Hz

    lfprem = []
    for epoch in rem.itertuples():
        lfprem.extend(binlfp(lfp, epoch.start, epoch.end))

    lfprem = np.asarray(lfprem)

    xpac = p.filterfit(1250.0, lfprem, n_perm=20)
    theta_lfp = stats.zscore(filter_sig.filter_theta(lfprem))
    hil_theta = hilbertfast(theta_lfp)
    theta_amp = np.abs(hil_theta)
    theta_angle = np.angle(hil_theta, deg=True) + 180
    angle_bin = np.arange(0, 360, 20)
    bin_ind = np.digitize(theta_angle, bins=angle_bin)

    ax = fig.add_subplot(gs[sub])
    # ax.plot(
    #     angle_bin[:-1] + 10, mean_amp_norm, linestyle=lnstyle, color=colband[band]
    # )
    # ax.set_xlabel("Degree (from theta trough)")
    # ax.set_ylabel("Amplitude")
    p.comodulogram(
        xpac.mean(-1),
        title="Contour plot with 5 regions",
        cmap="Spectral_r",
        plotas="contour",
        ncontours=7,
    )

    ax.set_title()




### Schematic --> Theta phase specific extraction method

In [None]:
figure = Fig()
fig, gs = figure.draw(grid=(5, 3))
fig.suptitle("Phase specfic extraction schematic")
for sub, sess in enumerate(sessions):
    eegSrate = sess.recinfo.lfpSrate
    maze = sess.epochs.maze
    thetachan = sess.theta.bestchan
    eeg = sess.recinfo.geteeg(chans=thetachan, timeRange=maze)
    strong_theta = stats.zscore(sess.theta.getstrongTheta(eeg)[0])
    rand_start = np.random.randint(0, len(strong_theta), 1)[0]
    theta_sample = strong_theta[rand_start : rand_start + 1 * eegSrate]
    thetaparams = sess.theta.getParams(theta_sample)
    gamma_lfp = signal_process.filter_sig.highpass(theta_sample, cutoff=25)

    # ----- dividing 360 degress into non-overlapping 5 bins ------------
    angle_bin = np.linspace(0, 360, 6)  # 5 bins so each bin=25ms
    angle_centers = angle_bin + np.diff(angle_bin).mean() / 2
    bin_ind = np.digitize(thetaparams.angle, bins=angle_bin)
    df = {}
    ax = plt.subplot(gs[0, :])
    cmap = mpl.cm.get_cmap("RdPu")
    for phase in range(1, len(angle_bin)):
        df[phase] = gamma_lfp[np.where(bin_ind == phase)[0]]

        ax.fill_between(
            np.arange(len(theta_sample)),
            np.min(theta_sample),
            theta_sample,
            where=(bin_ind == phase),
            # interpolate=False,
            color=cmap((phase + 1) / 10),
            # alpha=0.3,
            zorder=1,
        )
        # theta_atphase = theta_sample[np.where(bin_ind == phase)[0]]
        # ax.plot(theta_atphase)
    ax.plot(theta_sample, "k", zorder=2)
    ax.plot(thetaparams.lfp_filtered, "r", zorder=3)
    ax.plot(gamma_lfp - 3, color="#3b1641", zorder=3)
    ax.set_xlim([0, len(theta_sample)])
    ax.axis("off")

    axphase = plt.subplot(gs[1, :2])
    y_shift = 0.2
    for i in range(1, 6):
        axphase.plot(df[i] + y_shift, color=cmap((i + 1) / 10))
        axphase.axis("off")
        y_shift += 0.9
        axphase.set_ylim([-3.5, 4.8])


### Theta phhase specific extraction of lfp during strong theta MAZE with different binning techniques

In [None]:

figure = Fig()
fig, gs = figure.draw(grid=[4, 3], wspace=0.4)

for sub, sess in enumerate(sessions[7:8]):

    sess.trange = np.array([])
    eegSrate = sess.recinfo.lfpSrate
    maze = sess.epochs.maze

    lfpmaze = sess.recinfo.geteeg(chans=11, timeRange=maze)
    strong_theta = sess.theta.getstrongTheta(lfpmaze)[0]

    gamma_lfp = stats.zscore(
        signal_process.filter_sig.highpass(strong_theta, cutoff=25, order=3)
    )

    """
    phase specific extraction of highpass filtered strong theta periods (>25 Hz) and concatenating similar phases across multiple theta cycles
    """

    def getPxxData(**kwargs):

        gamma_bin, _, angle_centers = sess.theta.phase_specfic_extraction(
            strong_theta, gamma_lfp, **kwargs
        )
        df = pd.DataFrame()
        f_ = None
        for lfp, center in zip(gamma_bin, angle_centers):
            f_, pxx = sg.welch(lfp, nperseg=1250, noverlap=625, fs=1250)
            df[center] = np.log10(pxx)
        df.insert(0, "freq", f_)
        return df

    # ----- dividing 360 degress into multiple bins ------------
    binconfig = [[72, None], [40, None], [40, 5]]  # degree, degree
    binData = [getPxxData(window=wind, slideby=sld) for (wind, sld) in binconfig]

    bin_names = ["5bin", "9bin", "slide"]
    for i, df in enumerate(binData):
        ax = plt.subplot(gs[sub, i])
        data = df[df.freq < 200].set_index("freq")  # .transform(stats.zscore, axis=1)
        ax.pcolormesh(data.columns, data.index, data, cmap="jet", shading="auto")
        ax.set_xlabel(r"$\theta$ phase")
        ax.set_ylabel("Frequency (Hz)")
        ax.set_title(bin_names[i])
        # ax.set_xticks([0, data.shape[1] // 2, data.shape[1]])
        # ax.set_xticklabels(["0", "180", "360"])
        # ax.locator_params(axis="x", nbins=4)

# axbin1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
# figure.savefig("phase_specific_slowgamma_openfield", __file__)



### bicoherence in multiple channels from High velocity epochs on MAZE

In [None]:

data: Dict[str, np.array] = {}

for sub, sess in enumerate(sessions[6:7]):

    sess.trange = np.array([])
    eegSrate = sess.recinfo.lfpSrate
    changrp = sess.recinfo.goodchangrp
    maze = sess.epochs.maze
    speed = sess.position.speed
    t_position = sess.position.t[1:]
    chans2plot = np.concatenate([shank[::6] for shank in changrp]).astype(int)
    shank = [
        shank
        for shank in range(len(changrp))
        for chan in chans2plot
        if chan in changrp[shank]
    ]

    lfpmaze = sess.utils.geteeg(chans=chans2plot, timeRange=maze)
    lfpmaze_t = np.linspace(maze[0], maze[1], lfpmaze.shape[-1])
    speed = np.interp(lfpmaze_t, t_position, speed)
    speed = gaussian_filter1d(speed, sigma=10)

    frames_high_spd = np.where(speed > 25)[0]
    lfp_highspd = lfpmaze[:, frames_high_spd]

    frames_slow_spd = np.where(speed <= 25)[0]
    lfp_lowspd = lfpmaze[:, frames_slow_spd]

    # ---- filtering strong theta periods into theta and gamma band ------
    # theta_lfp = stats.zscore(
    #     signal_process.filter_sig.bandpass(lfp_highspd, lf=1, hf=25, ax=-1)
    # )

    # gamma_lfp = stats.zscore(
    #     signal_process.filter_sig.highpass(lfp_highspd, cutoff=25, order=3, ax=-1)
    # )

    # ----- phase detection for theta band -----------
    # filt_theta = signal_process.filter_sig.filter_cust(theta_lfp, lf=20, hf=60)
    # hil_theta = signal_process.hilbertfast(theta_lfp)
    # theta_amp = np.abs(hil_theta)
    # theta_angle = np.angle(hil_theta, deg=True) + 180  # range from 0 to 360

    # ------ psd calculation-----------
    f_, pxx = sg.welch(lfp_highspd, fs=1250, nperseg=4 * 1250, noverlap=2 * 250)
    f_slow, pxx_slow = sg.welch(lfp_lowspd, fs=1250, nperseg=4 * 1250, noverlap=2 * 250)

    # ---- bicoherence calculation ----------
    bicoh, f, _ = signal_process.bicoherence_m(lfp_highspd, flow=1, fhigh=70)

    data[sub] = {
        "chans": chans2plot,
        "fpxx_slow": f_slow,
        "pxx_slow": pxx_slow,
        "fpxx": f_,
        "pxx": pxx,
        "fbicoh": f,
        "bicoh": bicoh,
    }


# ---- plotting ----------
figure = Fig()
cmap = Colormap().dynamic3()
for i in range(len(data)):
    data_sub = data[i]
    fig, gs = figure.draw(num=i + 1, grid=[7, 8], size=[15, 15])
    for chan in range(len(data_sub["chans"])):
        ax = plt.subplot(gs[2 * chan])
        ax.plot(data_sub["fpxx"], data_sub["pxx"][chan])
        ax.plot(data_sub["fpxx_slow"], data_sub["pxx_slow"][chan])
        ax.set_ylabel("Power")
        ax.set_xlabel("Frequency (Hz)")
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlim([3, 200])
        ax.set_ylim(bottom=10)

        ax = plt.subplot(gs[2 * chan + 1])
        # ax.imshow(data_sub["bicoh"][chan, :, :])
        bic = data_sub["bicoh"][chan, :, :]
        bic = np.sqrt(bic)
        lt = np.tril_indices_from(bic, k=-1)
        bic[lt] = np.nan
        bic[(lt[0], -lt[1])] = np.nan
        bic = bic - np.nanmean(bic)
        bic[bic < 0.1] = 0
        # bic = stats.mstats.zscore(bic, nan_policy="omit")
        # bic = gaussian_filter(bic, sigma=0.5)
        bicoh_plt = ax.pcolormesh(
            data_sub["fbicoh"],
            data_sub["fbicoh"],
            bic,
            cmap=cmap,
            # shading="gouraud",
            vmin=-0.2,
            vmax=0.2,
        )

        ax.set_ylim([0, np.max(data_sub["fbicoh"]) / 2])

        ax.plot(
            [1, np.max(data_sub["fbicoh"]) / 2],
            [1, np.max(data_sub["fbicoh"]) / 2],
            "gray",
        )
        ax.plot(
            [np.max(data_sub["fbicoh"]) / 2, np.max(data_sub["fbicoh"])],
            [np.max(data_sub["fbicoh"]) / 2, 1],
            "gray",
        )
        # ax.set_title(sessions[i].sessinfo.session.sessionName)


