In [15]:
%matplotlib qt
import sys
import mne
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append(r"C:\Users\gautier\OneDrive - CentraleSupelec\3A - Master CNN\Supervised Project\pipeline project v0\scripts")
sys.path.append(r"C:\Users\gautier\OneDrive - CentraleSupelec\3A - Master CNN\Supervised Project\pipeline project v0\config")
import eeg_preprocessing as preprocessing
import eeg_decoding as decoding
import vst_config as config

from mne.stats import spatio_temporal_cluster_1samp_test
from mne.channels import find_ch_adjacency, make_1020_channel_selections

from mpl_toolkits.axes_grid1 import make_axes_locatable
from mne.viz import plot_compare_evokeds

In [16]:
epochs_type = "start_prod"
baseline_duration = 0.8
epoch_duration = 4
pick_type = "all_channels"

subject = "VST_02_X"

low_freq = 4

In [17]:
cluster_analysis_report_dir  = f'../../reports/cluster_analysis/{epochs_type}/'

Spatio Temporal Cluster Analysis on power time courses data over all epochs

In [23]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mne.viz import plot_compare_evokeds
from mne.channels import find_ch_adjacency


def compute_and_plot_stca(epochs, report, report_out, p_accept = 0.05, n_perm = 1024):

    # Running the spatio temporal cluster analysis
    adjacency, _ = find_ch_adjacency(epochs.info, "eeg")

    # Transposing the data to fit the cluster analysis
    epochs_data = epochs.get_data().transpose(0, 2, 1)

    # Subtracting the mean beta power over each electrode
    epochs_data = np.apply_along_axis(lambda x: x - np.mean(x), 2, epochs_data)
    evo = epochs.average()

    cluster_stats = spatio_temporal_cluster_1samp_test(epochs_data,
                                                        n_permutations=n_perm, 
                                                        threshold=None, 
                                                        n_jobs=-1, 
                                                        adjacency=adjacency)        
    T_obs, clusters, p_values, H0 = cluster_stats
    if len(clusters) == 0:
        print('No significant clusters found for this predictor')
        return

    print(p_values)
    good_cluster_inds = np.where(p_values < p_accept)[0]
    print(str(len(good_cluster_inds)) + ' sign. clusters')

    p_min = np.min(p_values)

    #======================================================
    # Display clusters information
    
    #======================================================
    # if no cluster, display the lowest p of n.s. cluster (as sanity check)
    #======================================================
    if len(good_cluster_inds) == 0:
        good_cluster_inds = np.where(p_values < p_min+0.01)[0]
        print(str(len(good_cluster_inds)) + ' n.s. cluster min p')

    #======================================================
    #% loop over clusters, plot and save in report
    for i_clu, clu_idx in enumerate(good_cluster_inds):
        
        # unpack cluster information, get unique indices
        time_inds, space_inds = np.squeeze(clusters[clu_idx])
        ch_inds     = np.unique(space_inds)
        time_inds   = np.unique(time_inds)
    
        # get topography for F stat
        f_map = T_obs[time_inds, ...].mean(axis=0)
    
        # get signals at the sensors contributing to the cluster
        # sig_times = blah.times[time_inds]
        sig_times = evo.times[time_inds]
        
        # create spatial mask
        mask = np.zeros((f_map.shape[0], 1), dtype=bool)
        mask[ch_inds, :] = True
    
        # initialize figure
        fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3))
    
        # plot average test statistic and mark significant sensors
        f_evoked = mne.EvokedArray(f_map[:, np.newaxis], epochs.info, tmin=0)
        
        f_evoked.plot_topomap(times=0, mask=mask, axes=ax_topo, cmap='Greys',
                            show=False,
                            colorbar=False, mask_params=dict(markersize=10))
        image = ax_topo.images[0]
        
        # create additional axes
        divider = make_axes_locatable(ax_topo)
    
        # add axes for colorbar
        ax_colorbar = divider.append_axes('right', size='5%', pad=0.05)
        plt.colorbar(image, cax=ax_colorbar)
        ax_topo.set_xlabel(
            'Averaged F-map ({:0.3f} - {:0.3f} s)'.format(*sig_times[[0, -1]]))
    
        # add new axis for time courses and plot time courses
        ax_signals = divider.append_axes('right', size='300%', pad=1.2)
        
        title = 'Cluster #{0}, {1} sensor, p min value {2}'.format(i_clu + 1, len(ch_inds), p_min)
        
        plot_compare_evokeds([evo], ci = True, title=title, 
                            picks=ch_inds, axes=ax_signals,
                            legend = 'upper right', split_legend=True, 
                            truncate_yaxis='auto')
    
        # plot temporal cluster extent
        ymin, ymax = ax_signals.get_ylim()
        ax_signals.fill_betweenx((ymin, ymax), sig_times[0], sig_times[-1],
                                color='grey', alpha=0.3)
    
        # clean up viz
        fig.subplots_adjust(bottom=.05)

        report.add_figure(fig, title=('sign cluster_' + str(p_min)))
        plt.close()
        
    report.save(fname=report_out, open_browser=False, overwrite=True)   
    return cluster_stats 
                

