In [None]:
"""Plot STIM OFF vs. ON TRGC (Figures 3C & S4)."""

import os
import sys
from pathlib import Path
import numpy as np
import copy
import pandas as pd
import pte_stats
import pte_decode
from matplotlib import pyplot as plt
from scipy.stats import sem

cd_path = Path(os.getcwd()).absolute().parent
sys.path.append(os.path.join(cd_path, "coherence"))

import matplotlib
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc("font", size=6, family="Arial")
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

from matplotlib.ticker import MultipleLocator

prop_cycle = plt.rcParams["axes.prop_cycle"]
def_colors = prop_cycle.by_key()["color"]

n_perm = 100_000
two_sided = True
alpha = 0.05

FOLDERPATH_ANALYSIS = "Path_to\\Project\\Analysis"
FOLDERPATH_FIGURES = os.path.join(os.path.dirname(os.getcwd()), "figures")

lfreq = 4.0
hfreq = 40.0

fband_func = np.mean

subregions_mapping = {
    "Motor": "motor",
    "Sensory": "sensory",
}

fbands = {
    "alpha": (8, 12),
    "low_beta": (12, 20),
    "high_beta": (20, 30)
}

In [None]:
# MOTOR CORTEX <-> STN TRGC (Figure 3C)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_con_granger_regional-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["frequencies"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["frequencies"] = [freqs] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

trgc_off = np.array(data["connectivity-trgc"][(data["stim"] == "Off") & (data["seed_subregions"] == "motor")].to_list())
trgc_on = np.array(data["connectivity-trgc"][(data["stim"] == "On") & (data["seed_subregions"] == "motor")].to_list())

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

trgc_off = trgc_off[:, lfreq_i:hfreq_i+1]
trgc_on = trgc_on[:, lfreq_i:hfreq_i+1]

trgc_off_sem = sem(trgc_off)
trgc_on_sem = sem(trgc_on)

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=trgc_off.T,
    x_2=trgc_on.T,
    times=np.arange(lfreq, hfreq+0.5, 0.5),
    data_labels=["OFF therapy", "ON STN-DBS"],
    x_label="Frequency (Hz)",
    y_label="Time-reversed\nGranger causality (A.U.)",
    two_tailed=two_sided,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Motor cortex -> STN",
    colour=[def_colors[2], def_colors[4]]
)

