## 1. Self Written Feature Extraction

In this file, I will use self-written connectivity feature extraction. 

In [7]:
import time
import mne
import numpy as np
from tqdm.notebook import tqdm

from components.dataset_jo import *

dataset = Dataset_subjectDependent('data')
dataset.set_segment(60)
# for filename in dataset.get_file_list():
#     data, labels, _ = dataset.get_data(filename, return_type='numpy')
#     print(filename, labels.shape, labels.sum(axis=0)/labels.shape[0]*100)

Found: 32 files


In [8]:
def train_model(X_ori,y_ori,groups_ori,filename=None, kernel='rbf',return_text=False):
    
    # Make a copy because I am paranoid
    X,y,groups = X_ori.copy(), y_ori.copy(), groups_ori.copy()

    from sklearn.svm import SVC
    from sklearn.model_selection import StratifiedShuffleSplit 
    from sklearn.model_selection import cross_val_score

    model = SVC(kernel=kernel,max_iter=50000)
    cv = StratifiedShuffleSplit(n_splits=10, train_size=0.75, random_state=0)
    cross = cross_val_score(model, X, y, cv=cv, n_jobs=8)
    
    # We probably dont need this
    model = SVC(kernel=kernel, max_iter=50000)
    model.fit(X, y)
    ans = model.predict(X)
    acc = sum(ans == y) / len(y)
    # If the model answer with all 0 or 1, we print this message
    text = None
    if( sum(ans) == len(y) or sum(ans) == 0 ): 
        text = f"-----WARNING: Model {filename} failed to learn: sum(ans)={sum(ans)} sum(y)={sum(y)} len(y)={len(y)}"
    if(return_text):
        return model, acc, cross, text
    else:
        print(text)
        return model, acc, cross


def pearson_correlation(x,y):
    """ x,y denoted the signal_x and signal_y following the equation """
    cov = np.cov(x, y)
    # print(cov)
    # [[ 8806859.74527069  8007149.0906219 ] ==> [[cov_xx, cov_xy]
    # [ 8007149.0906219  10396797.72458848]]      [cov_yx, cov_yy]]
    cov_xy = cov[0,1] # or cov[1,0]
    cov_xx = cov[0,0]
    cov_yy = cov[1,1]
    corr = cov_xy / ( cov_xx**0.5 * cov_yy**0.5  )
    return corr

def _cal(p_id, partial_data):
    # print(f"p_id:{p_id} - data to run {partial_data.shape}")
    from itertools import combinations
    pcc = []
    for index in range(partial_data.shape[0]):
        pcc_epoch = []
        for comb in combinations(list(range(partial_data.shape[1])), 2):
            pcc_ab = pearson_correlation(partial_data[index, comb[0], :], partial_data[index, comb[1], :]   )
            pcc_epoch.append(pcc_ab)
        pcc_epoch = np.hstack(pcc_epoch)
        pcc.append(pcc_epoch)
    pcc = np.vstack(pcc)
    return pcc

def calculate_pcc(data, n_jobs=8):
    """ 
    Input: Expect data to have (n_epochs, n_channels, n_samples)
    Output: (n_epochs, n_conn ) => n_conn = n_channels!/(2!(n_channels-2)!)
    """
    from multiprocessing import Pool

    try:
        t_out = 60000
        pool = Pool()
        p_list = []
        ans_list = []
        num_p = n_jobs
        indices = np.array_split(np.arange(data.shape[0]), num_p)
        for p_id in range(num_p):
            p_list.append(pool.apply_async(_cal, [p_id, data[indices[p_id]] ]))
        for p_id in range(num_p):
            ans_list.append( p_list[p_id].get(timeout=t_out) )
        # ans_list
    except e:
        print(e)
    finally:
        print("========= close ========")
        pool.close() 
        pool.terminate()
    return np.vstack(ans_list)


