In [1]:
"""Analyse relationship between MIC patterns and fMRI connectivity."""

from copy import deepcopy
import numpy as np
import pandas as pd
from statsmodels.regression.mixed_linear_model import MixedLM
import os
import matplotlib.ticker as ticker
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
import matplotlib
from scipy.stats import pearsonr


FOLDERPATH_ANALYSIS = "Path_to\\Project\\Analysis"
FOLDERPATH_FIGURES = os.path.join(os.path.dirname(os.getcwd()), "figures")


# Settings
fbands = ["low_beta", "high_beta"]
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc("font", size=6, family="Arial")
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [3]:
# ECoG MIC patterns - channel-wise fMRI connectivity
patterns = pd.DataFrame.from_dict(
    pd.read_pickle(
        os.path.join(
            FOLDERPATH_ANALYSIS,
            "task-Rest_acq-multi_run-multi_con_mic_fibre_tracking_hyperdirect-MedOffOn_multi_sub.pkl",
        )
    )
)

patterns = patterns[
    [
        "connectivity-mic_topographies",
        "frequencies_band_labels",
        "sub",
        "med",
        "ch_names",
        "ch_coords",
        "seed_names",
        "target_names",
        "seed_coords",
        "target_coords",
        "ch_types",
    ]
]
patterns = patterns.rename(
    columns={
        "connectivity-mic_topographies": "mean_weights",
        "frequencies_band_labels": "fband",
    }
)

putamen_regions = ["Put-ia_R", "Put-im_R", "Put-sa_R", "Put-sm_R", "Put-sp_R"]
caudate_regions = ["Cau-ia_R", "Cau-im_R", "Cau-sa_R", "Cau-sp_R"]
rois = {
    "Put_GPe_STN": [*putamen_regions, "Globus_pallidus_externalis_R", "STN_R"],
    "Put_GPe": [*putamen_regions, "Globus_pallidus_externalis_R"],
}

fmri_con = pd.DataFrame.from_dict(
    pd.read_csv(
        os.path.join(
            FOLDERPATH_ANALYSIS, "fmri_connectivity_maps_seed_atlas.csv"
        )
    )
)
patterns_ecog = {
    "sub": [],
    "med": [],
    "seed_name": [],
    "roi_name": [],
    "fband": [],
    "weights": [],
    "fmri_con": [],
}
for sub in patterns["sub"].unique():
    for med in ["Off", "On"]:
        seed_names = (
            patterns["seed_names"]
            .loc[
                (patterns["sub"] == sub)
                & (patterns["med"] == med)
                & (patterns["ch_types"] == "ecog")
            ]
            .values[0]
            .split(" & ")
        )
        for seed_name in seed_names:
            for roi_label, roi_names in rois.items():
                for fband in fbands:
                    patterns_ecog["sub"].append(sub)
                    patterns_ecog["med"].append(med)
                    patterns_ecog["seed_name"].append(seed_name)
                    patterns_ecog["roi_name"].append(roi_label)
                    patterns_ecog["fband"].append(fband)

                    seed_i = patterns.loc[
                        (patterns["sub"] == sub)
                        & (patterns["med"] == med)
                        & (patterns["fband"] == fband)
                        & (patterns["ch_names"] == seed_name)
                    ].index[0]
                    patterns_ecog["weights"].append(
                        patterns["mean_weights"].loc[seed_i]
                    )

                    patterns_ecog["fmri_con"].append(
                        np.mean(
                            fmri_con["con_values"]
                            .loc[
                                (fmri_con["sub"] == f"EL{sub}")
                                & (fmri_con["seed_names"] == seed_name)
                                & (fmri_con["target_names"].isin(roi_names))
                            ]
                            .values
                        )
                    )

patterns_ecog = pd.DataFrame.from_dict(patterns_ecog)

In [None]:
# LME models (and save model coefficients for plotting)
models = {fband: {} for fband in fbands}
results = {fband: {} for fband in fbands}
coeffs_fbands = []
coeffs_params = []
coeffs_rois = []
for fband in fbands:
    for roi_label in rois.keys():
        models[fband][roi_label] = MixedLM.from_formula(
            "fmri_con ~ weights + C(med)",
            patterns_ecog.loc[
                (patterns_ecog["fband"] == fband)
                & (patterns_ecog["roi_name"] == roi_label)
            ].to_dict("list"),
            groups="sub",
        )
        results[fband][roi_label] = models[fband][roi_label].fit()

        converged = results[fband][roi_label].converged
        while not converged:  # ensure eventual convergence
            for run in range(10):
                results[fband][roi_label] = models[fband][roi_label].fit(
                    start_params=results[fband][roi_label].params
                )
                converged = results[fband][roi_label].converged

        coeffs_fbands.append(fband)
        coeffs_params.append(results[fband][roi_label].params["weights"])
        coeffs_rois.append(roi_label)

