In [1]:
"""Plot time delay estimates (Figure 3d)."""

import os
import sys
from pathlib import Path
import numpy as np
import copy
import pandas as pd
from matplotlib import pyplot as plt
import pte_stats
import pte_decode
from scipy.stats import sem
from scipy.signal import find_peaks

cd_path = Path(os.getcwd()).absolute().parent
sys.path.append(os.path.join(cd_path, "coherence"))

import matplotlib
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc("font", size=6, family="Arial")
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

prop_cycle = plt.rcParams["axes.prop_cycle"]

n_perm = 100_000
two_sided = False
alpha = 0.05

FOLDERPATH_ANALYSIS = "Path_to\\Project\\Analysis"
FOLDERPATH_FIGURES = os.path.join(os.path.dirname(os.getcwd()), "figures")

ltime = -5
htime = 51

subregions_mapping = {
    "Motor cortex": "motor",
    "Parietal cortex": "parietal"
}

In [None]:
# CORTEX -> STN TIME DELAY (Figure 3d)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_con_tde-StimOffOn_multi_sub.pkl")
)
times = copy.deepcopy(data["timepoints"][0])
data["timepoints"] = [times] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

subregion_i = 0
fig, axis = plt.subplots(2, 1, gridspec_kw={"height_ratios": [2, 1]})
hist_bins = np.arange(0, htime+1, 10)
hist_bins[0] += 1
tde = {}
tde_off = {}
tde_on = {}
tde_off_noavg = {}
tde_on_noavg = {}
peaks_off = {}
peaks_on = {}
peaks_subs_off = {}
peaks_subs_on = {}
peaks_subs = {}
n_cons_subs = {}
n_cons_subs_off = {}
n_cons_subs_on = {}
peaks = {}
n_subs = {}

for subregion_name, subregion_label in subregions_mapping.items():
    tde_off[subregion_label] = []
    tde_off_noavg[subregion_label] = []
    tde_on[subregion_label] = []
    tde_on_noavg[subregion_label] = []
    tau = []
    peaks_off[subregion_label] = []
    peaks_on[subregion_label] = []
    peaks_subs_off[subregion_label] = {}
    peaks_subs_on[subregion_label] = {}
    peaks_subs[subregion_label] = {}
    n_cons_subs[subregion_label] = {}
    n_cons_subs_off[subregion_label] = {}
    n_cons_subs_on[subregion_label] = {}
    for sub in np.unique(data["sub"]):
        peaks_subs[subregion_label][sub] = []
        n_cons_subs[subregion_label][sub] = 0
        sub_off_idcs = ((data["stim"] == "Off") & (data["freq_band_names"] == "all")
                        & (data["seed_subregions"] == subregion_label) & (data["sub"] == sub)
                        & ~((data["tde-i_standard_tau_ci_80_low"] < 0) & (data["tde-i_standard_tau_ci_80_high"] > 0)))
        tde_off_sub = np.array(data["tde-i_standard"][sub_off_idcs].to_list())
        tau_off_sub = np.array(data["tde-i_standard_tau"][sub_off_idcs].to_list())
        n_cons_off = tde_off_sub.shape[0]
        if tde_off_sub.size > 0:
            peaks_off_sub = []
            for this_tde in tde_off_sub:
                peaks_off_sub.extend(find_peaks(this_tde, height=this_tde[times.index(0)], prominence=(0.5, None), distance=5)[0])
            tde_off_sub = tde_off_sub.mean(0)
            tau_off_sub = tau_off_sub.mean(0)

        sub_on_idcs = ((data["stim"] == "On") & (data["freq_band_names"] == "all")
                        & (data["seed_subregions"] == subregion_label) & (data["sub"] == sub)
                        & ~((data["tde-i_standard_tau_ci_80_low"] < 0) & (data["tde-i_standard_tau_ci_80_high"] > 0)))
        tde_on_sub = np.array(data["tde-i_standard"][sub_on_idcs].to_list())
        tau_on_sub = np.array(data["tde-i_standard_tau"][sub_on_idcs].to_list())
        n_cons_on = tde_on_sub.shape[0]
        if tde_on_sub.size > 0:
            peaks_on_sub = []
            for this_tde in tde_on_sub:
                peaks_on_sub.extend(find_peaks(this_tde, height=this_tde[times.index(0)], prominence=(0.5, None), distance=5)[0])
            tde_on_sub = tde_on_sub.mean(0)
            tau_on_sub = tau_on_sub.mean(0)
        
        if tde_off_sub.size > 0 and tde_on_sub.size > 0:
            tde_off[subregion_label].append(tde_off_sub)
            tde_off_noavg[subregion_label].append(tde_off_sub)
            tde_on[subregion_label].append(tde_on_sub)
            tde_on_noavg[subregion_label].append(tde_on_sub)
            peaks_off[subregion_label].extend(peaks_off_sub)
            peaks_on[subregion_label].extend(peaks_on_sub)
            peaks_subs_off[subregion_label][sub] = np.array(peaks_off_sub)
            peaks_subs_on[subregion_label][sub] = np.array(peaks_on_sub)
            n_cons_subs_off[subregion_label][sub] = n_cons_off
            n_cons_subs_on[subregion_label][sub] = n_cons_on
    
    tde_off[subregion_label] = np.array(tde_off[subregion_label]).mean(axis=0)
    tde_on[subregion_label] = np.array(tde_on[subregion_label]).mean(axis=0)
    tde_off_noavg[subregion_label] = np.array(tde_off_noavg[subregion_label])
    tde_on_noavg[subregion_label] = np.array(tde_on_noavg[subregion_label])
    peaks_off[subregion_label] = np.array([times[peak_idx] for peak_idx in peaks_off[subregion_label]])
    peaks_on[subregion_label] = np.array([times[peak_idx] for peak_idx in peaks_on[subregion_label]])
    n_subs[subregion_label] = tde_off_noavg[subregion_label].shape[0]
    
    subregion_i += 1