In [9]:
data, labels, groups = dataset.get_data('s01', stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
data_csd = mne.preprocessing.compute_current_source_density(data)
pcc = calculate_pcc(data_csd.get_data())
pcc.shape

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm


(2400, 496)

## 2. Connectivity

### 2.1 $ \text{PCC}_{\text{time}}(i,j) = \frac{\text{Cov}[\mathbf{X}_i, \mathbf{X}_j]}{\sigma_{\mathbf{X}_i} \sigma_{\mathbf{X}_j}} $

In [10]:
accs, cv_means, cv_stds = [],[],[]
accs_v, cv_means_v, cv_stds_v = [],[],[]
reports = []
for filename in (pbar := tqdm(dataset.get_file_list())):
    start = time.time()
    pbar.set_description(filename)
    data, labels, groups = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
    data_csd = mne.preprocessing.compute_current_source_density(data)
    pcc = calculate_pcc(data_csd.get_data())
    _,acc,cross,train_report = train_model(pcc, labels.squeeze(), groups, filename=filename, return_text=True)
    report = f"\tAROUSAL-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs.append(acc)
    cv_means.append(cross.mean())
    cv_stds.append(cross.std())

    _, labels_v, groups_v = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_VALENCE, return_type='mne', sfreq=128)
    _,acc,cross,train_report = train_model(pcc, labels_v.squeeze(), groups_v, filename=filename, return_text=True)
    report = f"\tVALENCE-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs_v.append(acc)
    cv_means_v.append(cross.mean())
    cv_stds_v.append(cross.std())
    # break
    # count+=1
    # if(count == 5): break
report = f"AROUSAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}"
print(report)
reports.append(report)
report = f"VALENCE|Acc={sum(accs_v)/len(accs_v)}|10-CV={sum(cv_means_v)/len(cv_means_v)}|STD={sum(cv_stds_v)/len(cv_stds_v)}"
print(report)
reports.append(report)

  0%|          | 0/32 [00:00<?, ?it/s]

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s09|Acc=0.60625|10-CV=0.60017|STD=0.0009|Time spend=33.4352240562439
	VALENCE-s09|Acc=0.62083|10-CV=0.61667|STD=0.01398|Time spend=46.04914402961731
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s19|Acc=0.675|10-CV=0.675|STD=0.0|Time spend=29.94568943977356
	VALENCE-s19|Acc=0.61375|10-CV=0.6025|STD=0.01502|Time spend=40.618797063827515
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s17|Acc=0.60208|10-CV=0.60017|STD=0.00203|Time spend=31.794206619262695
	VALENCE-s17|Acc=0.55|10-CV=0.54983|STD=0.0005|Time spend=42.96705198287964
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 

In [10]:
for text in reports:
    if(text == None): continue
    print(text)

	AROUSAL-s01|Acc=0.71042|10-CV=0.53583|STD=0.12882|Time spend=24.649158477783203
	VALENCE-s01|Acc=0.67917|10-CV=0.47867|STD=0.07076|Time spend=36.57502794265747
	AROUSAL-s02|Acc=0.62083|10-CV=0.55067|STD=0.09966|Time spend=28.401918411254883
	VALENCE-s02|Acc=0.59125|10-CV=0.47917|STD=0.10006|Time spend=40.15917253494263
	AROUSAL-s03|Acc=0.8|10-CV=0.73|STD=0.15524|Time spend=22.932647466659546
	VALENCE-s03|Acc=0.55042|10-CV=0.43967|STD=0.12141|Time spend=34.696168661117554
	AROUSAL-s04|Acc=0.60208|10-CV=0.52133|STD=0.20422|Time spend=28.39444637298584
	VALENCE-s04|Acc=0.60042|10-CV=0.4905|STD=0.18738|Time spend=39.192758321762085
	AROUSAL-s05|Acc=0.57417|10-CV=0.42483|STD=0.1069|Time spend=28.865741968154907
	VALENCE-s05|Acc=0.60208|10-CV=0.58633|STD=0.08869|Time spend=40.79992604255676
	AROUSAL-s06|Acc=0.61125|10-CV=0.5235|STD=0.14029|Time spend=27.70300030708313
	VALENCE-s06|Acc=0.75|10-CV=0.72|STD=0.07483|Time spend=36.020861864089966
	AROUSAL-s07|Acc=0.67333|10-CV=0.6235|STD=0.09409

In [28]:
accs.extend(accs_v)
cv_means.extend(cv_means_v)
cv_stds.extend(cv_stds_v)
print(f"TOTAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}")

