In [1]:
"""Analyse relationship between MIC patterns and hyperdirect connectivity."""

from copy import deepcopy
import numpy as np
import pandas as pd
from statsmodels.regression.mixed_linear_model import MixedLM
import os
import sys
from pathlib import Path
import matplotlib
from matplotlib import pyplot as plt
import mne
import trimesh
from matplotlib.offsetbox import AnchoredText
from scipy.stats import pearsonr

cd_path = Path(os.getcwd()).absolute().parent
sys.path.append(os.path.join(cd_path, "coherence"))
from coh_track_fibres import TrackFibres


FOLDERPATH_ANALYSIS = "Path_to\\Project\\Analysis"
FOLDERPATH_FIGURES = os.path.join(os.path.dirname(os.getcwd()), "figures")


# Settings
fbands = ["low_beta", "high_beta"]
ecog_radius = 5  # mm
lfp_radius = 3  # mm
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc("font", size=6, family="Arial")
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('axes', titlesize=7)
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [3]:
def project_to_mesh(coords):
    """Project coordinates to brain surface."""
    mesh_name = "mni_icbm152_nlin_asym_09b"
    sample_path = mne.datasets.sample.data_path()
    subjects_dir = sample_path / "subjects"

    # transform coords into proper space for projection
    mri_mni_trans = mne.read_talxfm(mesh_name, subjects_dir)
    mri_mni_inv = np.linalg.inv(mri_mni_trans["trans"])
    coords = mne.transforms.apply_trans(mri_mni_inv, coords)

    path_mesh = f"{subjects_dir}\\{mesh_name}\\surf\\{mesh_name}.glb"
    with open(path_mesh, mode="rb") as f:
        scene = trimesh.exchange.gltf.load_glb(f)
    mesh: trimesh.Trimesh = trimesh.Trimesh(**scene["geometry"]["geometry_0"])
    coords = mesh.nearest.on_surface(coords)[0]
    coords *= 1.05
    # transforms coords back into MNI space
    return mne.transforms.apply_trans(mri_mni_trans, coords)

In [4]:
# Averaged MIC patterns - channel-wise fibre counts
patterns = pd.DataFrame.from_dict(
    pd.read_pickle(
        os.path.join(
            FOLDERPATH_ANALYSIS,
            "task-Rest_acq-multi_run-multi_con_mic_fibre_tracking_hyperdirect-MedOffOn_multi_sub.pkl",
        )
    )
)

patterns = patterns[
    [
        "connectivity-mic_topographies",
        "frequencies_band_labels",
        "sub",
        "med",
        "ch_types",
        "ch_names",
        "ch_coords",
        "seed_names",
        "target_names",
        "seed_coords",
        "target_coords",
    ]
]
patterns = patterns.rename(
    columns={
        "connectivity-mic_topographies": "mean_weights",
        "frequencies_band_labels": "fband",
    }
)

coords = np.array(patterns["ch_coords"].tolist())
coords[:, 0] = np.abs(coords[:, 0])  # pin to right hemisphere
patterns["ch_coords"] = coords.tolist()

ecog_idcs = patterns["ch_types"] == "ecog"
ecog_coords = project_to_mesh(
    np.array(patterns["ch_coords"].loc[ecog_idcs].tolist())
)
patterns.loc[ecog_idcs, "ch_coords"] = pd.Series(
    ecog_coords.tolist(), index=patterns.index[ecog_idcs]
)

coords = np.array(patterns["ch_coords"].tolist())
coords *= 1000  # m to mm
patterns["ch_coords"] = coords.tolist()

max_n_fibres = 0

