In [1]:
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
import yasa
from mne import set_log_level
from neurolib.models.multimodel import MultiModel
from neurolib.utils.signal import RatesSignal
from yasa import get_centered_indices, stft_power

from aln_thalamus import ALNThalamusMiniNetwork
from plotting import (
    plot_average_events_ts,
    plot_circular_histogram,
    plot_kullback_leibler_modulation_index,
)
from spindle_detection import (
    so_phase_while_spindle,
    spindles_detect_aln,
    spindles_detect_thalamus,
)
from statistical_testing import get_p_values
from utils import dummy_detect_down_states, get_amplitude, get_dummy_so_phase, get_phase
from xfreq import (
    XFreqEvaluateSignal,
    kullback_leibler_modulation_index,
    mean_vector_length,
    mutual_information,
    phase_locking_value,
)

DPI = 75
CMAP = "plasma"
plt.rcParams["figure.figsize"] = (20, 9)
plt.rcParams["font.size"] = 18
plt.style.use("default_light")
set_log_level("WARNING")

In [2]:
POLAR_XTICKLABELS = [r"$-\pi$ = $\pi$", r"$-\pi/2$", "0", r"$\pi/2$", r"$\pi$"]
DELAY = 13.0  # ms
DURATION = 90000  # ms
OU_TAU = 5.0  # ms
DT = 0.01  # ms
SAMPLING_DT = 1.0  # ms
T_SPIN_UP = 5000  # ms
SW = {"low_freq": 0.1, "high_freq": 3.0}
SP = {"low_freq": 11.0, "high_freq": 16.0}


def simulate_net(
    ad_th,
    th_ad,
    ou_exc_mu,
    ou_inh_mu,
    aln_sigma=0.0,
    tcr_sigma=0.0,
    tauA=1000.0,
    b=15.0,
):
    # init model
    model = MultiModel(
        ALNThalamusMiniNetwork(
            np.array([[0.0, th_ad], [ad_th, 0.0]]),
            np.array([[0.0, DELAY], [DELAY, 0.0]]),
        )
    )
    model.params["*g_LK"] = 0.032
    model.params["1ALNThlmNet.ALNNode_0.ALNMassEXC_0.a"] = 0.0
    model.params["*b"] = b
    model.params["*tauA"] = tauA
    model.params["*EXC*mu"] = ou_exc_mu
    model.params["*INH*mu"] = ou_inh_mu
    model.params["*ALNMass*noise*sigma"] = aln_sigma
    model.params["*TCR*noise*sigma"] = tcr_sigma
    model.params["*noise*tau"] = OU_TAU
    model.params["duration"] = DURATION + T_SPIN_UP
    model.params["dt"] = DT
    model.params["sampling_dt"] = SAMPLING_DT
    model.params["backend"] = "numba"

    model.run()

    results_df = pd.DataFrame(
        {
            "ALN": model.r_mean_EXC[0, :] * 1000.0,
            "TCR": model.r_mean_EXC[1, :] * 1000.0,
        },
        index=model.t,
    )
    results_df.index.name = "time"

    return results_df

def so_phase(signal):
    return get_phase(signal, filter_args=SW, pad=5.0)


def spindle_phase(signal):
    return get_phase(signal, filter_args=SP, pad=5.0)


def spindle_amp(signal):
    return get_amplitude(signal, filter_args=SP, pad=5.0)


klmi = XFreqEvaluateSignal(
    measure_function=kullback_leibler_modulation_index,
    slow_timeseries_preprocessing=so_phase,
    fast_timeseries_preprocessing=spindle_amp,
    measure_settings={"bins": 36, "return_for_plotting": True},
    surrogate_settings={"num_surr": 1000, "surrogate_type": "FT"},
)

mvl = XFreqEvaluateSignal(
    measure_function=lambda x, y: np.abs(mean_vector_length(x, y)),
    slow_timeseries_preprocessing=so_phase,
    fast_timeseries_preprocessing=spindle_amp,
    measure_settings={},
    surrogate_settings={"num_surr": 1000, "surrogate_type": "FT"},
)

plv = XFreqEvaluateSignal(
    measure_function=lambda x, y, n, m: np.abs(phase_locking_value(x, y, n, m)),
    slow_timeseries_preprocessing=so_phase,
    fast_timeseries_preprocessing=spindle_phase,
    measure_settings={"n": 1, "m": 1},
    surrogate_settings={"num_surr": 1000, "surrogate_type": "FT"},
)