coeffs = pd.DataFrame.from_dict(
    {"fband": coeffs_fbands, "roi": coeffs_rois, "coeff": coeffs_params}
)
coeffs.to_csv(
    os.path.join(FOLDERPATH_ANALYSIS, "mic_patterns_fmri_lme_coeffs.csv")
)

# Can view results with e.g. `results["low_beta"]["Put_GPe_STN"].summary()`

In [7]:
def compute_r2(model_result):
    """Compute r-squared of LME model."""
    var_resid = model_result.scale
    var_random_effect = model_result.cov_re.iloc[0][0]
    var_fixed_effect = model_result.predict().var()

    total_var = var_fixed_effect + var_random_effect + var_resid
    marginal_r2 = var_fixed_effect / total_var
    conditional_r2 = (var_fixed_effect + var_random_effect) / total_var

    return marginal_r2, conditional_r2

In [None]:
def compute_bic(model):
    """Compute BIC of LME model."""
    return (-2 * model.llf) + (np.log(model.nobs) * model.df_modelwc)

In [None]:
# Plot information from LME models (Figure 4b inset)
fband = "low_beta"


# ROI-wise empirical vs. estimated fMRI connectivity
rois = ["Caudate", "Putamen", "GPe", "STN"]
xlocators = [0.06, 0.08, 0.1, 0.08]
ylocators = [0.02, 0.02, 0.01, 0.01]
fig, axis = plt.subplots(1, 4)
plot_n = 0
for roi, xlocator, ylocator in zip(rois, xlocators, ylocators):

    ax = axis[plot_n]

    pred_low = results[fband][roi].fittedvalues

    x_low = np.array(patterns_ecog.loc[
                    (patterns_ecog["fband"] == fband)
                    & (patterns_ecog["roi_name"] == roi)
                ].to_dict("list")["fmri_con"])

    a_low, b_low = np.polyfit(x_low, pred_low, 1)
    ax.scatter(
        x_low,
        pred_low,
        marker="o",
        alpha=1,
        color=[0.4, 0.4, 0.4],
        facecolors="none",
        s=1
    )
    ax.plot(x_low, a_low*x_low+b_low, color="k", linewidth=1)

    if plot_n == 0:
        ax.set_xlabel("Empirical functional connectivity (Z-score)")
        ax.set_ylabel("Estimated functional\nconnectivity (Z-score)")
    ax.set_title(roi)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(base=xlocator))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(base=ylocator))
    plt.sca(ax)
    plt.yticks(rotation=90, rotation_mode="anchor", ha="center")
    ax.tick_params("x", pad=2, size=2)
    ax.tick_params("y", pad=4, size=2)

    ax.set_box_aspect(1)

    model_pval = deepcopy(results[fband][roi].pvalues["weights"])
    if model_pval > 0.05:
        model_pval = f"= {model_pval :.2f}"
    elif model_pval < 0.05:
        model_pval = "< 0.05"

    model_r2 = compute_r2(results[fband][roi])[1]

    box_text = f"R$^{2}$ = {model_r2 :.2f}\np {model_pval}"
    text_box = AnchoredText(
        box_text, frameon=False, loc=4, pad=0.1, borderpad=0.1
    )
    plt.setp(text_box.patch, facecolor='white', alpha=0.5)
    ax.add_artist(text_box)

    plot_n += 1

fig.set_size_inches(4.9, 1.2)
fig.savefig(
    os.path.join(
        FOLDERPATH_FIGURES, f"Manuscript_FuncConn_EmpEst_All_{fband}.pdf"
    )
)


# ROI-wise LME coefficients
fig, axis = plt.subplots(1, 1)
axis.bar(
    np.arange(len(rois)),
    [results[fband][roi].params["weights"] for roi in rois],
    color=[0.4, 0.4, 0.4],
)
axis.axhspan(0, 0, color="k", linewidth=1)
axis.set_xticks(np.arange(len(rois)))
axis.set_xticklabels(rois)
plt.xticks(rotation=35, rotation_mode="anchor", ha="center")
plt.yticks(rotation=90, rotation_mode="anchor", ha="center")
axis.tick_params("x", pad=6, size=2)
axis.tick_params("y", pad=4, size=2)
axis.set_yticks([0, 6e-3])
axis.set_ylim([-5e-4, 7e-3])
axis.set_ylabel("Beta coefficients (A.U.)")
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
fig.set_size_inches(1.2, 1.2)
fig.savefig(
    os.path.join(
        FOLDERPATH_FIGURES, f"Manuscript_FuncConn_All_coeffs_{fband}_bar.pdf"
    )
)