In [19]:
from mne.time_frequency import (
    tfr_array_morlet,
)

# Takes the epochs as input and returns the power time course within given frequencies as an epochs array number of epochs x number of channels x times
def get_power_epochs(epochs, power_freq_low_bound, power_freq_high_bound):
    frequencies = np.arange(power_freq_low_bound, power_freq_high_bound)  
    cycles_per_freq = 6
    
    power = tfr_array_morlet(epochs.get_data(), freqs=frequencies, n_cycles=cycles_per_freq, output="power", sfreq=config.resample_sfreq)
    power_timecourse = np.mean(power, axis=2)

    # Transform to epochs and keep metadata
    epochs_power_timecourse = mne.EpochsArray(power_timecourse, info=epochs.info)
    epochs_power_timecourse.metadata = epochs.metadata

    return epochs_power_timecourse

Cluster analysis on all the epochs for the beta time course on all subjects

In [24]:
from mne.report import Report

epochs_list = []
for subject in config.preprocessed_subjects_list[:2]:
    for condition in config.conditions:
        epochs = preprocessing.load_subject_epochs_by_type_and_condition(subject, condition, epochs_type, baseline_duration, epoch_duration, pick_type, verbose=False)
        epochs = preprocessing.change_bad_channels(epochs)
        epochs_list.append(epochs)

low_freq = 15
high_freq = 30
all_epochs = mne.concatenate_epochs(epochs_list)
power_epochs = get_power_epochs(all_epochs, low_freq, high_freq)

report_out  = cluster_analysis_report_dir + f'beta_cluster_analysis_all_subjects.html'
os.makedirs(cluster_analysis_report_dir, exist_ok=True)
report      = Report(report_out, verbose=False)

power_epochs.pick_types(eeg=True)

cluster_stats = compute_and_plot_stca(power_epochs, report, report_out, n_perm=500)

Adding metadata with 1 columns
1549 matching events found
No baseline correction applied


  power = tfr_array_morlet(epochs.get_data(), freqs=frequencies, n_cycles=cycles_per_freq, output="power", sfreq=config.resample_sfreq)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:   12.1s


Not setting metadata
1549 matching events found
No baseline correction applied
0 projection items activated
Adding metadata with 1 columns
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Could not find a adjacency matrix for the data. Computing adjacency based on Delaunay triangulations.
-- number of adjacent vertices : 46


  epochs_data = epochs.get_data().transpose(0, 2, 1)


Using a threshold of 1.961498
stat_fun(H1): min=-63.425993 max=27.755327
Running initial clustering …
Found 2 clusters


100%|██████████| Permuting : 499/499 [03:58<00:00,    2.09it/s]


[0.002 0.002]
2 sign. clusters
combining channels using "gfp"
combining channels using "gfp"
Overwriting existing file.
Saving report to : C:\Users\gautier\OneDrive - CentraleSupelec\3A - Master CNN\Supervised Project\pipeline project v0\reports\cluster_analysis\start_prod\beta_cluster_analysis_all_subjects.html


In [34]:
print("channel names of cluster 1")
mask = cluster_stats[1][0][1]
channels = [ch for i, ch in enumerate(power_epochs.info["ch_names"]) if i in mask]
print("There are " + str(len(channels)) + " channels in the cluster:")
print(channels)

