In [1]:
from torcheeg.transforms.numpy.correlation import PearsonCorrelation
from torcheeg.datasets import BCICIV2aDataset
from torcheeg import transforms

SAMPLING_RATE = 250
SEQ_LENGTH = 500
DT=25

dataset = BCICIV2aDataset(
    root_path='../datasets/bci_c',
    io_path='.torcheeg/datasets_biciv_2a_correlation',
    chunk_size=SEQ_LENGTH,
    online_transform=transforms.PickElectrode([])
)

feature_extraction_ds = BCICIV2aDataset(
    root_path='../datasets/bci_c',
    io_path='.torcheeg/datasets_biciv_2a_feature_extraction',
    chunk_size=SEQ_LENGTH,
    overlap=SEQ_LENGTH - DT,
    num_worker=6,
)

[2024-02-17 23:08:11] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from .torcheeg/datasets_biciv_2a_correlation.
[2024-02-17 23:08:11] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from .torcheeg/datasets_biciv_2a_feature_extraction.


In [None]:
# test pick channels


In [2]:
import numpy as np

samples = [sample for sample, _ in dataset]
samples = np.concatenate(samples, axis=-1)

In [3]:
# Correlation matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

df = pd.DataFrame(samples.transpose())
corr = df.corr()
# _, ax = plt.subplots(figsize=(16, 16))
# sns.heatmap(corr, annot=True, ax=ax)


In [4]:
corr.style.background_gradient(cmap='coolwarm').format(precision=2)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21
0,1.0,0.91,0.93,0.93,0.93,0.91,0.75,0.78,0.82,0.82,0.83,0.78,0.74,0.64,0.7,0.69,0.7,0.67,0.56,0.57,0.56,0.43
1,0.91,1.0,0.96,0.93,0.89,0.87,0.89,0.93,0.91,0.88,0.84,0.78,0.72,0.81,0.81,0.78,0.76,0.7,0.68,0.66,0.64,0.53
2,0.93,0.96,1.0,0.97,0.95,0.9,0.85,0.9,0.95,0.93,0.91,0.82,0.76,0.79,0.85,0.82,0.81,0.74,0.69,0.7,0.67,0.55
3,0.93,0.93,0.97,1.0,0.97,0.93,0.78,0.86,0.92,0.94,0.92,0.85,0.78,0.75,0.81,0.82,0.81,0.75,0.66,0.68,0.66,0.53
4,0.93,0.89,0.95,0.97,1.0,0.97,0.77,0.83,0.91,0.93,0.96,0.9,0.85,0.74,0.82,0.83,0.85,0.81,0.67,0.7,0.7,0.56
5,0.91,0.87,0.9,0.93,0.97,1.0,0.73,0.79,0.85,0.88,0.93,0.93,0.89,0.7,0.76,0.79,0.84,0.83,0.65,0.68,0.7,0.55
6,0.75,0.89,0.85,0.78,0.77,0.73,1.0,0.93,0.88,0.79,0.77,0.7,0.68,0.88,0.84,0.76,0.72,0.66,0.73,0.69,0.65,0.58
7,0.78,0.93,0.9,0.86,0.83,0.79,0.93,1.0,0.95,0.89,0.84,0.77,0.71,0.94,0.91,0.86,0.81,0.73,0.8,0.76,0.72,0.63
8,0.82,0.91,0.95,0.92,0.91,0.85,0.88,0.95,1.0,0.96,0.93,0.84,0.77,0.9,0.95,0.92,0.88,0.8,0.82,0.82,0.78,0.67
9,0.82,0.88,0.93,0.94,0.93,0.88,0.79,0.89,0.96,1.0,0.96,0.88,0.79,0.84,0.92,0.94,0.91,0.83,0.79,0.82,0.8,0.67


In [6]:
def get_top_correlations(df, threshold=0.4):
    """
    df: the dataframe to get correlations from
    threshold: the maximum and minimum value to include for correlations. For eg, if this is 0.4, only pairs haveing a correlation coefficient greater than 0.4 or less than -0.4 will be included in the results. 
    """
    orig_corr = df.corr()
    c = orig_corr.abs()

    so = c.unstack()

    print("|    Variable 1    |    Variable 2    | Correlation Coefficient    |")
    print("|------------------|------------------|----------------------------|")
    
    i=0
    pairs=set()
    result = pd.DataFrame()
    for index, value in so.sort_values(ascending=False).items():
        # Exclude duplicates and self-correlations
        if value > threshold \
        and index[0] != index[1] \
        and (index[0], index[1]) not in pairs \
        and (index[1], index[0]) not in pairs:
            
            print(f'|    {index[0]}    |    {index[1]}    |    {orig_corr.loc[(index[0], index[1])]}    |')
            result.loc[i, ['Variable 1', 'Variable 2', 'Correlation Coefficient']] = [index[0], index[1], orig_corr.loc[(index[0], index[1])]]
            pairs.add((index[0], index[1]))
            i+=1
    return result.reset_index(drop=True).set_index(['Variable 1', 'Variable 2'])


