In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import mne
import os
from matplotlib import colors as mcolors
from sklearn.preprocessing import StandardScaler
from scipy.io import loadmat
from scipy import stats
from mne.stats import permutation_cluster_test
import seaborn as sns
sns.set_style('darkgrid')

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)

path_out = 'C:/skoltech_hand_writing'
subj = ''

In [None]:
# upload cleaned (!) epochs
epochs_list_concated = mne.read_epochs(f'{path_out}/hand_writing/{subj}/5_files_for_model/epochs_eeg_2-epo.fif', preload = True)
epochs_eeg=epochs_list_concated.copy().resample(1000)
epochs_eeg_init=epochs_list_concated.copy().resample(1000)

In [None]:
# a function to create and save graphs with permutation stats
def cluster_plot(times,condition1,condition2,title, n_permutations=1000, save = True):

  threshold = 6.0
  ch_signif_list=[]
  T_obs, clusters, cluster_p_values, H0 = \
      permutation_cluster_test([condition1, condition2], n_permutations=n_permutations,
                              threshold=threshold, tail=1, n_jobs=None,
                              out_type='mask')
  fig, (ax, ax2) = plt.subplots(2, 1, figsize=(8, 4))


  #ax.set_title('Channel : ' + channel)
  ax.plot(times, condition1.mean(axis=0) - condition2.mean(axis=0),
          label="Contrast (Event 1 - Event 2)")
  #ax.set_ylabel("MEG (T / m)")
  ax.legend()

  for i_c, c in enumerate(clusters):
      c = c[0]
      if cluster_p_values[i_c] <= 0.05:
          ch_signif_list.append( [c.start,c.stop] )
          h = ax2.axvspan(times[c.start], times[c.stop - 1],
                          color='r', alpha=0.3)

      else:
          ax2.axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3),
                      alpha=0.3)


  hf =plt.plot(times, T_obs, 'g')
  ax2.set_xlabel("time (ms)")
  ax2.set_ylabel("f-values")
  plt.suptitle(title)
  plt.axvline(x=0, color='red')
  plt.text(x=0, y=1,s='trigger')

  plt.axvline(x=.5, color='red')
  plt.text(x=.5, y=1,s='digit')

  plt.axvline(x=5,color='red')
  plt.text(x=5,y=1,s='+')
  if save:
    plt.savefig(f'{path_out}/hand_writing/{subj}/output/{title}_permutation.png', dpi = 300)
  plt.show()
  return ch_signif_list

In [None]:
sfreq = epochs_eeg.info['sfreq']
# Initialize an info structure
info = mne.create_info(
    ch_names=['eeg1'],
    ch_types=['eeg'],
    sfreq=sfreq
)

ch_signif_list_dict={}

for i in range(len(epochs_eeg.ch_names)):
  ch_i=i


  #Extract data fron channels and stack it
  custom_epochs=epochs_eeg

  #Length of the stacked recording
  interval=epochs_eeg_init._data[:,ch_i,:].flatten().shape[0]
  flatten_data=epochs_eeg_init._data[:,ch_i,:].flatten()

  #Length of randomly generated sequence
  epoch_length=custom_epochs._data.shape[-1]


  #We take a random value in the stacked data and then we take the whole interval
  random_seq_list=[]
  for k in range(len(custom_epochs)):
    rand_v=np.random.randint(interval-epoch_length)
    random_seq=flatten_data[rand_v:rand_v+epoch_length]
    random_seq_list.append(random_seq)

#We will use this array of random sequencyes for cluster analysus
  random_array=np.array(random_seq_list)
  ch_signif_list=cluster_plot(custom_epochs.times, custom_epochs._data[:,i,:], random_array, epochs_eeg.ch_names[i], n_permutations=500)

  ch_signif_list_dict[epochs_eeg.ch_names[i]]=ch_signif_list

#Adding point of experiment


In [None]:
# define colors to plot all channels on one graph
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
colors = list(colors.keys())[::-3]

In [None]:
# definetime limits to plot the statistical graph and number of channels to plot
l_border_ms, r_border_ms= -0.5, 4.5
l_ch_ind, r_ch_ind= 0,len(epochs_eeg.ch_names)

In [None]:
# plot and save

plt.rcParams.update({'font.size': 12})

fig=plt.figure(figsize=(15,8))
for i in range(l_ch_ind, r_ch_ind):
    ch_i=i
    ch_nm=epochs_eeg.ch_names[ch_i]
    times_= (custom_epochs.times)
    mean_=np.mean(custom_epochs._data[:,i,:],0)
    l_border_ms = times_[0]
    r_border_ms = times_[-1]
    l_border_ind, r_border_ind= times_.tolist().index(l_border_ms), times_.tolist().index(r_border_ms)

    for k in range(len(ch_signif_list_dict[ch_nm])):
            try:
                s_=np.array(ch_signif_list_dict[ch_nm])[k][0]
                e_=np.array(ch_signif_list_dict[ch_nm])[k][1]
                if times_[s_:e_][-1]>l_border_ms and times_[s_:e_][-1]<r_border_ms :
                    plt.plot(times_[s_:e_] ,mean_[s_:e_], color=colors[i- l_ch_ind], linewidth=5, alpha=0.9)
                    print(ch_nm,s_,e_)

            except:
                print(ch_nm, 'none')
                
    plt.plot(times_[l_border_ind:r_border_ind]  , mean_[l_border_ind:r_border_ind] , '--', alpha=0.5, color=colors[i-l_ch_ind], label=ch_nm )


plt.xlim(l_border_ms,r_border_ms)
plt.xticks( np.arange(l_border_ms,r_border_ms,0.25) )

ax=plt.gca()
ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

plt.ylabel('Amplitude, V')
plt.xlabel('Time, s')
plt.legend()
plt.grid(True)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig(f'{path_out}/hand_writing/{subj}/output/{subj}_stats.png', dpi = 300)
plt.show()