fibre_tracking = TrackFibres(
    os.path.join(
        cd_path,
        "coherence",
        "fibre_atlases",
        "holographic_hyperdirect_filtered.mat",
    )
)
avg_patterns = {
    "sub": [],
    "med": [],
    "seed_name": [],
    "target_name": [],
    "fband": [],
    "weights": [],
    "n_fibres": [],
    "fibre_ids": [],
}
fibre_weights = {fband: {} for fband in fbands}
for sub in patterns["sub"].unique():
    for med in ["Off", "On"]:
        seed_names = (
            patterns["seed_names"]
            .loc[(patterns["sub"] == sub) & (patterns["med"] == med)]
            .values[0]
            .split(" & ")
        )
        target_names = (
            patterns["target_names"]
            .loc[(patterns["sub"] == sub) & (patterns["med"] == med)]
            .values[0]
            .split(" & ")
        )
        for seed_name in seed_names:
            for target_name in target_names:
                for fband in fbands:
                    avg_patterns["sub"].append(sub)
                    avg_patterns["med"].append(med)
                    avg_patterns["seed_name"].append(seed_name)
                    avg_patterns["target_name"].append(target_name)
                    avg_patterns["fband"].append(fband)

                    seed_i = patterns.loc[
                        (patterns["sub"] == sub)
                        & (patterns["med"] == med)
                        & (patterns["fband"] == fband)
                        & (patterns["ch_names"] == seed_name)
                    ].index[0]
                    target_i = patterns.loc[
                        (patterns["sub"] == sub)
                        & (patterns["med"] == med)
                        & (patterns["fband"] == fband)
                        & (patterns["ch_names"] == target_name)
                    ].index[0]
                    avg_patterns["weights"].append(
                        np.mean(
                            (
                                patterns["mean_weights"].loc[seed_i],
                                patterns["mean_weights"].loc[target_i],
                            )
                        )
                    )

                    seed_coords = np.array(patterns["ch_coords"].loc[seed_i])[
                        None, :
                    ]
                    target_coords = np.array(
                        patterns["ch_coords"].loc[target_i]
                    )[None, :]

                    fibre_ids, n_fibres = fibre_tracking.find_within_radius(
                        seed_coords,
                        ecog_radius,
                        target_coords,
                        lfp_radius,
                        True,
                    )
                    avg_patterns["fibre_ids"].append(fibre_ids[0])
                    avg_patterns["n_fibres"].append(n_fibres[0])
                    max_n_fibres = np.max([max_n_fibres, n_fibres[0]])

                    for fibre_id in fibre_ids[0]:
                        if fibre_id not in fibre_weights[fband].keys():
                            fibre_weights[fband][fibre_id] = []
                        fibre_weights[fband][fibre_id].append(
                            avg_patterns["weights"][-1]
                        )

avg_patterns["n_fibres"] = list(
    (np.array(avg_patterns["n_fibres"]) / max_n_fibres) * 100
)

summed_patterns = pd.DataFrame.from_dict(avg_patterns)

In [5]:
# Average of MIC weights per fibre across subjects (for plotting)
for fband in fibre_weights.keys():
    for fibre_id, weights in fibre_weights[fband].items():
        fibre_weights[fband][fibre_id] = np.mean(weights)
fibre_weights_results = {}
for fband_i, fband in enumerate(fibre_weights.keys()):
    if fband_i == 0:
        fibre_weights_results["fibre_ids"] = list(fibre_weights[fband].keys())
    fibre_weights_results[f"{fband}_weights"] = list(
        fibre_weights[fband].values()
    )
fibre_weights_results = pd.DataFrame.from_dict(fibre_weights_results)
fibre_weights_results.sort_values("fibre_ids", inplace=True)
fibre_weights_results.reset_index(drop=True, inplace=True)
fibre_weights_results.to_csv(
    os.path.join(
        FOLDERPATH_ANALYSIS,
        "mic_patterns_summed_hyperdirect_fibres-MedOffOn.csv",
    ),
    index=False,
)

In [5]:
# LME models
models = {fband: None for fband in fbands}
results = {fband: None for fband in fbands}
for fband in fbands:
    models[fband] = MixedLM.from_formula(
        "n_fibres ~ weights + C(med)",
        summed_patterns.loc[(summed_patterns["fband"] == fband)].to_dict(
            "list"
        ),
        groups="sub",
    )
    results[fband] = models[fband].fit()

# Can view results with e.g. `results["high_beta"].summary()`