mi_eqq = XFreqEvaluateSignal(
    measure_function=mutual_information,
    slow_timeseries_preprocessing=so_phase,
    fast_timeseries_preprocessing=spindle_phase,
    measure_settings={"algorithm": "EQQ", "bins": 16},
    surrogate_settings={"num_surr": 1000, "surrogate_type": "FT"},
)


def klmi_eval(slow_ts, fast_ts, subtitle=""):
    klmi_data, klmi_surrs = klmi.run(slow_timeseries=slow_ts, fast_timeseries=fast_ts)
    surrs_values = np.array([surr_result[0] for surr_result in klmi_surrs])
    p_val = get_p_values(klmi_data[0], surrs_values, tailed="upper")
    klmi_val = klmi_data[0]
    print(" === KL-MI === ")
    print(f" --- {subtitle} --- ")
    print(f"Data value: {klmi_val:.4f}")
    print(f"p-value: {p_val:.4f}")
    return klmi_data, klmi_surrs


def mvl_eval(slow_ts, fast_ts, subtitle=""):
    mvl_data, mvl_surrs = mvl.run(slow_timeseries=slow_ts, fast_timeseries=fast_ts)
    p_val = get_p_values(mvl_data, mvl_surrs, tailed="upper")
    print(" === MVL === ")
    print(f" --- {subtitle} --- ")
    print(f"Data value: {mvl_data:.4f}")
    print(f"p-value: {p_val:.4f}")
    
    
def plv_eval(slow_ts, fast_ts, subtitle=""):
    plv_data, plv_surrs = plv.run(slow_timeseries=slow_ts, fast_timeseries=fast_ts)
    p_val = get_p_values(plv_data, plv_surrs, tailed="upper")
    print(" === PLV === ")
    print(f" --- {subtitle} --- ")
    print(f"Data value: {plv_data:.4f}")
    print(f"p-value: {p_val:.4f}")
    
    
def mi_eval(slow_ts, fast_ts, subtitle=""):
    mi_eqq_data, mi_eqq_surrs = mi_eqq.run(
        slow_timeseries=slow_ts, fast_timeseries=fast_ts
    )
    p_val = get_p_values(mi_eqq_data, mi_eqq_surrs, tailed="upper")
    print(" === MI EQQ === ")
    print(f" --- {subtitle} --- ")
    print(f"Data value: {mi_eqq_data:.4f}")
    print(f"p-value: {p_val:.4f}")

In [None]:
params = [(3.1, "right_border")]  # , (2.25, "left_border"), (2.4, "inside_lc")]
titles = {
    "right_border": r"Border $\mathregular{{LC_{{aE}}}}$ $\times$ up",
    "left_border": r"Border down $\times$ $\mathregular{{LC_{{aE}}}}$",
    "inside_lc": "Inside $\mathregular{{LC_{{aE}}}}$",
}
# prepare
SPINDLE_DURATION = (0.15, 2.0)
REL_POWER_ALN = 0.05
BEFORE = 1.2
AFTER = 1.2