for fband_name, fband_lims in fbands.items():
    stat, pval = pte_stats.permutation_onesample(
        data_a=fband_func(trgc_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        data_b=fband_func(trgc_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        n_perm=n_perm, two_tailed=two_sided
    )
    if pval < alpha:
        axis.axvspan(fband_lims[0], fband_lims[1], color=[0.8, 0.8, 0.8], alpha=0.3, linewidth=0)
        prefix = "[SIGNIFICANT]"
    else:
        prefix = ""
    print(f"{prefix} {fband_name}: stat={stat}; p-val={pval}")

axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.legend(handlelength=0.5, handletextpad=0.5, labelspacing=0, borderaxespad=0, loc="upper left")
leg = axis.get_legend()
leg.legend_handles[0].set_linewidth(1.5)
leg.legend_handles[1].set_linewidth(1.5)
leg.legend_handles[2].set_color([0.5, 0.5, 0.5])
leg.set_frame_on(False)
axis.set_xticks(np.arange(5, 45, 5))
axis.tick_params("x", pad=2, size=2)
axis.tick_params("y", pad=2, size=2)
axis.xaxis.labelpad = 0
axis.yaxis.labelpad = 1
axis.set_yticks([0, 0.05, 0.1, 0.15, 0.2])

fig.set_size_inches(1.8, 1.35)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_TRGC_Stim_motor_cortex-STN.pdf"))

In [None]:
# SENSORY CORTEX <-> STN TRGC (Figure S2)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_con_granger_regional-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["frequencies"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["frequencies"] = [freqs] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

trgc_off = np.array(data["connectivity-trgc"][(data["stim"] == "Off") & (data["seed_subregions"] == "sensory")].to_list())
trgc_on = np.array(data["connectivity-trgc"][(data["stim"] == "On") & (data["seed_subregions"] == "sensory")].to_list())

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

trgc_off = trgc_off[:, lfreq_i:hfreq_i+1]
trgc_on = trgc_on[:, lfreq_i:hfreq_i+1]

trgc_off_sem = sem(trgc_off)
trgc_on_sem = sem(trgc_on)

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=trgc_off.T,
    x_2=trgc_on.T,
    times=np.arange(lfreq, hfreq+0.5, 0.5),
    data_labels=["OFF therapy", "ON STN-DBS"],
    x_label="Frequency (Hz)",
    y_label="Time-reversed\nGranger causality (A.U.)",
    two_tailed=two_sided,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Sensory cortex -> STN",
    colour=[def_colors[2], def_colors[4]]
)

for fband_name, fband_lims in fbands.items():
    stat, pval = pte_stats.permutation_onesample(
        data_a=fband_func(trgc_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        data_b=fband_func(trgc_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        n_perm=n_perm, two_tailed=two_sided
    )
    if pval < alpha:
        axis.axvspan(fband_lims[0], fband_lims[1], color=[0.8, 0.8, 0.8], alpha=0.3, linewidth=0)
        prefix = "[SIGNIFICANT]"
    else:
        prefix = ""
    print(f"{prefix} {fband_name}: stat={stat}; p-val={pval}")

axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.legend(handlelength=0.5, handletextpad=0.5, labelspacing=0, borderaxespad=0, loc="upper left")
leg = axis.get_legend()
leg.legend_handles[0].set_linewidth(1.5)
leg.legend_handles[1].set_linewidth(1.5)
leg.set_frame_on(False)
axis.set_xticks(np.arange(5, 45, 5))
axis.tick_params("x", pad=2, size=2)
axis.tick_params("y", pad=2, size=2)
axis.xaxis.labelpad = 0
axis.yaxis.labelpad = 1
axis.set_yticks([0, 0.05, 0.1])

fig.set_size_inches(1.8, 1.35)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_TRGC_Stim_sensory_cortex-STN.pdf"))

In [None]:
# Find significant periods of connectivity (motor cortex -> STN; OFF)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, f"task-Rest_acq-multi_run-multi_con_granger_regional-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["frequencies"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["frequencies"] = [freqs] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

trgc_off = np.array(data["connectivity-trgc"][(data["stim"] == "Off") & (data["seed_subregions"] == "motor")].to_list())
trgc_off = trgc_off[:, lfreq_i:hfreq_i+1]

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=trgc_off.T,
    x_2=np.zeros_like(trgc_off).T,
    times=np.array(plot_freqs),
    data_labels=["TRGC", "Baseline"],
    x_label="Frequency (Hz)",
    y_label="TRGC (A.U.)",
    two_tailed=two_sided,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Cortex -> STN",
    colour=[def_colors[2], "k"]
)

In [None]:
# MOTOR CORTEX -> STN Granger scores (Figure S4)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, f"task-Rest_acq-multi_run-multi_con_granger_regional-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["frequencies"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["frequencies"] = [freqs] * len(data["seed_types"])
data = pd.DataFrame.from_dict(data)

off_idcs = (data["stim"] == "Off") & (data["seed_subregions"] == "motor")
on_idcs = (data["stim"] == "On") & (data["seed_subregions"] == "motor")

lfreq = 4.0
hfreq = 100.0

results = {
    "gc_st_off": np.array(data["connectivity-gc"][off_idcs].to_list()),
    "gc_st_on": np.array(data["connectivity-gc"][on_idcs].to_list()),

    "gc_ts_off": np.array(data["connectivity-gc_ts"][off_idcs].to_list()),
    "gc_ts_on": np.array(data["connectivity-gc_ts"][on_idcs].to_list()),

    "net_gc_off": np.array(data["connectivity-net_gc"][off_idcs].to_list()),
    "net_gc_on": np.array(data["connectivity-net_gc"][on_idcs].to_list()),

    "gc_st_tr_off": np.array(data["connectivity-gc_tr"][off_idcs].to_list()),
    "gc_st_tr_on": np.array(data["connectivity-gc_tr"][on_idcs].to_list()),

    "gc_ts_tr_off": np.array(data["connectivity-gc_tr_ts"][off_idcs].to_list()),
    "gc_ts_tr_on": np.array(data["connectivity-gc_tr_ts"][on_idcs].to_list()),

    "trgc_off": np.array(data["connectivity-trgc"][off_idcs].to_list()),
    "trgc_on": np.array(data["connectivity-trgc"][on_idcs].to_list())
}
results["net_gc_tr_off"] = results["gc_st_tr_off"] - results["gc_ts_tr_off"]
results["net_gc_tr_on"] = results["gc_st_tr_on"] - results["gc_ts_tr_on"]

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
freqs = freqs[lfreq_i:hfreq_i+1]

results_sem = {}
for name, value in results.items():
    results[name] = value[:, lfreq_i:hfreq_i+1]
    results_sem[name] = sem(results[name])

results_order = [
    "gc_st", "gc_ts", "net_gc",
    None,
    "gc_st_tr", "gc_ts_tr", "net_gc_tr",
    "trgc"
]

titles = [
    "Cortex --> STN\n",
    "STN --> cortex\n",
    "Cortex --> STN\n(net scores)",
    "",
    "Cortex --> STN\n(time-reversal)",
    "STN --> cortex\n(time-reversal)",
    "Cortex --> STN\n(net time-reversal scores)",
    "Cortex --> STN\n(TRGC)"
]

titles = [
    r"$F_{X \rightarrow Y}$",
    r"$F_{Y \rightarrow X}$",
    r"$F^{\, \mathrm{net}}_{X \rightarrow Y}$",
    "",
    r"$\tilde{F}_{\tilde{X} \rightarrow \tilde{Y}}$",
    r"$\tilde{F}_{\tilde{Y} \rightarrow \tilde{X}}$",
    r"$\tilde{F}^{\, \mathrm{net}}_{\tilde{X} \rightarrow \tilde{Y}}$",
    r"$\tilde{D}^{\, \mathrm{net}}_{X \rightarrow Y}$",
]

fig, axes = plt.subplots(2, 4)
axes = axes.flatten()
ylim = [np.inf, -np.inf]
for axis_i, name in enumerate(results_order):
    if name is not None:
        name_off = f"{name}_off"
        name_on= f"{name}_on"

        axes[axis_i].plot((1, 100), (0, 0), color="k", linestyle="--", linewidth=1)

        mean_off = np.mean(results[name_off], axis=0)
        mean_on = np.mean(results[name_on], axis=0)

        axes[axis_i].plot(freqs, mean_off, color=def_colors[2])
        axes[axis_i].fill_between(freqs, mean_off-results_sem[name_off], mean_off+results_sem[name_off], color=def_colors[2], alpha=0.3)

        axes[axis_i].plot(freqs, mean_on, color=def_colors[4])
        axes[axis_i].fill_between(freqs, mean_on-results_sem[name_on], mean_on+results_sem[name_on], color=def_colors[4], alpha=0.3)

        axes[axis_i].spines['top'].set_visible(False)
        axes[axis_i].spines['right'].set_visible(False)

        axes[axis_i].set_xticks(np.arange(0, hfreq+1, 20))
        axes[axis_i].xaxis.set_minor_locator(MultipleLocator(10))

        axis_ylim = axes[axis_i].get_ylim()
        if axis_ylim[0] < ylim[0]:
            ylim[0] = axis_ylim[0]
        if axis_ylim[1] > ylim[1]:
            ylim[1] = axis_ylim[1]

        axes[axis_i].set_title(titles[axis_i])

for axis in axes:
    axis.set_ylim(ylim)
    axis.set_xlim(0, 100)
    axis.tick_params("x", pad=2, size=2)
    axis.tick_params("y", pad=2, size=2)
    axis.xaxis.labelpad = 0
    axis.yaxis.labelpad = 1

fig.delaxes(axes[3])
fig.supxlabel("Frequency (Hz)")
fig.supylabel("Granger scores (A.U.)")

fig.set_size_inches(6.7, 3.7)

plt.tight_layout()

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_AllGC_Stim_motor_cortex-STN.pdf"))