print("channel names of cluster 2")
mask = cluster_stats[1][1][1]
channels = [ch for i, ch in enumerate(power_epochs.info["ch_names"]) if i in mask]
print("There are " + str(len(channels)) + " channels in the cluster:")
print(channels)

channel names of cluster 1
There are 24 channels in the cluster:
['Fp1', 'Fp2', 'F7', 'FC5', 'M2', 'P7', 'P3', 'P4', 'P8', 'POz', 'O1', 'O2', 'AF8', 'P5', 'P6', 'PO5', 'PO3', 'PO4', 'PO6', 'FT8', 'TP8', 'PO7', 'PO8', 'Oz']
channel names of cluster 2
There are 24 channels in the cluster:
['Fpz', 'F3', 'Fz', 'F4', 'FC5', 'FC1', 'FC2', 'FC6', 'C4', 'CP5', 'CP1', 'CP6', 'P3', 'AF4', 'F1', 'F2', 'FCz', 'C5', 'C1', 'C2', 'C6', 'CP3', 'P1', 'P2']


Cluster analysis on the beta time course on the difference between long and short duration

In [16]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mne.viz import plot_compare_evokeds
from mne.channels import find_ch_adjacency
from mne.stats import spatio_temporal_cluster_test


def compute_and_plot_stca_2_conditions(epochs1, epochs2, report, report_out, p_accept = 0.05, n_perm = 1024):

    # Running the spatio temporal cluster analysis

    epochs1.pick_types(eeg=True)
    epochs2.pick_types(eeg=True)
    adjacency, _ = find_ch_adjacency(epochs1.info, "eeg")

    X = [
        epochs1.get_data(copy=False).transpose(0, 2, 1),
        epochs2.get_data(copy=False).transpose(0, 2, 1),
    ]

    # Computing the combined evoked response
    evo = mne.combine_evoked(
        [epochs1.average(), epochs2.average()], weights=[1, -1]
    )

    cluster_stats = spatio_temporal_cluster_test(X,
                                                    n_permutations=n_perm, 
                                                    threshold=None, 
                                                    n_jobs=-1, 
                                                    adjacency=adjacency)        
    T_obs, clusters, p_values, H0 = cluster_stats
    if len(clusters) == 0:
        print('No significant clusters found for this predictor')
        return

    print(p_values)
    good_cluster_inds = np.where(p_values < p_accept)[0]
    print(str(len(good_cluster_inds)) + ' sign. clusters')

    p_min = np.min(p_values)

    #======================================================
    # Display clusters information
    
    #======================================================
    # if no cluster, display the lowest p of n.s. cluster (as sanity check)
    #======================================================
    if len(good_cluster_inds) == 0:
        good_cluster_inds = np.where(p_values < p_min+0.01)[0]
        print(str(len(good_cluster_inds)) + ' n.s. cluster min p')

    #======================================================
    #% loop over clusters, plot and save in report
    for i_clu, clu_idx in enumerate(good_cluster_inds):
        
        # unpack cluster information, get unique indices
        time_inds, space_inds = np.squeeze(clusters[clu_idx])
        ch_inds     = np.unique(space_inds)
        time_inds   = np.unique(time_inds)
    
        # get topography for F stat
        f_map = T_obs[time_inds, ...].mean(axis=0)
    
        # get signals at the sensors contributing to the cluster
        # sig_times = blah.times[time_inds]
        sig_times = evo.times[time_inds]
        
        # create spatial mask
        mask = np.zeros((f_map.shape[0], 1), dtype=bool)
        mask[ch_inds, :] = True
    
        # initialize figure
        fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3))
    
        # plot average test statistic and mark significant sensors
        f_evoked = mne.EvokedArray(f_map[:, np.newaxis], epochs1.info, tmin=0)
        
        f_evoked.plot_topomap(times=0, mask=mask, axes=ax_topo, cmap='Greys',
                            show=False,
                            colorbar=False, mask_params=dict(markersize=10))
        image = ax_topo.images[0]
        
        # create additional axes
        divider = make_axes_locatable(ax_topo)
    
        # add axes for colorbar
        ax_colorbar = divider.append_axes('right', size='5%', pad=0.05)
        plt.colorbar(image, cax=ax_colorbar)
        ax_topo.set_xlabel(
            'Averaged F-map ({:0.3f} - {:0.3f} s)'.format(*sig_times[[0, -1]]))
    
        # add new axis for time courses and plot time courses
        ax_signals = divider.append_axes('right', size='300%', pad=1.2)
        
        title = 'Cluster #{0}, {1} sensor, p min value {2}'.format(i_clu + 1, len(ch_inds), p_min)
        
        plot_compare_evokeds([evo], ci = True, title=title, 
                            picks=ch_inds, axes=ax_signals,
                            legend = 'upper right', split_legend=True, 
                            truncate_yaxis='auto')
    
        # plot temporal cluster extent
        ymin, ymax = ax_signals.get_ylim()
        ax_signals.fill_betweenx((ymin, ymax), sig_times[0], sig_times[-1],
                                color='grey', alpha=0.3)
    
        # clean up viz
        fig.subplots_adjust(bottom=.05)

        report.add_figure(fig, title=('sign cluster_' + str(p_min)))
        plt.close()
        
        report.save(fname=report_out, open_browser=False, overwrite=True)    
                