for exc_inp, plot_name in params:

    # simulte network
    results_df = simulate_net(
        ad_th=1.0,
        th_ad=0.15,
        ou_exc_mu=exc_inp,
        ou_inh_mu=3.5,
        aln_sigma=0.05,
        tcr_sigma=0.005,
        tauA=1000.0,
        b=15.0,
    )

    aln_xr = xr.DataArray(results_df["ALN"])
    tcr_xr = xr.DataArray(results_df["TCR"])
    aln_sig = RatesSignal(aln_xr)
    tcr_sig = RatesSignal(tcr_xr)

    fig = plt.figure(figsize=(20, 11))
    gs = fig.add_gridspec(nrows=5, ncols=4, height_ratios=[1, 1, 0.1, 1, 1])
    gs.update(wspace=0.2, hspace=0.7)

    # timeseries plot ALN
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.plot(results_df.index, results_df["ALN"], color="k", linewidth=1.5)
    aln_spindles = spindles_detect_aln(
        xr.DataArray(results_df["ALN"]),
        duration=SPINDLE_DURATION,
        rel_power=REL_POWER_ALN,
    )
    if aln_spindles is not None:
        spindles_highlight = results_df["ALN"] * aln_spindles.get_mask()
        spindles_highlight[spindles_highlight == 0] = np.nan
        ax1.plot(results_df.index, spindles_highlight, color="indianred", linewidth=1.5)
    ax1.set_ylabel("ALN r [Hz]")
    ax1.set_ylim([0, 40])
    ax1.set_yticks([0, 20, 40])
    ax1.set_xlim([15, 35])
    sns.despine(trim=True, ax=ax1)

    # timeseries plot TCR
    ax2 = fig.add_subplot(gs[3, :2], sharex=ax1)
    ax2.plot(results_df.index, results_df["TCR"], color="k", linewidth=1.5)
    thal_spindles = spindles_detect_thalamus(
        xr.DataArray(results_df["TCR"]), duration=SPINDLE_DURATION
    )
    spindles_highlight = results_df["TCR"] * thal_spindles.get_mask()
    spindles_highlight[spindles_highlight == 0] = np.nan
    ax2.plot(results_df.index, spindles_highlight, color="indianred", linewidth=1.5)
    ax2.set_ylabel("TCR r [Hz]")
    ax2.set_ylim([0, 400])
    ax2.set_yticks([0, 200, 400])
    ax2.set_xlim([15, 35])
    sns.despine(trim=True, ax=ax2)

    # TFR plot ALN and TCR
    sampling_freq = 1.0 / (results_df.index[1] - results_df.index[0])
    window = 2.0  # seconds
    step = 0.2  # seconds
    freqs_bounds = (0.1, 20.0)
    vmin = 0.01
    vmax = 0.3
    for ii, node in enumerate(["ALN", "TCR"]):
        f, _, Sxx = stft_power(
            results_df[node],
            sampling_freq,
            window=window,
            step=step,
            band=freqs_bounds,
            norm=True,
            interp=True,
        )
        ax = fig.add_subplot(gs[3 * ii + 1, :2], sharex=ax1)
        ax.pcolormesh(
            results_df.index,
            f,
            Sxx,
            cmap=CMAP,
            rasterized=True,
            norm=mpl.colors.LogNorm(vmin=vmin, vmax=vmax),
        )
        ax.grid()
        ax.set_ylabel(f"{node} f [Hz]")
        sns.despine(trim=True, ax=ax)
        ax.set_xlim([15, 35])
    ax.set_xlabel("time [sec]")
    cbar_ax = fig.add_axes([0.2, 0.0, 0.2, 0.02])
    cbar = mpl.colorbar.ColorbarBase(
        cbar_ax,
        cmap=plt.get_cmap(CMAP),
        norm=mpl.colors.LogNorm(vmin=vmin, vmax=vmax),
        orientation="horizontal",
        extend="both",
    )
    cbar.set_label("power Hz$^{2}$/Hz")

    aln_down_states = dummy_detect_down_states(
        aln_sig, threshold=2.0, min_down_length=0.3
    )
    ds_midpoints = np.array([ds[len(ds) // 2] for ds in aln_down_states])
    aln_ds_idx, _ = get_centered_indices(
        aln_sig.data.values,
        ds_midpoints,
        int(BEFORE * aln_sig.sampling_frequency),
        int(AFTER * aln_sig.sampling_frequency),
    )

    # locked on ALN down states
    ax3 = fig.add_subplot(gs[0, 2:])
    plot_average_events_ts(
        aln_xr,
        events_idx=aln_ds_idx,
        time_before=BEFORE,
        time_after=AFTER,
        color="indianred",
        ylabel="",
        title="",
        second_ts=tcr_xr,
        color_second_ts="k",
        ax=ax3,
    )
    ax3.set_title("")
    ax3.set_xlabel("time w.r.t DOWN state [sec]")
    ax3.set_ylim([0, 40])
    sns.despine(trim=True, ax=ax3)
    sns.despine(trim=True, ax=plt.gca())
    plt.gca().set_yticks([])

    # spindle stats
    isi = (
        aln_spindles.summary().shift(-1)["Start"] - aln_spindles.summary()["End"]
    ).dropna()
    ax5 = fig.add_subplot(gs[1, 2])
    ax5.hist(isi, bins=15, color="k", rwidth=0.9)
    sns.despine(ax=ax5)
    ax5.set_xlabel("ISI [sec]")
    ax5.set_yticks([])

    aln_so_phase = get_dummy_so_phase(aln_sig, threshold=2.0, min_down_length=0.3)
    aln_sp_amp = get_amplitude(aln_sig, filter_args=SP, pad=5.0)
    so_phases = so_phase_while_spindle(
        aln_so_phase.data.values, aln_sp_amp.data.values, ds_midpoints
    )
    ax7 = fig.add_subplot(gs[1, 3], projection="polar")
    plot_circular_histogram(so_phases, ax=ax7)

    # average spindle TCR
    tcr_idx, _ = get_centered_indices(
        tcr_xr.values,
        (thal_spindles.summary()["Peak"] * tcr_sig.sampling_frequency)
        .astype(int)
        .to_numpy(),
        int(BEFORE * tcr_sig.sampling_frequency),
        int(AFTER * tcr_sig.sampling_frequency),
    )
    ax4 = fig.add_subplot(gs[3, 2:])
    plot_average_events_ts(
        tcr_xr,
        events_idx=tcr_idx,
        time_before=BEFORE,
        time_after=AFTER,
        color="indianred",
        ylabel="",
        title="",
        second_ts=aln_xr,
        color_second_ts="k",
        ax=ax4,
    )
    ax4.set_title("")
    ax4.set_xlabel("time w.r.t spindle peak [sec]")
    ax4.set_ylim([0, 400])
    sns.despine(trim=True, ax=ax4)
    sns.despine(trim=True, ax=plt.gca())
    plt.gca().set_yticks([])

    # spindle stats
    isi = (
        thal_spindles.summary().shift(-1)["Start"] - thal_spindles.summary()["End"]
    ).dropna()
    ax6 = fig.add_subplot(gs[4, 2])
    ax6.hist(isi, bins=15, color="k", rwidth=0.9)
    sns.despine(ax=ax6)
    ax6.set_xlabel("ISI [sec]")
    ax6.set_yticks([])

    ax8 = fig.add_subplot(gs[4, 3])
    ax8.hist(thal_spindles.summary()["Frequency"], bins=15, color="k", rwidth=0.9)
    sns.despine(ax=ax8)
    ax8.set_xlabel("Spindle frequency [Hz]")
    ax8.set_yticks([])

    plt.suptitle(
        titles[plot_name]
        + f"\n$\mu_{{E}}={exc_inp / 5.0:.3f}$ nA, $\mu_{{I}}={3.5 / 5.0:.3f}$ nA"
    )

    plt.text(
        0.07,
        0.92,
        "ALN E",
        ha="center",
        va="center",
        transform=fig.transFigure,
        fontsize=35,
    )
    plt.text(
        0.07,
        0.49,
        "TCR",
        ha="center",
        va="center",
        transform=fig.transFigure,
        fontsize=35,
    )

    # to PDF due transparency
    plt.savefig(
        f"../figs/loop_spindles_{plot_name}.pdf", transparent=True, bbox_inches="tight"
    )

    # phase-amplitude CFC
    plt.rcParams["figure.figsize"] = (12, 5.4)
    klmi_data, klmi_surrs = klmi_eval(aln_sig, tcr_sig, subtitle="ALN vs TCR")
    plot_kullback_leibler_modulation_index(
        klmi_data,
        klmi_surrs,
        data_color="C0",
        surr_color="C1",
        perc_color="C1",
    )
    sns.despine(trim=True)
    plt.gca().set_xlabel("ALN slow wave phase")
    plt.gca().set_ylabel("TCR spindle amplitude")
    plt.gca().set_title(f"KL-MI value in data: {klmi_data[0]:.4f}")
    plt.savefig(
        f"../figs/KLMI_aln_tcr_{plot_name}.pdf", transparent=True, bbox_inches="tight"
    )

    mvl_eval(aln_sig, tcr_sig, subtitle="ALN vs TCR")

    klmi_data, klmi_surrs = klmi_eval(aln_sig, aln_sig, subtitle="ALN vs ALN")
    plot_kullback_leibler_modulation_index(
        klmi_data,
        klmi_surrs,
        data_color="C0",
        surr_color="w",
        perc_color="C1",
    )
    sns.despine(trim=True)
    plt.gca().set_xlabel("ALN slow wave phase")
    plt.gca().set_ylabel("ALN spindle amplitude")
    plt.gca().set_title(f"KL-MI value in data: {klmi_data[0]:.4f}")
    plt.savefig(
        f"../figs/KLMI_aln_aln_{plot_name}.pdf", transparent=True, bbox_inches="tight"
    )

    mvl_eval(aln_sig, aln_sig, subtitle="ALN vs ALN")


    # phase-phase CFC
    plv_eval(aln_sig, tcr_sig, subtitle="ALN vs TCR")
    mi_eval(aln_sig, tcr_sig, subtitle="ALN vs TCR")
    
    plv_eval(aln_sig, aln_sig, subtitle="ALN vs ALN")
    mi_eval(aln_sig, aln_sig, subtitle="ALN vs ALN")

  for _ in range(self.num_surr)


 === KL-MI === 
 --- ALN vs TCR --- 
Data value: 0.0307
p-value: 0.0000