tde_off_baseline = tde_off["parietal"]
tde_on_baseline = tde_on["parietal"]

axis[0].plot(times[times.index(ltime):times.index(htime)+1],
             tde_off["motor"][times.index(ltime):times.index(htime)+1],
             color="#DF4A4A")
axis[0].plot(times[times.index(ltime):times.index(htime)+1],
             tde_on["motor"][times.index(ltime):times.index(htime)+1],
             color="#71BCAD")

peak_off_hists = {}
peak_on_hists = {}
for subregion_label in peaks_off.keys():
    peak_off_hists[subregion_label] = np.histogram(
        peaks_off[subregion_label], bins=hist_bins, range=(0, htime),
        weights=(1 / n_subs[subregion_label]) * np.ones_like(peaks_off[subregion_label])
    )[0]
    peak_on_hists[subregion_label] = np.histogram(
        peaks_on[subregion_label], bins=hist_bins, range=(0, htime),
        weights=(1 / n_subs[subregion_label]) * np.ones_like(peaks_on[subregion_label])
    )[0]
axis[1].hist(np.arange(1, htime, 10), bins=hist_bins, range=(0, htime),
                weights=peak_off_hists["motor"],
                alpha=0.7, color="#DF4A4A")
axis[1].hist(np.arange(1, htime, 10), bins=hist_bins, range=(0, htime),
                weights=peak_on_hists["motor"],
                alpha=0.7, color="#71BCAD")

axis[0].spines['top'].set_visible(False)
axis[0].spines['right'].set_visible(False)
axis[1].spines['top'].set_visible(False)
axis[1].spines['right'].set_visible(False)
axis[0].set_xticks(np.arange(0, htime+1, 10))
axis[1].set_xticks(np.arange(0, htime+1, 10))
axis[1].set_xlim(axis[0].get_xlim())
axis[0].plot([0, 0], axis[0].get_ylim(), color="k", linestyle="--", linewidth=1)
axis[1].plot([0, 0], axis[1].get_ylim(), color="k", linestyle="--", linewidth=1)

axis[1].set_xlabel("Time (ms)")
axis[0].set_ylabel("Time delay estimate\nstrength (Z-score)")
axis[0].set_title("Motor cortex -> STN")

hist_max = np.max([peak_off_hists["motor"], peak_on_hists["motor"]])
axis[1].set_ylim(hist_max/2, hist_max)
axis[1].set_yticks(np.linspace(hist_max/4, hist_max, 4))
axis[1].set_yticklabels(["25", "50", "75", "100"])
axis[1].set_ylabel("Peak count\n(% maximum)")