# Putamen-GPe-STN empirical vs. estimated fMRI connectivity
roi = "Put_GPe_STN"
fig, axs = plt.subplots(1, 2)

pred_low = results[fband][roi].fittedvalues

x_low = np.array(patterns_ecog.loc[
                (patterns_ecog["fband"] == fband)
                & (patterns_ecog["roi_name"] == roi)
            ].to_dict("list")["fmri_con"])

a_low, b_low = np.polyfit(x_low, pred_low, 1)
axs[0].scatter(
    x_low,
    pred_low,
    marker="o",
    alpha=1,
    color=[0.4, 0.4, 0.4],
    facecolors=[0.4, 0.4, 0.4],
    s=2,
)
axs[0].plot(x_low, a_low*x_low+b_low, color="k", linewidth=1)

axs[0].set_xlabel("Empirical functional\nconnectivity (Z-score)")
axs[0].set_ylabel("Estimated functional\nconnectivity (Z-score)")
axs[0].set_title("Putamen, GPe, & STN")

xlocator = 0.04
ylocator = 0.03
axs[0].xaxis.set_major_locator(ticker.MultipleLocator(base=xlocator))
axs[0].yaxis.set_major_locator(ticker.MultipleLocator(base=ylocator))
plt.sca(axs[0])
plt.yticks(rotation=90, rotation_mode="anchor", ha="center")
axs[0].tick_params("x", pad=4, size=4)
axs[0].tick_params("y", pad=7, size=4)

axs[0].set_box_aspect(1)

pearson_r, pearson_pval = pearsonr(x_low, pred_low)
if pearson_pval > 0.05:
    pearson_pval = f"= {pearson_pval :.2f}"
elif pearson_pval < 0.05:
    pearson_pval = "< 0.05"

model_r2 = compute_r2(results[fband][roi])[1]

box_text = f"R$^{2}$ = {model_r2 :.2f}\nr = {pearson_r :.2f}\np {pearson_pval}"
text_box = AnchoredText(box_text, frameon=False, loc=4, pad=0.1, borderpad=0.5)
plt.setp(text_box.patch, facecolor='white', alpha=0.7)
axs[0].add_artist(text_box)
axs[0].tick_params("x", pad=2, size=2)
axs[0].tick_params("y", pad=4, size=2)

pred_low = results[fband][roi].fittedvalues

x_low = np.array(patterns_ecog.loc[
                (patterns_ecog["fband"] == fband)
                & (patterns_ecog["roi_name"] == roi)
            ].to_dict("list")["weights"])

a_low, b_low = np.polyfit(x_low, pred_low, 1)
cmap = plt.get_cmap("viridis")
normalise = matplotlib.colors.Normalize(vmin=np.min(x_low), vmax=np.max(x_low))
colours = cmap(normalise(x_low))
axs[1].scatter(
    x_low,
    pred_low,
    marker="o",
    alpha=1,
    color=colours,
    facecolors=colours,
    s=2,
)
axs[1].plot(x_low, a_low*x_low+b_low, color="k", linewidth=1)

model_pval = deepcopy(results[fband][roi].pvalues["weights"])
if model_pval > 0.05:
    model_pval = f"= {model_pval :.2f}"
elif model_pval < 0.05:
    model_pval = "< 0.05"
box_text = (
    fr"$\beta$ = {results[fband][roi].params['weights'] :.2f}"
    + f"\np {model_pval}"
)
text_box = AnchoredText(box_text, frameon=False, loc=4, pad=0.1, borderpad=0.5)
plt.setp(text_box.patch, facecolor='white', alpha=0.7)
axs[1].add_artist(text_box)

axs[1].set_xlabel("Contribution to\nconnectivity (A.U.)")
axs[1].set_ylabel("")

axs[1].set_xticks((x_low.min(), x_low.max()))
axs[1].set_xticklabels(["Low", "High"])

axs[1].yaxis.set_major_locator(ticker.MultipleLocator(base=ylocator))
plt.yticks(rotation=90, rotation_mode="anchor", ha="center")
axs[1].tick_params("x", pad=2, size=2)
axs[1].tick_params("y", pad=2, size=2)
axs[1].set_yticklabels(["", "", "", ""])

axs[1].set_box_aspect(1)

axs[0].xaxis.labelpad = 1
axs[1].xaxis.labelpad = 0
axs[1].yaxis.labelpad = 1

fig.set_size_inches(3.2, 1.5)
fig.savefig(
    os.path.join(
        FOLDERPATH_FIGURES,
        f"Manuscript_FuncConn_EmpEstWeights_{roi}_{fband}.pdf"
    )
)