In [1]:
"""Plot STIM OFF vs. ON power (Figures 2a, S1, & S2)."""

import os
import sys
from pathlib import Path
import numpy as np
import copy
import pandas as pd
import pte_stats
import pte_decode
from matplotlib import pyplot as plt
from statannotations.stats.StatTest import StatTest

cd_path = Path(os.getcwd()).absolute().parent
sys.path.append(os.path.join(cd_path, "coherence"))

import matplotlib
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc("font", size=6, family="Arial")
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

prop_cycle = plt.rcParams["axes.prop_cycle"]

n_perm = 100_000
alpha = 0.05
two_sided = True

FOLDERPATH_ANALYSIS = "Path_to\\Project\\Analysis"
FOLDERPATH_FIGURES = os.path.join(os.path.dirname(os.getcwd()), "figures")

def permutation_onesample():
    """Wrapper for StatTest with permutation one-sample test."""

    def _stat_test(x, y):
        if isinstance(x, pd.Series):
            x = x.to_numpy()
        if isinstance(y, pd.Series):
            y = y.to_numpy()
        diff = x - y
        return pte_stats.permutation_onesample(
            data_a=diff, data_b=0, n_perm=100_000, two_tailed=True
        )

    return StatTest(
        func=_stat_test,
        alpha=0.05,
        test_long_name="Permutation Test",
        test_short_name="Perm. Test"
    )

lfreq = 4.0
hfreq = 40.0

fbands = {
    "alpha": (8, 12),
    "low_beta": (12, 20),
    "high_beta": (20, 30)
}

fband_func = np.mean

In [None]:
# WHOLE CORTEX POWER (Figure 2a; also Figure S2 when adjusting the data loaded)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_pow_standard_whole-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["freqs"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["freqs"] = [freqs] * len(data["ch_types"])
data = pd.DataFrame.from_dict(data)

pow_off = np.array(data["power-multitaper"][(data["stim"] == "Off") & (data["ch_regions"] == "cortex")].to_list())
pow_on = np.array(data["power-multitaper"][(data["stim"] == "On") & (data["ch_regions"] == "cortex")].to_list())

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

pow_off = pow_off[:, lfreq_i:hfreq_i+1]
pow_on = pow_on[:, lfreq_i:hfreq_i+1]

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=pow_off.T,
    x_2=pow_on.T,
    times=np.arange(lfreq, hfreq+0.5, 0.5),
    data_labels=["OFF therapy", "ON STN-DBS"],
    x_label="Frequency (Hz)",
    y_label="Power (% total)",
    two_tailed=two_sided,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="Cortex",
    colour=["#DF4A4A", "#71BCAD"]
)