axis[0].tick_params("x", pad=2, size=2)
axis[0].tick_params("y", pad=2, size=2)
axis[1].tick_params("x", pad=2, size=2)
axis[1].tick_params("y", pad=2, size=2)

axis[0].xaxis.labelpad = 1
axis[0].yaxis.labelpad = 1
axis[1].xaxis.labelpad = 1
axis[1].yaxis.labelpad = 1

fig.set_size_inches(2.1, 2)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_TDE_Stim_spectrum.pdf"))

In [None]:
# Find average peak time in 1-9 ms window
peak_times = peaks_off["motor"]
peak_times_sub10 = peak_times[(peak_times > 0) & (peak_times < 10)]
print(np.mean(peak_times_sub10))
print(sem(peak_times_sub10))
print(
    "Average motor cortex -> STN peak time below 10 ms (OFF): "
    f"{np.mean(peak_times_sub10)} +/- {sem(peak_times_sub10)} ms (mean +/- SEM)"
)

In [None]:
# Find times of significant estimates (OFF & ON)
fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=(tde_off_noavg["motor"] - tde_off_baseline)[:, times.index(ltime):times.index(htime)+1].T,
    x_2=np.zeros_like(tde_off_noavg["motor"])[:, times.index(ltime):times.index(htime)+1].T,
    times=np.array(times[times.index(ltime):times.index(htime)+1]),
    data_labels=["motor", "parietal"],
    x_label="Time (ms)",
    y_label="Est (A.U.)",
    two_tailed=two_sided,
    paired_x1x2=False,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Cortex -> STN (OFF)",
    colour=["#DF4A4A", "k"]
)

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=(tde_on_noavg["motor"] - tde_on_baseline)[:, times.index(ltime):times.index(htime)+1].T,
    x_2=np.zeros_like(tde_on_noavg["motor"])[:, times.index(ltime):times.index(htime)+1].T,
    times=np.array(times[times.index(ltime):times.index(htime)+1]),
    data_labels=["motor", "parietal"],
    x_label="Time (ms)",
    y_label="Est (A.U.)",
    two_tailed=two_sided,
    paired_x1x2=False,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Cortex -> STN (ON)",
    colour=["#71BCAD", "k"]
)

In [None]:
# Find time bins with significant number of peaks (OFF)
n_peaks = {}
for subregion_label in subregions_mapping.values():
    n_peaks[subregion_label] = np.zeros((len(n_cons_subs_off[subregion_label].values()), len(hist_bins) - 1))
    sub_idx = 0
    for sub, n_cons in zip(peaks_subs_off[subregion_label].values(), n_cons_subs_off[subregion_label].values()):
        if n_cons > 0:
            bin_idx = 0
            for start_time, end_time in zip(hist_bins[:-1], hist_bins[1:]):
                for idx in sub:
                    n_peaks[subregion_label][sub_idx, bin_idx] += np.count_nonzero((times[idx] >= start_time) & (times[idx] < end_time))
                n_peaks[subregion_label][sub_idx, bin_idx] /= n_cons
                bin_idx += 1
        else:
            n_peaks[subregion_label][sub_idx, :] = np.nan
        sub_idx += 1

n_peaks_pvals = {"motor": []}
for subregion_label in n_peaks_pvals.keys():
    for time_bin_idx, time_bin in enumerate(n_peaks[subregion_label].T):
        _, pval = pte_stats.permutation_twosample(
            data_a=np.array(time_bin[~np.isnan(time_bin)]) - np.nanmean(n_peaks["parietal"], axis=0)[time_bin_idx],
            data_b=np.zeros_like(time_bin[~np.isnan(time_bin)]),
            n_perm=n_perm, two_tailed=two_sided
        )
        n_peaks_pvals[subregion_label].append(pval)

n_peaks_pvals["motor"]