In [17]:
from mne.report import Report

# Retrieve the power epochs of long duration and short duration apart

low_freq = 15
high_freq = 30

epochs_lists = {"L":[], "S":[]}
# extract all pairs of condition where one ends up with an L and the other with an S
duration_conditions = [conditions for conditions in config.condition_pairs if conditions[0][-1] == "S" and conditions[1][-1] == "L" or conditions[0][-1] == "L" and conditions[1][-1] == "S"]

# order the tuples so that the first condition is always the one ending up with an L
duration_conditions = [conditions if conditions[0][-1] == "L" else (conditions[1], conditions[0]) for conditions in duration_conditions]

print(duration_conditions)

for subject in config.preprocessed_subjects_list:
    for conditions in duration_conditions:
        epochs_long = preprocessing.load_subject_epochs_by_type_and_condition(subject, conditions[0], epochs_type, baseline_duration, epoch_duration, pick_type, verbose=False)
        epochs_lists["L"].append(epochs_long)

        epochs_short = preprocessing.load_subject_epochs_by_type_and_condition(subject, conditions[1], epochs_type, baseline_duration, epoch_duration, pick_type, verbose=False)
        epochs_lists["S"].append(epochs_short)

all_short_epochs = mne.concatenate_epochs(epochs_lists["S"])
all_long_epochs = mne.concatenate_epochs(epochs_lists["L"])

all_short_power_epochs = get_power_epochs(all_short_epochs, low_freq, high_freq)
all_long_power_epochs = get_power_epochs(all_long_epochs, low_freq, high_freq)

report_out  = cluster_analysis_report_dir + f'beta_conditions_all_subjects.html'
os.makedirs(cluster_analysis_report_dir, exist_ok=True)
report      = Report(report_out, verbose=False)

compute_and_plot_stca_2_conditions(all_short_power_epochs, all_long_power_epochs,  report, report_out, n_perm=100)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Could not find a adjacency matrix for the data. Computing adjacency based on Delaunay triangulations.
-- number of adjacent vertices : 63
Using a threshold of 3.854299


  cluster_stats = spatio_temporal_cluster_test(X,


stat_fun(H1): min=0.000000 max=14.480224
Running initial clustering …
Found 63 clusters


100%|██████████| Permuting : 99/99 [00:26<00:00,    3.68it/s]

[0.66 1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.
 1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.
 1.   1.   1.   1.   0.99 1.   1.   0.87 1.   1.   0.99 1.   1.   1.
 1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.   1.
 1.   1.   0.17 1.   1.   1.   1.  ]
0 sign. clusters
1 n.s. cluster min p
combining channels using "gfp"





Saving report to : C:\Users\gautier\OneDrive - CentraleSupelec\3A - Master CNN\Supervised Project\pipeline project v0\reports\cluster_analysis\start_prod\beta_conditions_test.html