for fband_name, fband_lims in fbands.items():
    stat, pval = pte_stats.permutation_onesample(
        data_a=fband_func(pow_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        data_b=fband_func(pow_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        n_perm=n_perm, two_tailed=two_sided
    )
    if pval < alpha:
        axis.axvspan(fband_lims[0], fband_lims[1], color=[0.8, 0.8, 0.8], alpha=0.3, linewidth=0, zorder=0)
        prefix = "[SIGNIFICANT]"
    else:
        prefix = ""
    print(f"{prefix} {fband_name}: stat={stat}; p-val={pval}")

axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.set_box_aspect(0.8)
axis.legend(handlelength=0.5, handletextpad=0.5, labelspacing=0, borderaxespad=0, loc="lower left")
leg = axis.get_legend()
leg.legend_handles[0].set_linewidth(1.5)
leg.legend_handles[1].set_linewidth(1.5)
leg.set_frame_on(False)
axis.tick_params("x", pad=2, size=2)
axis.tick_params("y", pad=2, size=2)
axis.set_xticks(np.arange(5, 45, 5))
axis.xaxis.labelpad = 0
axis.yaxis.labelpad = 1
axis.set_yticks([1, 2, 3])

fig.set_size_inches(1.6, 1.6)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_Power_Stim_whole_cortex.pdf"), dpi=300)

In [None]:
# STN POWER (Figure 2a; also Figure S2 when adjusting the data loaded)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_pow_standard_whole-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["freqs"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["freqs"] = [freqs] * len(data["ch_types"])
data = pd.DataFrame.from_dict(data)

pow_off = np.array(data["power-multitaper"][(data["stim"] == "Off") & (data["ch_regions"] == "STN")].to_list())
pow_on = np.array(data["power-multitaper"][(data["stim"] == "On") & (data["ch_regions"] == "STN")].to_list())

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

pow_off = pow_off[:, lfreq_i:hfreq_i+1]
pow_on = pow_on[:, lfreq_i:hfreq_i+1]

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=pow_off.T,
    x_2=pow_on.T,
    times=np.arange(lfreq, hfreq+0.5, 0.5),
    data_labels=["OFF therapy", "ON STN-DBS"],
    x_label="Frequency (Hz)",
    y_label="Power (% total)",
    two_tailed=two_sided,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="STN",
    colour=["#DF4A4A", "#71BCAD"]
)

for fband_name, fband_lims in fbands.items():
    stat, pval = pte_stats.permutation_onesample(
        data_a=fband_func(pow_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        data_b=fband_func(pow_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        n_perm=n_perm, two_tailed=True
    )
    if pval < alpha:
        axis.axvspan(fband_lims[0], fband_lims[1], color=[0.8, 0.8, 0.8], alpha=0.3, linewidth=0, zorder=0)
        prefix = "[SIGNIFICANT]"
    else:
        prefix = ""
    print(f"{prefix} {fband_name}: stat={stat}; p-val={pval}")

axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.set_box_aspect(0.8)
axis.legend(handlelength=0.5, handletextpad=0.5, labelspacing=0, borderaxespad=0, loc="lower left")
leg = axis.get_legend()
leg.legend_handles[0].set_linewidth(1.5)
leg.legend_handles[1].set_linewidth(1.5)
leg.legend_handles[2].set_color([0.5, 0.5, 0.5])
leg.set_frame_on(False)
axis.set_xticks(np.arange(5, 45, 5))
axis.tick_params("x", pad=2, size=2)
axis.tick_params("y", pad=2, size=2)
axis.set_ylim((0.2, 3.1))
axis.xaxis.labelpad = 0
axis.yaxis.labelpad = 1
axis.set_yticks([1, 2, 3])

fig.set_size_inches(1.58, 1.58)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_Power_Stim_STN.pdf"))

In [None]:
# STN POWER OFF (Channels with low vs. high beta peaks; Figure S1a)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_pow_standard_channels-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["freqs"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["freqs"] = [freqs] * len(data["ch_types"])
data = pd.DataFrame.from_dict(data)

pow_low = []
pow_high = []
peak_channels_low = {
    "006": ["01-05", "01-06", "01-07", "05-07", "06-07", "06-08", "07-08"],
    "007": ["01-05", "05-08"],
    "009": ["06-07", "05-08", "06-08", "07-08"],
    "012": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "014": ["01-03", "01-04", "02-03", "02-04", "03-04", "03-08", "04-08"],
    "017": ["01-05", "01-06", "01-07", "05-06", "05-07", "05-08"],
    "019": ["05-08"],
    "022": ["01-07", "05-07", "06-07", "06-08"],
    "023": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "025": ["01-07", "05-07", "06-07", "07-08"],
    "026": ["01-04", "02-04", "03-04", "03-08", "04-08"],
    "027": ["05-07"],
}
peak_channels_high = {
    "006": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
    "007": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
    "009": ["01-05", "05-06", "05-07"],
    "012": ["01-02", "01-03", "01-04"],
    "014": ["01-02", "01-03", "01-04", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "017": ["01-05", "05-06", "05-07", "05-08"],
    "019": ["01-05", "01-06", "01-07", "05-06", "05-08"],
    "022": ["05-07"],
    "023": ["01-02", "02-04"],
    "025": ["05-06", "06-07"],
    "026": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08"],
    "027": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
}
for sub in peak_channels_low.keys():
    sub_pow_low = []
    sub_pow_high = []
    sub_pow = data[(data["ch_regions"] == "STN") & (data["sub"] == sub)]
    for ch in peak_channels_low[sub]:
        if sub_pow["ch_names"][sub_pow["stim"] == "Off"].str.contains(ch).any():
            sub_pow_low.append(data["power-multitaper"][(data["stim"] == "Off") &
                                                        (data["ch_names"].str.contains(ch)) &
                                                        (data["ch_regions"] == "STN") &
                                                        (data["sub"] == sub)].to_list())
    for ch in peak_channels_high[sub]:
        if sub_pow["ch_names"][sub_pow["stim"] == "Off"].str.contains(ch).any():
            sub_pow_high.append(data["power-multitaper"][(data["stim"] == "Off") &
                                                         (data["ch_names"].str.contains(ch)) &
                                                         (data["ch_regions"] == "STN") &
                                                         (data["sub"] == sub)].to_list())
    pow_low.append(np.mean(sub_pow_low, axis=0)[0])
    pow_high.append(np.mean(sub_pow_high, axis=0)[0])
pow_low = np.array(pow_low)
pow_high = np.array(pow_high)

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

pow_low = pow_low[:, lfreq_i:hfreq_i+1]
pow_high = pow_high[:, lfreq_i:hfreq_i+1]

fig, axis = pte_decode.lineplot_prediction_compare(
    x_1=pow_low.T,
    x_2=pow_high.T,
    times=np.arange(lfreq, hfreq+0.5, 0.5),
    data_labels=["Low beta", "High beta"],
    x_label="Frequency (Hz)",
    y_label="Power (% total)",
    two_tailed=True,
    paired_x1x2=True,
    n_perm=n_perm,
    correction_method="cluster_pvals",
    title="STN (OFF therapy)",
    colour=["#CF80E7", "#CC9800"]
)

for fband_name, fband_lims in fbands.items():
    stat, pval = pte_stats.permutation_onesample(
        data_a=fband_func(pow_low[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        data_b=fband_func(pow_high[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1),
        n_perm=n_perm, two_tailed=True
    )
    if pval < alpha:
        axis.axvspan(fband_lims[0], fband_lims[1], color=[0.8, 0.8, 0.8], alpha=0.3, linewidth=0, zorder=0)
        prefix = "[SIGNIFICANT]"
    else:
        prefix = ""
    print(f"{prefix} {fband_name}: stat={stat}; p-val={pval}")

axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.set_box_aspect(0.8)
axis.legend(handlelength=0.5, handletextpad=0.5, labelspacing=0, borderaxespad=0, loc="lower left")
leg = axis.get_legend()
leg.legend_handles[0].set_linewidth(1.5)
leg.legend_handles[1].set_linewidth(1.5)
leg.legend_handles[2].set_color([0.5, 0.5, 0.5])
leg.set_frame_on(False)
axis.set_xticks(np.arange(5, 45, 5))
axis.tick_params("x", pad=2, size=2)
axis.tick_params("y", pad=2, size=2)
axis.set_ylim((0.2, 3.1))
axis.xaxis.labelpad = 0
axis.yaxis.labelpad = 1
axis.set_yticks([1, 2, 3])

fig.set_size_inches(1.58, 1.58)

fig.savefig(os.path.join(FOLDERPATH_FIGURES, "Manuscript_Power_Stim_OFF_low_vs_high_beta_channels.pdf"), dpi=300)

In [None]:
# STN POWER OFF vs. ON (Channels with low beta peaks; Figure S1b)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_pow_standard_channels-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["freqs"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["freqs"] = [freqs] * len(data["ch_types"])
data = pd.DataFrame.from_dict(data)

pow_off = []
pow_on = []
peak_channels = {
    "006": ["01-05", "01-06", "01-07", "05-07", "06-07", "06-08", "07-08"],
    "007": ["01-05", "05-08"],
    "009": ["06-07", "05-08", "06-08", "07-08"],
    "012": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "014": ["01-03", "01-04", "02-03", "02-04", "03-04", "03-08", "04-08"],
    "017": ["01-05", "01-06", "01-07", "05-06", "05-07", "05-08"],
    "019": ["05-08"],
    "022": ["01-07", "05-07", "06-07", "06-08"],
    "023": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "025": ["01-07", "05-07", "06-07", "07-08"],
    "026": ["01-04", "02-04", "03-04", "03-08", "04-08"],
    "027": ["05-07"],
}
for sub in peak_channels.keys():
    sub_pow_off = []
    sub_pow_on = []
    for ch in peak_channels[sub]:
        sub_pow = data[(data["ch_regions"] == "STN") & (data["sub"] == sub)]
        if sub_pow["ch_names"][sub_pow["stim"] == "Off"].str.contains(ch).any():
            sub_pow_off.append(data["power-multitaper"][(data["stim"] == "Off") & (data["ch_names"].str.contains(ch)) & (data["ch_regions"] == "STN") & (data["sub"] == sub)].to_list())
        if sub_pow["ch_names"][sub_pow["stim"] == "On"].str.contains(ch).any():
            sub_pow_on.append(data["power-multitaper"][(data["stim"] == "On") & (data["ch_names"].str.contains(ch)) & (data["ch_regions"] == "STN") & (data["sub"] == sub)].to_list())
    pow_off.append(np.mean(sub_pow_off, axis=0)[0])
    pow_on.append(np.mean(sub_pow_on, axis=0)[0])
pow_off = np.array(pow_off)
pow_on = np.array(pow_on)

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

pow_off = pow_off[:, lfreq_i:hfreq_i+1]
pow_on = pow_on[:, lfreq_i:hfreq_i+1]

for fband_name, fband_lims in fbands.items():
    pow_off_band = fband_func(pow_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1)
    pow_on_band = fband_func(pow_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1)
    plot_df = pd.DataFrame.from_dict({
        "Power (% total)": pow_off_band.tolist() + pow_on_band.tolist(),
        "Stimulation": ["OFF"] * len(pow_off_band) + ["ON"] * len(pow_on_band),
        "Subject": list(peak_channels.keys()) * 2
    })

    fig, ax = pte_decode.violinplot_results(
        data=plot_df,
        outpath=None,
        x="Stimulation",
        y="Power (% total)",
        hue=None,
        order=["OFF", "ON"],
        hue_order=None,
        stat_test=permutation_onesample(),
        alpha=0.05,
        add_lines="Subject",
        title=f"Low beta channels; {fband_name}",
        swarm_size=3,
        colours=["#DF4A4A", "#71BCAD"],
        figsize=[1.3, 2],
        show=True
    )

    fig.savefig(os.path.join(FOLDERPATH_FIGURES, f"Manuscript_Power_Stim_STN_lowbeta_channels_{fband_name}.pdf"), dpi=300)

In [None]:
# STN POWER OFF vs. ON (Channels with high beta peaks; Figure S1b)
data = pd.read_pickle(os.path.join(
    FOLDERPATH_ANALYSIS, "task-Rest_acq-multi_run-multi_pow_standard_channels-StimOffOn_multi_sub.pkl")
)
freqs = copy.deepcopy(data["freqs"][0])
freqs_delta = np.unique(np.subtract(freqs[1:], freqs[:-1]))[0]
data["freqs"] = [freqs] * len(data["ch_types"])
data = pd.DataFrame.from_dict(data)

pow_off = []
pow_on = []
peak_channels = {
    "006": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
    "007": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
    "009": ["01-05", "05-06", "05-07"],
    "012": ["01-02", "01-03", "01-04"],
    "014": ["01-02", "01-03", "01-04", "02-04", "03-04", "02-08", "03-08", "04-08"],
    "017": ["01-05", "05-06", "05-07", "05-08"],
    "019": ["01-05", "01-06", "01-07", "05-06", "05-08"],
    "022": ["05-07"],
    "023": ["01-02", "02-04"],
    "025": ["05-06", "06-07"],
    "026": ["01-02", "01-03", "01-04", "02-03", "02-04", "03-04", "02-08", "03-08"],
    "027": ["01-05", "01-06", "01-07", "05-06", "05-07", "06-07", "05-08", "06-08", "07-08"],
}
for sub in peak_channels.keys():
    sub_pow_off = []
    sub_pow_on = []
    for ch in peak_channels[sub]:
        sub_pow = data[(data["ch_regions"] == "STN") & (data["sub"] == sub)]
        if sub_pow["ch_names"][sub_pow["stim"] == "Off"].str.contains(ch).any():
            sub_pow_off.append(data["power-multitaper"][(data["stim"] == "Off") & (data["ch_names"].str.contains(ch)) & (data["ch_regions"] == "STN") & (data["sub"] == sub)].to_list())
        if sub_pow["ch_names"][sub_pow["stim"] == "On"].str.contains(ch).any():
            sub_pow_on.append(data["power-multitaper"][(data["stim"] == "On") & (data["ch_names"].str.contains(ch)) & (data["ch_regions"] == "STN") & (data["sub"] == sub)].to_list())
    pow_off.append(np.mean(sub_pow_off, axis=0)[0])
    pow_on.append(np.mean(sub_pow_on, axis=0)[0])
pow_off = np.array(pow_off)
pow_on = np.array(pow_on)

np.random.seed(44)

lfreq_i = freqs.index(lfreq)
hfreq_i = freqs.index(hfreq)
plot_freqs = freqs[lfreq_i:hfreq_i+1]

pow_off = pow_off[:, lfreq_i:hfreq_i+1]
pow_on = pow_on[:, lfreq_i:hfreq_i+1]

for fband_name, fband_lims in fbands.items():
    pow_off_band = fband_func(pow_off[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1)
    pow_on_band = fband_func(pow_on[:, freqs.index(fband_lims[0]):freqs.index(fband_lims[1])+1], 1)
    plot_df = pd.DataFrame.from_dict({
        "Power (% total)": pow_off_band.tolist() + pow_on_band.tolist(),
        "Stimulation": ["OFF"] * len(pow_off_band) + ["ON"] * len(pow_on_band),
        "Subject": list(peak_channels.keys()) * 2
    })

    fig, ax = pte_decode.violinplot_results(
        data=plot_df,
        outpath=None,
        x="Stimulation",
        y="Power (% total)",
        hue=None,
        order=["OFF", "ON"],
        hue_order=None,
        stat_test=permutation_onesample(),
        alpha=0.05,
        add_lines="Subject",
        title=f"High beta channels; {fband_name}",
        swarm_size=3,
        colours=["#DF4A4A", "#71BCAD"],
        figsize=[1.3, 2],
        show=True
    )

    fig.savefig(os.path.join(FOLDERPATH_FIGURES, f"Manuscript_Power_Stim_STN_highbeta_channels_{fband_name}.pdf"), dpi=300)