In [None]:
# Find time bins with significant number of peaks (ON)
n_peaks = {}
for subregion_label in subregions_mapping.values():
    n_peaks[subregion_label] = np.zeros((len(n_cons_subs_on[subregion_label].values()), len(hist_bins) - 1))
    sub_idx = 0
    for sub, n_cons in zip(peaks_subs_on[subregion_label].values(), n_cons_subs_on[subregion_label].values()):
        if n_cons > 0:
            bin_idx = 0
            for start_time, end_time in zip(hist_bins[:-1], hist_bins[1:]):
                for idx in sub:
                    n_peaks[subregion_label][sub_idx, bin_idx] += np.count_nonzero((times[idx] >= start_time) & (times[idx] < end_time))
                n_peaks[subregion_label][sub_idx, bin_idx] /= n_cons
                bin_idx += 1
        else:
            n_peaks[subregion_label][sub_idx, :] = np.nan
        sub_idx += 1

n_peaks_pvals = {"motor": []}
for subregion_label in n_peaks_pvals.keys():
    for time_bin_idx, time_bin in enumerate(n_peaks[subregion_label].T):
        _, pval = pte_stats.permutation_twosample(
            data_a=np.array(time_bin[~np.isnan(time_bin)]) - np.nanmean(n_peaks["parietal"], axis=0)[time_bin_idx],
            data_b=np.zeros_like(time_bin[~np.isnan(time_bin)]),
            n_perm=n_perm, two_tailed=two_sided
        )
        n_peaks_pvals[subregion_label].append(pval)

n_peaks_pvals["motor"]

In [None]:
# Compute taus 
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_con_tde-StimOffOn_multi_sub.pkl")
)
times = copy.deepcopy(data["timepoints"][0])
data["timepoints"] = [times] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

tau = {}
tau_off = {}
tau_on = {}
for subregion_name, subregion_label in subregions_mapping.items():
    tau[subregion_label] = []
    tau_off[subregion_label] = []
    tau_on[subregion_label] = []
    for sub in np.unique(data["sub"]):
        sub_off_idcs = ((data["stim"] == "Off") & (data["freq_band_names"] == "all")
                        & (data["seed_subregions"] == subregion_label)
                        & (data["tde-i_standard_tau"] > 0) & (data["sub"] == sub)
                        & ~((data["tde-i_standard_tau_ci_80_low"] < 0) & (data["tde-i_standard_tau_ci_80_high"] > 0)))
        tau_off_sub = np.array(data["tde-i_standard_tau"][sub_off_idcs].to_list())
        if tau_off_sub.size > 0:
            tau_off_sub = tau_off_sub.mean(0)

        sub_on_idcs = ((data["stim"] == "On") & (data["freq_band_names"] == "all")
                        & (data["seed_subregions"] == subregion_label)
                        & (data["tde-i_standard_tau"] > 0) & (data["sub"] == sub)
                        & ~((data["tde-i_standard_tau_ci_80_low"] < 0) & (data["tde-i_standard_tau_ci_80_high"] > 0)))
        tau_on_sub = np.array(data["tde-i_standard_tau"][sub_on_idcs].to_list())
        if tau_on_sub.size > 0:
            tau_on_sub = tau_on_sub.mean(0)
        
        if tau_off_sub.size > 0 and tau_on_sub.size > 0:
            tau_off[subregion_label].append(tau_off_sub)
            tau_on[subregion_label].append(tau_on_sub)
            tau[subregion_label].append(np.mean([tau_off_sub, tau_on_sub], 0))
    
    tau_off[subregion_label] = np.array(tau_off)
    tau_on[subregion_label] = np.array(tau_on)
    tau[subregion_label] = np.array(tau[subregion_label])

    _, pval = pte_stats.permutation_onesample(data_a=tau_off, data_b=tau_on, n_perm=n_perm, two_tailed=two_sided)
    print(f"{subregion_name} tau, OFF vs. ON: p-value = {pval}")

for subregion in tau.keys():
    print(f"{subregion} tau (avg. OFF & ON): {np.mean(tau[subregion])} +/- {sem(tau[subregion])} ms (mean +/- SEM)")
    print(f"{subregion} tau (OFF): {np.mean(tau_off[subregion])} +/- {sem(tau_off[subregion])} ms (mean +/- SEM)")
    print(f"{subregion} tau (ON): {np.mean(tau_on[subregion])} +/- {sem(tau_on[subregion])} ms (mean +/- SEM)")