thresholded_corr = get_top_correlations(corr, threshold=0.85)


|    Variable 1    |    Variable 2    | Correlation Coefficient    |
|------------------|------------------|----------------------------|
|    2    |    3    |    0.9733597093218412    |
|    1    |    2    |    0.9692441775827345    |
|    3    |    4    |    0.9690379499313186    |
|    18    |    19    |    0.9682252941968382    |
|    4    |    5    |    0.9669950567890588    |
|    19    |    20    |    0.9662997863457256    |
|    19    |    21    |    0.9655171457894206    |
|    12    |    11    |    0.9603220194631279    |
|    3    |    0    |    0.9587000186538038    |
|    21    |    20    |    0.9489536511277675    |
|    18    |    21    |    0.9488435514504877    |
|    6    |    7    |    0.9475564103974222    |
|    21    |    0    |    -0.9438175224082204    |
|    0    |    2    |    0.9387367888644301    |
|    0    |    4    |    0.9369301030299947    |
|    16    |    17    |    0.9135738671816046    |
|    4    |    2    |    0.9017011979887254    |
|    3    |  

|    Variable 1    |    Variable 2    | Correlation Coefficient    |
|------------------|------------------|----------------------------|
|    2    |    3    |    0.9733597093218412    |
|    1    |    2    |    0.9692441775827345    |
|    3    |    4    |    0.9690379499313186    |
|    18    |    19    |    0.9682252941968382    |
|    4    |    5    |    0.9669950567890588    |
|    19    |    20    |    0.9662997863457256    |
|    19    |    21    |    0.9655171457894206    |
|    12    |    11    |    0.9603220194631279    |
|    3    |    0    |    0.9587000186538038    |
|    21    |    20    |    0.9489536511277675    |
|    18    |    21    |    0.9488435514504877    |
|    6    |    7    |    0.9475564103974222    |
|    21    |    0    |    -0.9438175224082204    |
|    0    |    2    |    0.9387367888644301    |
|    0    |    4    |    0.9369301030299947    |
|    16    |    17    |    0.9135738671816046    |
|    4    |    2    |    0.9017011979887254    |
|    3    |    1    |    0.9007672840559301    |
|    13    |    14    |    0.8946996628580193    |
|    1    |    0    |    0.8916826656063631    |
|    15    |    14    |    0.8908036036152203    |
|    5    |    3    |    0.8870282910309479    |
|    8    |    7    |    0.8861516711981836    |
|    18    |    20    |    0.8855054440122953    |
|    9    |    8    |    0.8840697713276787    |
|    5    |    0    |    0.8810512747060361    |
|    18    |    0    |    -0.8775799319014917    |
|    11    |    10    |    0.8772332457296156    |
|    1    |    21    |    -0.8611490771892903    |
|    0    |    19    |    -0.8595860716173214    |
|    10    |    9    |    0.8544486782346973    |
|    2    |    21    |    -0.8526743807620805    |
|    21    |    3    |    -0.849215314686478    |
|    20    |    0    |    -0.8489561188060041    |
|    15    |    16    |    0.8374887060596466    |
|    16    |    20    |    0.8333164717467622    |
|    1    |    20    |    -0.8296570038463833    |
|    21    |    4    |    -0.8224395523144938    |
|    19    |    15    |    0.8084976006037484    |
|    4    |    10    |    0.8060531859846514    |
|    5    |    10    |    0.8054969643665818    |
|    2    |    8    |    0.803065648443115    |

In [29]:
thresholded_corr.sort_values(by='Variable 1').style.format(precision=3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Correlation Coefficient
Variable 1,Variable 2,Unnamed: 2_level_1
0.0,2.0,0.939
0.0,4.0,0.937
1.0,2.0,0.969
2.0,3.0,0.973
3.0,4.0,0.969
3.0,0.0,0.959
3.0,1.0,0.901
4.0,5.0,0.967
4.0,2.0,0.902
6.0,7.0,0.948


In [None]:
# Se escogen
[5,7,8,9,10,11,13,14,15,17,20]