TOTAL|Acc=0.6472066885964911|10-CV=0.6343654057017551|STD=0.00979406488222259


### 2.2 $ \text{PCC}_{freq} (i,j) = \frac{\text{Cov}[\hat{\mathbf{X}}_i, \hat{\mathbf{X}}_j]}{\sigma_{\hat{\mathbf{X}}_i} \sigma_{\hat{\mathbf{X}}_j}} $

In [29]:
def calculate_fft(signal, sfreq):
    """ signal: can be 1D array of (n_sample,) or 2D array of (n_signal, n_sample)  """
    number_sample = signal.shape[-1]

    # the result will be a complex number. We can obtain the magnitude using `absolute`
    magnitude = np.abs(np.fft.fft(signal))
    # scale the result
    magnitude = magnitude / (number_sample/2)
    # Selecting the range
    magnitude = magnitude.T[:number_sample//2].T
    freq_range = np.fft.fftfreq(number_sample, d=1/sfreq)[:number_sample//2]

    return magnitude, freq_range

In [30]:
accs, cv_means, cv_stds = [],[],[]
accs_v, cv_means_v, cv_stds_v = [],[],[]
reports = []
for filename in (pbar := tqdm(dataset.get_file_list())):
    start = time.time()
    pbar.set_description(filename)
    data, labels, groups = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
    
    data_csd = mne.preprocessing.compute_current_source_density(data)
    magnitude, freq_range = calculate_fft(data_csd.get_data(), 128)
    pcc = calculate_pcc(magnitude)
    _,acc,cross,train_report = train_model(pcc, labels.squeeze(), groups, filename=filename, return_text=True)
    report = f"\tAROUSAL-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs.append(acc)
    cv_means.append(cross.mean())
    cv_stds.append(cross.std())

    _, labels_v, groups_v = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_VALENCE, return_type='mne', sfreq=128)
    _,acc,cross,train_report = train_model(pcc, labels_v.squeeze(), groups_v, filename=filename, return_text=True)
    report = f"\tVALENCE-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs_v.append(acc)
    cv_means_v.append(cross.mean())
    cv_stds_v.append(cross.std())
    # break
    # count+=1
    # if(count == 5): break
report = f"AROUSAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}"
print(report)
reports.append(report)
report = f"VALENCE|Acc={sum(accs_v)/len(accs_v)}|10-CV={sum(cv_means_v)/len(cv_means_v)}|STD={sum(cv_stds_v)/len(cv_stds_v)}"
print(report)
reports.append(report)

  0%|          | 0/32 [00:00<?, ?it/s]

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s09|Acc=0.65792|10-CV=0.6205|STD=0.01085|Time spend=50.36061072349548
	VALENCE-s09|Acc=0.67792|10-CV=0.63817|STD=0.01161|Time spend=65.66573977470398
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s19|Acc=0.68833|10-CV=0.67867|STD=0.00452|Time spend=55.59354519844055
	VALENCE-s19|Acc=0.6925|10-CV=0.606|STD=0.01254|Time spend=67.668137550354
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s17|Acc=0.67042|10-CV=0.65367|STD=0.01653|Time spend=49.55806279182434
	VALENCE-s17|Acc=0.59917|10-CV=0.56267|STD=0.01093|Time spend=62.91363191604614
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.

In [31]:
for text in reports:
    if(text == None): continue
    print(text)

accs.extend(accs_v)
cv_means.extend(cv_means_v)
cv_stds.extend(cv_stds_v)
print(f"TOTAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}")

	AROUSAL-s09|Acc=0.65792|10-CV=0.6205|STD=0.01085|Time spend=50.36061072349548
	VALENCE-s09|Acc=0.67792|10-CV=0.63817|STD=0.01161|Time spend=65.66573977470398
	AROUSAL-s19|Acc=0.68833|10-CV=0.67867|STD=0.00452|Time spend=55.59354519844055
	VALENCE-s19|Acc=0.6925|10-CV=0.606|STD=0.01254|Time spend=67.668137550354
	AROUSAL-s17|Acc=0.67042|10-CV=0.65367|STD=0.01653|Time spend=49.55806279182434
	VALENCE-s17|Acc=0.59917|10-CV=0.56267|STD=0.01093|Time spend=62.91363191604614
	AROUSAL-s28|Acc=0.58875|10-CV=0.554|STD=0.00496|Time spend=49.26823949813843
	VALENCE-s28|Acc=0.63583|10-CV=0.62617|STD=0.00299|Time spend=61.37085151672363
	AROUSAL-s07|Acc=0.75583|10-CV=0.65917|STD=0.01745|Time spend=47.76533794403076
	VALENCE-s07|Acc=0.79|10-CV=0.72367|STD=0.00869|Time spend=58.30316948890686
	AROUSAL-s11|Acc=0.65667|10-CV=0.64233|STD=0.00517|Time spend=48.72523903846741
	VALENCE-s11|Acc=0.65958|10-CV=0.64217|STD=0.00907|Time spend=60.64424777030945
	AROUSAL-s25|Acc=0.725|10-CV=0.725|STD=0.0|Time spe

### 2.3 $ \text{PLV}(j,k) = \frac{1}{T} | \Sigma^{T}_{t=1} e^{i(\phi^{t}_{j} - \phi^{t}_{k})}   | $

In [32]:
# bands = [(0,4), (4,8), (8,12), (12,30), (30,64)]
def calculate_stft(signals, sfreq):
    from scipy import signal
    f_range, t_range, Z = signal.stft(signals, sfreq, nperseg=sfreq//10, nfft=sfreq)
    magnitude = np.abs(Z) 
    phase = np.angle(Z)
    return magnitude, phase, f_range, t_range

def PLV_stft(p_id, phase):
    # (32, 65, 12) => (channels, freqs, times)
    # print(f"p_id:{p_id} - data to run {phase.shape}")
    print(phase.shape)
    from itertools import combinations
    plv = []
    # count = 0
    for index in range(phase.shape[0]):
        plv_epoch = []
        for comb in combinations(list(range(phase.shape[1])), 2):
            # shape = (65,12)
            phase_a, phase_b = phase[index, comb[0]], phase[index, comb[1]]
            phase_diff = phase_a - phase_b
            # sum along the time size
            plv_ab = np.abs(np.average(np.exp(complex(0,1) * phase_diff), axis=1))
            plv_epoch.append(plv_ab)
        plv_epoch = np.vstack(plv_epoch)
        # print(plv_epoch.shape) => (300, 32, 65, 23)
        plv_epoch_5 = np.concatenate([ plv_epoch[:,0:4].mean(axis=1).reshape(-1,1),
                                        plv_epoch[:,4:8].mean(axis=1).reshape(-1,1),
                                        plv_epoch[:,8:12].mean(axis=1).reshape(-1,1),
                                        plv_epoch[:,12:30].mean(axis=1).reshape(-1,1),
                                        plv_epoch[:,30:65].mean(axis=1).reshape(-1,1)], axis=0)
        plv.append(np.expand_dims(plv_epoch_5, axis=0))
        # count += 1
        # if(count == 3): break
    # shape (496, 65)
    # 496 is number of pairs that is not duplicate
    # 65 is number of phase of frequencies
    plv = np.vstack( plv )
    return plv.squeeze()

def calculate_plv(data, n_jobs=8):
    """ 
    Input: Expect data to have (n_epochs, n_channels, n_samples)
    Output: (n_epochs, n_conn, n_freqs ) => n_conn = n_channels!/(2!(n_channels-2)!)
    """
    from multiprocessing import Pool

    try:
        t_out = 60000
        pool = Pool()
        p_list = []
        ans_list = []
        num_p = n_jobs
        indices = np.array_split(np.arange(data.shape[0]), num_p)
        _, phase, _, _ = calculate_stft(data, 128)
        for p_id in range(num_p):
            p_list.append(pool.apply_async(PLV_stft, [p_id, phase[indices[p_id]] ]))
        for p_id in range(num_p):
            ans_list.append( p_list[p_id].get(timeout=t_out) )
        # ans_list
    finally:
        print("========= close ========")
        pool.close() 
        pool.terminate()
    return np.vstack(ans_list)


In [33]:
data, labels, groups = dataset.get_data('s01', stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
data_csd = mne.preprocessing.compute_current_source_density(data)
plv = calculate_plv(data_csd.get_data(), n_jobs=8)
# _, phase, _, _ = calculate_stft(data_csd.get_data(), 128)
# plv = PLV_stft(1, phase)
plv.shape
del(plv)

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)


In [34]:
accs, cv_means, cv_stds = [],[],[]
accs_v, cv_means_v, cv_stds_v = [],[],[]
reports = []
for filename in (pbar := tqdm(dataset.get_file_list())):
    start = time.time()
    pbar.set_description(filename)
    data, labels, groups = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
    
    data_csd = mne.preprocessing.compute_current_source_density(data)

    plv = calculate_plv(data_csd.get_data(), n_jobs=8)

    _,acc,cross,train_report = train_model(plv, labels.squeeze(), groups, filename=filename, return_text=True)
    report = f"\tAROUSAL-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs.append(acc)
    cv_means.append(cross.mean())
    cv_stds.append(cross.std())

    _, labels_v, groups_v = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_VALENCE, return_type='mne', sfreq=128)
    _,acc,cross,train_report = train_model(plv, labels_v.squeeze(), groups_v, filename=filename, return_text=True)
    report = f"\tVALENCE-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs_v.append(acc)
    cv_means_v.append(cross.mean())
    cv_stds_v.append(cross.std())
    # break
    # count+=1
    # if(count == 5): break
report = f"AROUSAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}"
print(report)
reports.append(report)
report = f"VALENCE|Acc={sum(accs_v)/len(accs_v)}|10-CV={sum(cv_means_v)/len(cv_means_v)}|STD={sum(cv_stds_v)/len(cv_stds_v)}"
print(report)
reports.append(report)

  0%|          | 0/32 [00:00<?, ?it/s]

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
	AROUSAL-s09|Acc=0.73958|10-CV=0.674|STD=0.01403|Time spend=948.3420503139496
	VALENCE-s09|Acc=0.77042|10-CV=0.734|STD=0.01515|Time spend=1086.4108264446259
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
(300, 32, 65, 23)
	AROUSAL-s19|Acc=0.73792|10-CV=0.69133|STD=0.00407|Time spend=277.81249713897705
	VALENCE-s19|Acc=0.78708|10-CV=0.656|STD=0.01352|Time spend=353.21989727020264
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 4

In [35]:
for text in reports:
    if(text == None): continue
    print(text)

accs.extend(accs_v)
cv_means.extend(cv_means_v)
cv_stds.extend(cv_stds_v)
print(f"TOTAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}")

	AROUSAL-s09|Acc=0.73958|10-CV=0.674|STD=0.01403|Time spend=948.3420503139496
	VALENCE-s09|Acc=0.77042|10-CV=0.734|STD=0.01515|Time spend=1086.4108264446259
	AROUSAL-s19|Acc=0.73792|10-CV=0.69133|STD=0.00407|Time spend=277.81249713897705
	VALENCE-s19|Acc=0.78708|10-CV=0.656|STD=0.01352|Time spend=353.21989727020264
	AROUSAL-s17|Acc=0.75208|10-CV=0.7015|STD=0.00804|Time spend=110.3167929649353
	VALENCE-s17|Acc=0.67833|10-CV=0.60933|STD=0.02265|Time spend=187.76707768440247
	AROUSAL-s28|Acc=0.68042|10-CV=0.59133|STD=0.01183|Time spend=129.9256203174591
	VALENCE-s28|Acc=0.73417|10-CV=0.6705|STD=0.0096|Time spend=198.33644199371338
	AROUSAL-s07|Acc=0.82833|10-CV=0.6835|STD=0.01237|Time spend=120.8293342590332
	VALENCE-s07|Acc=0.82292|10-CV=0.7425|STD=0.00889|Time spend=187.26276850700378
	AROUSAL-s11|Acc=0.67583|10-CV=0.64433|STD=0.00876|Time spend=122.75873184204102
	VALENCE-s11|Acc=0.69042|10-CV=0.6525|STD=0.00911|Time spend=195.70947551727295
	AROUSAL-s25|Acc=0.725|10-CV=0.725|STD=0.0|T

### 2.4 $ \text{PLI}(j,k) =  \frac{1}{T} | \Sigma^{T}_{t=1} \text{sign}(Im[e^{i (\phi^{t}_{j} - \phi^{t}_{k})}])  | $

In [36]:
# bands = [(0,4), (4,8), (8,12), (12,30), (30,64)]
def calculate_stft(signals, sfreq):
    from scipy import signal
    f_range, t_range, Z = signal.stft(signals, sfreq, nperseg=sfreq//10, nfft=sfreq)
    magnitude = np.abs(Z) 
    phase = np.angle(Z)
    return magnitude, phase, f_range, t_range

def PLI_stft(p_id, phase):
    # (32, 65, 12) => (channels, freqs, times)
    # print(f"p_id:{p_id} - data to run {phase.shape}")
    from itertools import combinations
    pli = []
    # count = 0
    for index in range(phase.shape[0]):
        pli_epoch = []
        for comb in combinations(list(range(phase.shape[1])), 2):
            # shape = (65,12)
            phase_a, phase_b = phase[index, comb[0]], phase[index, comb[1]]
            phase_diff = phase_a - phase_b
            # sum along the time size
            pli_ab = np.abs(np.average(   np.sign(np.imag(   np.exp(complex(0,1) * phase_diff) ))  , axis=1))
            pli_epoch.append(pli_ab)
        pli_epoch = np.vstack(pli_epoch)
        pli_epoch_5 = np.concatenate([ pli_epoch[:,0:4].mean(axis=1).reshape(-1,1),
                                        pli_epoch[:,4:8].mean(axis=1).reshape(-1,1),
                                        pli_epoch[:,8:12].mean(axis=1).reshape(-1,1),
                                        pli_epoch[:,12:30].mean(axis=1).reshape(-1,1),
                                        pli_epoch[:,30:65].mean(axis=1).reshape(-1,1)], axis=0)
        pli.append(np.expand_dims(pli_epoch_5, axis=0))
        # count += 1
        # if(count == 3): break
    # shape (496, 65)
    # 496 is number of pairs that is not duplicate
    # 65 is number of phase of frequencies
    pli = np.vstack( pli )
    return pli.squeeze()

def calculate_pli(data, n_jobs=8):
    """ 
    Input: Expect data to have (n_epochs, n_channels, n_samples)
    Output: (n_epochs, n_conn, n_freqs ) => n_conn = n_channels!/(2!(n_channels-2)!)
    """
    from multiprocessing import Pool

    try:
        t_out = 60000
        pool = Pool()
        p_list = []
        ans_list = []
        num_p = n_jobs
        indices = np.array_split(np.arange(data.shape[0]), num_p)
        _, phase, _, _ = calculate_stft(data, 128)
        for p_id in range(num_p):
            p_list.append(pool.apply_async(PLI_stft, [p_id, phase[indices[p_id]] ]))
        for p_id in range(num_p):
            ans_list.append( p_list[p_id].get(timeout=t_out) )
        # ans_list
    finally:
        print("========= close ========")
        pool.close() 
        pool.terminate()
    return np.vstack(ans_list)


In [37]:
data, labels, groups = dataset.get_data('s01', stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
data_csd = mne.preprocessing.compute_current_source_density(data)
pli = calculate_pli(data_csd.get_data(), n_jobs=8)
# _, phase, _, _ = calculate_stft(data_csd.get_data(), 128)
# pli = pli_stft(1, phase)
print(pli.shape)
del(pli)

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
(2400, 2480)


In [38]:
accs, cv_means, cv_stds = [],[],[]
accs_v, cv_means_v, cv_stds_v = [],[],[]
reports = []
for filename in (pbar := tqdm(dataset.get_file_list())):
    start = time.time()
    pbar.set_description(filename)
    data, labels, groups = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_AROUSAL, return_type='mne', sfreq=128)
    
    data_csd = mne.preprocessing.compute_current_source_density(data)

    pli = calculate_pli(data_csd.get_data(), n_jobs=8)

    _,acc,cross,train_report = train_model(pli, labels.squeeze(), groups, filename=filename, return_text=True)
    report = f"\tAROUSAL-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs.append(acc)
    cv_means.append(cross.mean())
    cv_stds.append(cross.std())

    _, labels_v, groups_v = dataset.get_data(filename, stimuli=Dataset_subjectDependent.STIMULI_VALENCE, return_type='mne', sfreq=128)
    _,acc,cross,train_report = train_model(pli, labels_v.squeeze(), groups_v, filename=filename, return_text=True)
    report = f"\tVALENCE-{filename}|Acc={round(acc,5)}|10-CV={round(cross.mean(),5)}|STD={round(cross.std(),5)}|Time spend={time.time() - start}"
    print(report)
    reports.append(report)
    reports.append(train_report)
    accs_v.append(acc)
    cv_means_v.append(cross.mean())
    cv_stds_v.append(cross.std())
    # break
    # count+=1
    # if(count == 5): break
report = f"AROUSAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}"
print(report)
reports.append(report)
report = f"VALENCE|Acc={sum(accs_v)/len(accs_v)}|10-CV={sum(cv_means_v)/len(cv_means_v)}|STD={sum(cv_stds_v)/len(cv_stds_v)}"
print(report)
reports.append(report)

  0%|          | 0/32 [00:00<?, ?it/s]

Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s09|Acc=0.76792|10-CV=0.59533|STD=0.00314|Time spend=132.70519638061523
	VALENCE-s09|Acc=0.90667|10-CV=0.53883|STD=0.02438|Time spend=218.1479251384735
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s19|Acc=0.71542|10-CV=0.67483|STD=0.0005|Time spend=121.75163507461548
	VALENCE-s19|Acc=0.82833|10-CV=0.56333|STD=0.00901|Time spend=204.59139943122864
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0.8 15.1 45.3 mm
	AROUSAL-s17|Acc=0.80292|10-CV=0.61033|STD=0.00983|Time spend=132.16715049743652
	VALENCE-s17|Acc=0.90583|10-CV=0.55|STD=0.0169|Time spend=218.91054964065552
Fitted sphere radius:         95.3 mm
Origin head coordinates:      -0.8 15.1 45.3 mm
Origin device coordinates:    -0

In [39]:
for text in reports:
    if(text == None): continue
    print(text)

accs.extend(accs_v)
cv_means.extend(cv_means_v)
cv_stds.extend(cv_stds_v)
print(f"TOTAL|Acc={sum(accs)/len(accs)}|10-CV={sum(cv_means)/len(cv_means)}|STD={sum(cv_stds)/len(cv_stds)}")

	AROUSAL-s09|Acc=0.76792|10-CV=0.59533|STD=0.00314|Time spend=132.70519638061523
	VALENCE-s09|Acc=0.90667|10-CV=0.53883|STD=0.02438|Time spend=218.1479251384735
	AROUSAL-s19|Acc=0.71542|10-CV=0.67483|STD=0.0005|Time spend=121.75163507461548
	VALENCE-s19|Acc=0.82833|10-CV=0.56333|STD=0.00901|Time spend=204.59139943122864
	AROUSAL-s17|Acc=0.80292|10-CV=0.61033|STD=0.00983|Time spend=132.16715049743652
	VALENCE-s17|Acc=0.90583|10-CV=0.55|STD=0.0169|Time spend=218.91054964065552
	AROUSAL-s28|Acc=0.92542|10-CV=0.53167|STD=0.0119|Time spend=134.61193656921387
	VALENCE-s28|Acc=0.84375|10-CV=0.62333|STD=0.00279|Time spend=220.23092770576477
	AROUSAL-s07|Acc=0.9125|10-CV=0.623|STD=0.00515|Time spend=129.0194935798645
	VALENCE-s07|Acc=0.8475|10-CV=0.7|STD=0.0|Time spend=205.66868090629578
	AROUSAL-s11|Acc=0.77583|10-CV=0.624|STD=0.00226|Time spend=132.0025942325592
	VALENCE-s11|Acc=0.78792|10-CV=0.597|STD=0.00393|Time spend=215.4134566783905
	AROUSAL-s25|Acc=0.74958|10-CV=0.725|STD=0.0|Time spen