In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
from nilearn import connectome, datasets, maskers, plotting
from src import fmriprep

warnings.simplefilter(action="ignore", category=FutureWarning)

ex_preproc = "../tests/data/sub-c016_task-h2_space-MNI152NLin6Asym_res-3_desc-preproc_bold.nii.gz"

data = fmriprep.Data(preproc=ex_preproc, denoise_strategy="HMPWMCSFScrubGS")

atlas = datasets.fetch_atlas_schaefer_2018(n_rois=100, resolution_mm=2, yeo_networks=17)
atlas_img = atlas["maps"]
atlas_labels = [label.decode()[11:] for label in atlas["labels"]]

plotting.plot_roi(atlas_img, draw_cross=False)


#### extract mean timeseries w/in each ROI
- zscore timeseries and confounds
- `nilearn.signal.clean`: 'confounds removal is based on a projection on the orthogonal of the signal space'

In [None]:
masker = maskers.NiftiLabelsMasker(
    atlas_img,
    atlas_labels,
    mask_img=data.mask,
    smoothing_fwhm=6,
    standardize="zscore",
    standardize_confounds=True,
    t_r=data.tr,
    strategy="mean",
    memory="tmp",
)


roi_timeseries = masker.fit_transform(
    imgs=data.preproc, confounds=data.confounds, sample_mask=data.sample_mask
)

roi_timeseries.shape


#### functional connectivity: correlation
- correlation and partial correlation
- add [xDF correction](https://github.com/asoroosh/xDF)

In [None]:
def fisher_z(corr_coefficients):
    return np.arctanh(corr_coefficients)


# LedoitWolf covariance estimator
corr_model = connectome.ConnectivityMeasure(kind="correlation")
pcorr_model = connectome.ConnectivityMeasure(kind="partial correlation")

corr_coefficients = corr_model.fit_transform([roi_timeseries])[0]
pcorr_coefficients = pcorr_model.fit_transform([roi_timeseries])[0]


corr_z_transform = fisher_z(corr_coefficients)
pcorr_z_transform = fisher_z(pcorr_coefficients)

# plotting
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

kwargs = dict(vmin=-1, vmax=1, cmap="RdBu_r")

plotting.plot_matrix(corr_z_transform, axes=axs[0], colorbar=False, **kwargs)
plotting.plot_matrix(pcorr_z_transform, axes=axs[1], **kwargs)

axs[0].set_title("correlation")
axs[1].set_title("partial correlation");


#### functional connectivity: PPI

In [None]:
dm = data.make_design_matrix(hrf_model="glover + derivative", drop_constant=True)
dm.plot(subplots=True, figsize=(20, 2))


In [None]:
roi_ts_sample = pd.DataFrame(
    roi_timeseries, columns=atlas_labels, index=dm.index
).sample(n=3, random_state=5, axis="columns")

sns.pairplot(roi_ts_sample.join(dm["stim"]), height=1.5, plot_kws=dict(size=1, alpha=0.5))

#### linear model summary

In [None]:
lm = smf.ols(
    formula="LH_DefaultB_PFCv_2 ~ (stim * LH_ContA_PFCl_2) + stim_derivative",
    data=dm.join(roi_ts_sample),
).fit()
print(lm.summary())

#### model diagnostics: residual plots

In [None]:
lm.resid.plot(figsize=(20, 1), title="residuals")

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
sns.regplot(lm.fittedvalues, lm.resid, lowess=True, ax=axs[0])
axs[0].set_xlabel("Fitted Values"); axs[0].set_ylabel("Residuals")
sm.qqplot(lm.resid, line="s", ax=axs[1]);

#### temporal autocorrelation

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
pd.plotting.lag_plot(lm.resid, lag=1, ax=axs[0])
pd.plotting.autocorrelation_plot(lm.resid, ax=axs[1])