In [7]:
def compute_r2(model_result):
    """Compute r-squared of LME model."""
    var_resid = model_result.scale
    var_random_effect = model_result.cov_re.iloc[0][0]
    var_fixed_effect = model_result.predict().var()

    total_var = var_fixed_effect + var_random_effect + var_resid
    marginal_r2 = var_fixed_effect / total_var
    conditional_r2 = (var_fixed_effect + var_random_effect) / total_var

    return marginal_r2, conditional_r2

In [None]:
# Plot information from LME models (Figure 4a inset)
fband = "high_beta"

fig, axs = plt.subplots(2, 1)

est_nfibres = results[fband].fittedvalues
est_nfibres = np.clip(est_nfibres, 0, None)

emp_nfibres = np.array(
    summed_patterns.loc[(summed_patterns["fband"] == fband)].to_dict("list")[
        "n_fibres"
    ]
)

a, b = np.polyfit(emp_nfibres, est_nfibres, 1)
axs[0].scatter(
    emp_nfibres,
    est_nfibres,
    marker="o",
    alpha=1,
    color=[0.4, 0.4, 0.4],
    facecolors="none",
    s=1,
)
axs[0].plot(emp_nfibres, a * emp_nfibres + b, color="k", linewidth=1)

pearson_r, pearson_pval = pearsonr(emp_nfibres, est_nfibres)
if pearson_pval > 0.05:
    pearson_pval = f"= {pearson_pval :.2f}"
elif pearson_pval < 0.05:
    pearson_pval = "< 0.05"

model_r2 = compute_r2(results[fband])[1]

box_text = f"R$^{2}$ = {model_r2 :.2f}\nr = {pearson_r :.2f}\np {pearson_pval}"
text_box = AnchoredText(box_text, frameon=False, loc=4, pad=0.1, borderpad=0.5)
plt.setp(text_box.patch, facecolor="white", alpha=0.7)
axs[0].add_artist(text_box)

axs[0].set_xlabel("Empirical number of fibres\n(% maximum)")

axs[0].tick_params("x", pad=2, size=2)
axs[0].tick_params("y", pad=2, size=2)

axs[0].set_box_aspect(1)

emp_weights = np.array(
    summed_patterns.loc[(summed_patterns["fband"] == fband)].to_dict("list")[
        "weights"
    ]
)

a, b = np.polyfit(emp_weights, est_nfibres, 1)
cmap = plt.get_cmap("viridis")
normalise = matplotlib.colors.Normalize(
    vmin=np.min(emp_weights), vmax=np.max(emp_weights)
)
colours = cmap(normalise(emp_weights))
axs[1].scatter(
    emp_weights, est_nfibres, marker="o", alpha=1, color=colours, s=1
)
axs[1].plot(emp_weights, a * emp_weights + b, color="k", linewidth=1)

model_pval = deepcopy(results[fband].pvalues["weights"])
if model_pval > 0.05:
    model_pval = f"= {model_pval :.2f}"
elif model_pval < 0.05:
    model_pval = "< 0.05"
box_text = (
    fr"$\beta$ = {results[fband].params['weights'] :.2f}" + f"\np {model_pval}"
)
text_box = AnchoredText(box_text, frameon=False, loc=4, pad=0.1, borderpad=0.5)
plt.setp(text_box.patch, facecolor='white', alpha=0.7)
axs[1].add_artist(text_box)

axs[1].set_xlabel("Contribution to\nconnectivity (A.U.)")
axs[1].set_ylabel("Estimated number of fibres (% maximum)")

axs[1].set_xticks((emp_weights.min(), emp_weights.max()))
axs[1].set_xticklabels(["Low", "High"])

axs[1].tick_params("x", pad=2, size=2)
axs[1].tick_params("y", pad=2, size=2)

axs[1].set_box_aspect(1)

axs[0].xaxis.labelpad = 0.5
axs[1].xaxis.labelpad = 0
axs[1].yaxis.labelpad = 1

fig.set_size_inches(1.3, 3.5)
fig.savefig(
    os.path.join(
        FOLDERPATH_FIGURES,
        f"Manuscript_StrucConn_EmpEstWeights_{fband}.pdf",
    )
)