# Running a cluster-based permutation test

In [None]:
from pathlib import Path
from pqdm.processes import pqdm
import numpy as np
import matplotlib.pyplot as plt
import mne


mne.set_log_level(verbose='WARNING')

N_JOBS = 7  # Should match the number of your (high-performance) CPU cores

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

## Calculate evoked contrasts

The tests we're going to run will be "one-sample" t-tests against zero. The input is the difference between the conditions we wish to compare, and the test then tries to evaluate whether this difference is singificantly different from zero (H1) or not (H0).

In [None]:
subjects = [
    '04', '06','07','08','09', '13','14','15','16','18','19',
    '20','21','22','23','24','25','26','27','28','29','30','32'
]

contrast = ['Cond == 2', 'Cond == 4']


def read_cleaned_epochs(subject):
    # Import epochs
    meg_subject_dir = Path(f'lb_analysis_sparse/import_metadata/sub-{subject}/meg/')
    filename = f'sub-{subject}_task-sparse_metadata_proc-clean_epo.fif'
    epochs_path = meg_subject_dir / filename

    epochs = mne.read_epochs(epochs_path)
    return epochs


def compute_evoked_contrast(subject):
    # Read epochs
    epochs = read_cleaned_epochs(subject)

    # Compute evokeds
    evokeds = [
        epochs[contrast[0]].average(),
        epochs[contrast[1]].average()
    ]

    # Contrast the evokeds
    evoked_contrast = mne.combine_evoked(evokeds, weights=[1, -1])

    return evoked_contrast


evoked_contrasts_all_subjects = pqdm(
    subjects,
    function=compute_evoked_contrast,
    n_jobs=1
)

print(f'Calculated {len(evoked_contrasts_all_subjects)} contrasts!')

In [None]:
ch_type = 'mag'  # we only use magnetometers – data was Maxwell-filtered
tmin = 0
tmax = None  # end of the evoked

all_evoked_contrasts_cropped = []
all_data = []

for evoked_contrast in evoked_contrasts_all_subjects:
    evoked_contrast_cropped = (
        evoked_contrast
        .copy()  # don't modify the original!
        .crop(tmin=tmin, tmax=tmax)
        .pick(ch_type)
    )
    all_evoked_contrasts_cropped.append(evoked_contrast_cropped)

    data = evoked_contrast_cropped.data.T  # transpose to have time on the first axis
    all_data.append(data)

    del evoked_contrast_cropped, data  # so we don't accidentally re-use it later!

all_data = np.array(all_data)

# get measurement info from first participant
info = all_evoked_contrasts_cropped[0].info
times = all_evoked_contrasts_cropped[0].times

# compute adjacency
adjacency, ch_names = mne.channels.find_ch_adjacency(info=info, ch_type=ch_type)

all_data.shape, adjacency.shape, times.shape

In [None]:
all_data[0].shape

In [None]:
from mne.stats import permutation_cluster_1samp_test, spatio_temporal_cluster_1samp_test

n_permutations = 10_000
# n_permutations = 'all'
tail = 0
threshold = None

cluster_stats = spatio_temporal_cluster_1samp_test(
    X=all_data,
    threshold=threshold,
    n_jobs=N_JOBS,
    tail=tail,
    n_permutations=n_permutations,
    adjacency=adjacency,
    out_type='mask',
    seed=42,  # make results reproducible
    verbose=True,
)

T_obs, clusters, p_values, _ = cluster_stats

In [None]:
len(clusters)

In [None]:
len(p_values)

In [None]:
p_threshold = 0.05
significant_clusters_indices = np.where(p_values < p_threshold)[0]  # don't forget the [0]
significant_clusters_indices

In [None]:
significant_clusters = []

for cluster_index in significant_clusters_indices:
    significant_clusters.append(clusters[cluster_index])

significant_clusters

In [None]:
significant_clusters[0].shape

In [None]:
import pandas as pd

cluster_df = pd.DataFrame(
    significant_clusters[0],
    columns=all_evoked_contrasts_cropped[0].ch_names,
    index=pd.Index(times, name='time')
)
cluster_df.to_clipboard()

## Look at T-values for a single channel

In [None]:
ch_name = 'MEG1031'
ch_idx = (
    all_evoked_contrasts_cropped[0]
    .ch_names
    .index(ch_name)
)

fig, ax = plt.subplots()
ax.plot(times, T_obs[:, ch_idx])
ax.set_title(f'T-values for {ch_name}')
ax.set_xlabel('Time (s)')
ax.set_ylabel('T-value')

## Cut down result matrix to rows and columns that contain at least one cell that's part of a significant cluster

In [None]:
cluster_df_cleaned = (
    cluster_df[cluster_df]
    .dropna(how='all', axis='columns')
    .dropna(how='all', axis='index')
)
cluster_df_cleaned

In [None]:
cluster_channels = cluster_df_cleaned.columns
cluster_times = cluster_df_cleaned.index

In [None]:
time_indices = np.where((times >= cluster_times[0]) & (times <= cluster_times[-1]))[0]
time_indices

In [None]:
channel_indices = np.array([
    all_evoked_contrasts_cropped[0].ch_names.index(ch_name)
    for ch_name in cluster_channels
])
channel_indices

In [None]:
T_vals_cluster_time = T_obs[time_indices, :]
T_vals_cluster_time.shape

In [None]:
T_vals_cluster_time_average = T_vals_cluster_time.mean(axis=0)
T_vals_cluster_time_average.shape

In [None]:
channel_layout = mne.find_layout(info, ch_type=ch_type)
channel_positions = channel_layout.pos
channel_positions.shape

In [None]:
channel_mask = np.array(
    [ch_name in cluster_channels
     for ch_name in info.ch_names],
    dtype='bool'
)


fig, ax = plt.subplots()
ax.set_title(
    f'Mean T-values\n{cluster_times[0]} — {cluster_times[-1]} sec',
    fontweight='bold'
)

mne.viz.plot_topomap(
    data=T_vals_cluster_time_average,
    pos=info,
    extrapolate="head",
    mask=channel_mask,
    ch_type=ch_type,
    axes=ax
);


In [None]:
fig, ax = plt.subplots()

mne.viz.plot_compare_evokeds(
    evokeds={
        'Cluster #1': evoked_contrasts_all_subjects
    },
    picks=cluster_channels.to_list(),
    combine='mean',
    axes=ax,
    truncate_yaxis=False,
    truncate_xaxis=False,
    show=False
)
ax.fill_betweenx(ax.get_ylim(), cluster_times[0], cluster_times[-1], color='orange', alpha=0.2)
ax.axvline(cluster_times[0], ls='--', color='black')
ax.axvline(cluster_times[-1], ls='--', color='black')
ax.set_title(f'Mean of {len(cluster_channels)} channels in cluster')
