Author:
        
        PARK, JunHo, junho@ccnets.org

        Kim, Jinsu 

        KIM, JeongYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.

In [1]:
import sys
path_append = "../"
sys.path.append(path_append)  # Go up one directory from where you are.

from nn.utils.init import set_random_seed
set_random_seed(0)

In [2]:

# https://github.com/N-Nieto/Inner_Speech_Dataset

# Load the Inner Speech Dataset
# =============================
# This dataset comprises raw EEG data collected from subject 'sub-01' during session 'ses-01'.
# Source: https://github.com/N-Nieto/Inner_Speech_Dataset
#
# Overview:
# - The dataset is part of a study on inner speech, capturing brain activity via EEG.
# - Each row in the dataset corresponds to a timestamp of EEG readings.
# - Columns represent various EEG channels (electrodes placed on the scalp).
#
# Usage:
# - The data is primarily used for cognitive neuroscience research, focusing on the neural correlates of inner speech.
# - Users can analyze EEG signals to investigate brain activity patterns associated with the cognitive processes of inner speech.
#
# File Structure:
# - Located at '../data/RAW_EEG/sub-01/sub-01_ses-01.csv' relative to this script.
# - It is advisable to preprocess the data (filtering, normalization) before detailed analysis.
#
# Example:
# - To load this data into a DataFrame for analysis and processing, use the following code snippet.

import pandas as pd

df = None
for csv in ["../data/RAW_EEG/sub-01/sub-01_ses-01.csv", "../data/RAW_EEG/sub-01/sub-01_ses-02.csv", "../data/RAW_EEG/sub-01/sub-01_ses-03.csv"]:
    tmp_df = pd.read_csv(path_append + csv)
    if df is None:
        df = tmp_df
    else:
        df = pd.concat([df, tmp_df])
df = df.reset_index(drop=True)
df


Unnamed: 0,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,...,D24,D25,D26,D27,D28,D29,D30,D31,D32,event
0,3549.790315,4533.538497,3619.665186,3077.291188,-1380.325575,6120.066816,-4072.820600,-2256.511456,1820.012261,-2815.635423,...,-7240.845997,7034.252627,8458.062496,5905.223463,6147.660515,2458.073582,-7465.876831,-3604.133966,-5445.224315,5
1,3551.227812,4534.850995,3622.540181,3077.322438,-1377.575581,6123.066810,-4069.851856,-2252.167714,1825.168502,-2803.072947,...,-7227.283522,7039.627617,8463.874985,5911.598451,6153.504254,2463.354822,-7461.033090,-3594.258985,-5435.693082,5
2,3556.727802,4539.850986,3629.040169,3081.978679,-1370.419344,6130.348047,-4063.508118,-2249.292720,1828.074746,-2804.041695,...,-7227.158522,7048.502600,8473.562467,5921.348433,6163.004236,2469.854810,-7460.470591,-3591.540240,-5433.568086,5
3,3557.915300,4541.225983,3628.540169,3083.197427,-1372.263090,6130.410547,-4062.070620,-2251.667715,1825.856000,-2803.572946,...,-7224.189777,7042.346362,8464.593734,5917.660940,6160.972990,2467.011066,-7458.158095,-3597.008980,-5437.474329,5
4,3553.352808,4535.757243,3622.477681,3079.572434,-1377.763080,6125.598056,-4066.570612,-2255.136459,1821.981008,-2808.041687,...,-7219.971035,7044.658857,8466.843729,5914.848445,6156.785498,2466.948566,-7457.501846,-3585.821500,-5428.630595,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
588018,-623.326974,2269.261431,2575.479615,285.733846,907.388947,-491.014719,-2998.447586,1886.043389,1659.637557,416.296105,...,-7176.689865,2116.667963,-901.138961,-227.327706,-657.170662,3025.322534,-12313.149124,-3810.071086,-5620.505241,10
588019,-627.420717,2264.448940,2570.323375,281.077605,903.482705,-490.702219,-3001.260080,1884.387142,1657.012562,414.702358,...,-7179.502360,2118.074210,-900.607712,-227.046456,-659.389408,3027.760030,-12307.211635,-3809.946086,-5621.098990,10
588020,-631.764459,2260.730197,2566.917131,275.546365,902.045207,-493.545964,-3006.103821,1886.199639,1658.512560,424.202340,...,-7177.439864,2118.199210,-900.920211,-226.140208,-659.764407,3027.103781,-12305.774138,-3805.633594,-5614.880251,10
588021,-625.076971,2265.605188,2573.354619,281.702604,904.982702,-490.795969,-3001.416330,1888.387135,1659.418808,420.077348,...,-7172.002374,2119.730457,-898.170216,-224.515211,-656.576913,3032.822520,-12303.742892,-3804.133597,-5614.192752,10


In [3]:
import pandas as pd

# Example setup, assuming df is defined as a DataFrame
# Assuming df['event'] contains the class labels

# Print counts of each class in the 'event' column
event_counts = df['event'].value_counts()
print("Counts of each class in the 'event' column:")
print(event_counts)

# Print the maximum class number
max_class_number = df['event'].max()
print("\nMaximum class number:")
print(max_class_number)

# Determine the number of classes
print("\nExpected number of classes (from num_classes variable):", max_class_number + 1)

# Reset the class labels to be in the range [0, num_classes-1]
unique_classes = sorted(df['event'].unique())
class_mapping = {old_class: new_class for new_class, old_class in enumerate(unique_classes)}

df['event'] = df['event'].map(class_mapping)
print("\nDataFrame with reset class labels:")
print(df)

# Verify the new counts and max class number
new_event_counts = df['event'].value_counts()
new_max_class_number = df['event'].max()
num_classes = new_max_class_number + 1

print("\nNew counts of each class in the 'event' column:")
print(new_event_counts)
print("\nNew maximum class number:")
print(new_max_class_number)
print("\nNew expected number of classes (from num_classes variable):", new_max_class_number + 1)


Counts of each class in the 'event' column:
1     57650
0     57650
3     57650
2     57650
13    57650
12    57650
11    57650
10    57650
6     28825
9     28825
8     28825
7     28825
5     11523
Name: event, dtype: int64

Maximum class number:
13

Expected number of classes (from num_classes variable): 14

DataFrame with reset class labels:
                 A1           A2           A3           A4           A5  \
0       3549.790315  4533.538497  3619.665186  3077.291188 -1380.325575   
1       3551.227812  4534.850995  3622.540181  3077.322438 -1377.575581   
2       3556.727802  4539.850986  3629.040169  3081.978679 -1370.419344   
3       3557.915300  4541.225983  3628.540169  3083.197427 -1372.263090   
4       3553.352808  4535.757243  3622.477681  3079.572434 -1377.763080   
...             ...          ...          ...          ...          ...   
588018  -623.326974  2269.261431  2575.479615   285.733846   907.388947   
588019  -627.420717  2264.448940  2570.323375   281.

In [4]:
# Assuming df is defined and already includes an 'event' column
# Assuming 'event' column contains class labels
event_changes = df['event'].diff().ne(0)
change_indices = event_changes[event_changes].index.tolist()

# Calculate and print lengths between changes
lengths_between_changes = [change_indices[i] - change_indices[i-1] for i in range(1, len(change_indices))]

# Find the minimum cycle length where the label changes
min_cycle_length = min(lengths_between_changes)

print("Indices where the 'event' label changes:", change_indices)
print("Lengths between changes:", lengths_between_changes)
print(f"Minimum cycle length: {min_cycle_length}")

Indices where the 'event' label changes: [0, 3841, 4994, 6147, 8453, 9606, 10759, 11912, 13065, 14218, 15371, 17677, 18830, 19983, 21136, 23442, 24595, 25748, 26901, 28054, 29207, 30360, 31513, 33819, 34972, 36125, 38431, 39584, 40737, 41890, 45349, 46502, 48808, 49961, 52267, 55726, 56879, 58032, 60338, 61491, 62644, 63797, 66103, 67256, 68409, 69562, 70715, 71868, 73021, 74174, 75327, 77633, 78786, 81092, 82245, 83398, 85704, 86857, 88010, 89163, 90316, 91469, 92622, 94928, 96081, 98387, 101846, 102999, 104152, 106458, 107611, 108764, 109917, 112223, 113376, 114529, 115682, 116835, 117988, 119141, 120294, 121447, 123753, 124906, 127212, 128365, 129518, 131824, 132977, 134130, 135283, 136436, 137589, 138742, 141048, 142201, 143354, 144507, 146813, 149119, 150272, 151425, 152578, 153731, 157190, 158343, 160649, 161802, 162955, 164108, 165261, 166414, 168720, 169873, 172179, 173332, 175638, 176791, 179097, 180250, 181403, 182556, 183709, 184862, 186015, 187168, 188321, 190627, 192933, 1

In [5]:
segment_pairs = []
for start, end in zip(change_indices[:-1], change_indices[1:]):
    segment_length = end - start
    if segment_length >= min_cycle_length and segment_length % min_cycle_length == 0:
        # Normalize each sub-segment within the main segment
        for offset in range(0, segment_length, min_cycle_length):
            sub_start = start + offset
            sub_end = sub_start + min_cycle_length
            segment_pairs.append((sub_start, sub_end))
    else:
        irregular_num = segment_length//min_cycle_length
        # Normalize each sub-segment within the main segment
        for i in range(irregular_num):
            sub_start = start + i * min_cycle_length
            if i == irregular_num - 1:
                sub_end = end
            else:
                sub_end = sub_start + min_cycle_length
            segment_pairs.append((sub_start, sub_end))

In [6]:
import torch
seq_lengths = []#234440, 234440, 119140
tmp = df.event.diff(1).dropna()
seq_lengths = [0] + tmp[tmp!=0].index.to_list()
seq_lengths = torch.tensor(seq_lengths)

In [7]:

import torch
import pandas as pd
from sklearn.preprocessing import StandardScaler

def process_dataframe(df, segment_pairs, use_scale = False, include_diff=False, window_size=1):
    """
    Process the DataFrame by applying standard scaling and calculating differences, 
    returning the final DataFrame.
    
    Parameters:
    - df (pd.DataFrame): The input DataFrame.
    - segment_pairs (list of tuples): List of (start, end) index pairs.
    - window_size (int): The window size for difference calculation.
    - include_diff (bool): Whether to include difference calculations in the final DataFrame.

    Returns:
    - final_df (pd.DataFrame): The processed DataFrame.
    - new_segment_pairs (list of tuples): Updated segment pairs after processing.
    """
    df_tensor = torch.tensor(df.values, dtype=torch.float64).cuda()
    window_size = window_size if include_diff else 0
    
    df_list_diff = []
    df_list_y = []
    df_list_x = []
    cur_idx = 0
    new_segment_pairs = []

    for i in range(len(segment_pairs)):
        start, end = segment_pairs[i]

        # Extract segments and apply scaling
        x = df_tensor[start + window_size:end - window_size, :-1]
        y = df_tensor[start + window_size:end - window_size, -1:]

        x_np = x.cpu().numpy()
        scaler = StandardScaler()
        if use_scale:
            x_scaled = torch.tensor(scaler.fit_transform(x_np), dtype=torch.float64).cuda()
        else:
            x_scaled = x

        if include_diff:
            diff = df_tensor[start + 2 * window_size:end, :-1] - df_tensor[start:end - 2 * window_size, :-1]
            diff_np = diff.cpu().numpy()
            diff_scaled = torch.tensor(scaler.fit_transform(diff_np), dtype=torch.float64).cuda()
            df_list_diff.append(diff_scaled)

        add_len = (end - start - 2 * window_size)
        new_segment_pairs.append((cur_idx, cur_idx + add_len))
        cur_idx += add_len

        df_list_y.append(y)
        df_list_x.append(x_scaled)

    df_y_tensor = torch.cat(df_list_y, dim=0)
    df_x_tensor = torch.cat(df_list_x, dim=0)

    df_y = pd.DataFrame(df_y_tensor.cpu().numpy())
    df_x = pd.DataFrame(df_x_tensor.cpu().numpy())

    if include_diff:
        df_diff_tensor = torch.cat(df_list_diff, dim=0)
        df_diff = pd.DataFrame(df_diff_tensor.cpu().numpy())
        new_column_names = ['diff_' + name for name in df.columns[:-1]]
        df_diff.columns = new_column_names

    print("df_x shape:", df_x.shape)
    if include_diff:
        print("df_diff shape:", df_diff.shape)
    print("df_y shape:", df_y.shape)

    if include_diff:
        final_df = pd.concat([df_x, df_diff, df_y], axis=1)
        final_column_names = list(df.columns[:-1]) + new_column_names + [df.columns[-1]]
    else:
        final_df = pd.concat([df_x, df_y], axis=1)
        final_column_names = list(df.columns[:-1]) + [df.columns[-1]]

    final_df.columns = final_column_names

    return final_df, new_segment_pairs

# Example usage
window_size = 1

# Assuming df and segment_pairs are defined appropriately
df, segment_pairs = process_dataframe(df, segment_pairs, use_scale=True, include_diff=False, window_size=window_size)

df_x shape: (585717, 128)
df_y shape: (585717, 1)


In [8]:
num_features = df.shape[1] - 1

In [9]:
df[:5], df[-5:]

(         A1        A2        A3        A4        A5        A6        A7  \
 0 -2.489225 -2.277009 -2.514475 -2.188626 -0.665613 -1.346316 -0.058813   
 1 -2.276914 -2.102834 -2.132174 -2.184048 -0.334188 -0.958881  0.323031   
 2 -1.464595 -1.439312 -1.267841 -1.501895  0.528270 -0.018544  1.138970   
 3 -1.289208 -1.256844 -1.334328 -1.323345  0.306064 -0.010473  1.323863   
 4 -1.963064 -1.982571 -2.140485 -1.854417 -0.356786 -0.631983  0.745068   
 
          A8        A9       A10  ...       D24       D25       D26       D27  \
 0  0.063280  0.040420 -1.023329  ... -1.910036 -1.004736 -0.560572 -1.682052   
 1  0.508772  0.570520 -0.405698  ... -1.422319 -0.624488  0.064674 -0.879682   
 2  0.803630  0.869303 -0.453326  ... -1.417824  0.003363  1.106749  0.347471   
 3  0.560051  0.641200 -0.430280  ... -1.311066 -0.432153  0.141989 -0.116645   
 4  0.204299  0.242821 -0.649985  ... -1.159357 -0.268558  0.384020 -0.470631   
 
         D28       D29       D30       D31       D32  

In [10]:
import torch
from torch.utils.data import Dataset
import random

class EEG_Dataset(Dataset):
    def __init__(self, df, indices, max_window_size, num_classes, normalize = False):
        self.df = df
        self.indices = indices  # List of start indices
        self.max_window_size = max_window_size
        self.min_window_size = max_window_size // 2
        self.num_classes = num_classes  # Added num_classes as an instance variable
        self.normalize = normalize

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        # Randomly choose a window size between min_window_size and max_window_size
        window_size = random.randint(self.min_window_size, self.max_window_size)
        
        end_idx = min(start_idx + window_size, len(self.df))

        # Retrieve the sequence using the calculated indices
        seq = self.df.iloc[start_idx:end_idx]
        X, y = seq.values[:, :-1], seq.values[:, -1]
        
        # Convert to PyTorch tensors
        X = torch.tensor(X, dtype=torch.float32)

        # Normalize X
        if self.normalize:
            X = (X - X.mean(dim = -1, keepdim = True)) / (X.std(dim = -1, keepdim = True) + 1e-8)
        
        y = torch.tensor(y, dtype=torch.long)  # Ensure y is a tensor of type long
        y = torch.nn.functional.one_hot(y, num_classes=self.num_classes)  # Correct use with instance variable
        return X, y


In [11]:
from random import shuffle

# Assume 'df' is your DataFrame and 'event' is the column containing labels

# Generate indices without mixing segments
def generate_indices(input_df, input_pairs, max_window_size, test_size=0.2):
    length = len(input_pairs)
    train_length = int(length * (1- test_size))
    training_indices = []
    testing_indices = []
    for iter, (start, end) in enumerate(input_pairs):
        indices = training_indices if iter < train_length else testing_indices
        max_index = end - max_window_size  # Calculate the maximum starting index for this segment
        for i in range(start, max_index):
            # Check if all labels in the window are the same
            if len(input_df['event'][i:i + max_window_size].unique()) == 1:
                indices.append(i)
            else:
                print(f"Skipping index {i} due to multiple labels in window.")
    return training_indices, testing_indices

# Example usage
max_window_size = 128
shuffle(segment_pairs)  # Shuffle the indices to randomize the data order
train_indices, test_indices = generate_indices(df, segment_pairs, max_window_size)

shuffle(train_indices)  # Shuffle the indices to randomize the data order
shuffle(test_indices)  # Shuffle the indices to randomize the data order

# Assuming you have an EEG_Dataset class defined as before
trainset = EEG_Dataset(df=df, indices=train_indices, max_window_size=max_window_size, num_classes = num_classes, normalize=False)
testset = EEG_Dataset(df=df, indices=test_indices, max_window_size=max_window_size, num_classes = num_classes, normalize=False)

In [12]:
def check_overlap(train_indices, test_indices, window_size, trials = None):
    # Generate train and test window ranges
    train_windows = [(idx, idx + window_size) for idx in train_indices]
    test_windows = [(idx, idx + window_size) for idx in test_indices]
    
    # Sort train and test windows by their start indices
    train_windows.sort()
    test_windows.sort()

    overlaps = []
    train_idx = 0
    trials = len(test_windows) if trials is None else trials

    # Check for overlapping windows
    for idx, (t_start, t_end) in enumerate(test_windows):
        if idx > trials:
            break
        # Advance the train_idx to the relevant window range
        while train_idx < len(train_windows) and train_windows[train_idx][1] <= t_start:
            train_idx += 1

        # Check if the current test window overlaps with any train window
        for tr_start, tr_end in train_windows[train_idx:]:
            if tr_start < t_end and tr_end > t_start:  # Overlapping condition
                overlaps.append((tr_start, tr_end, t_start, t_end))
                print (f"Overlap found between train window {tr_start}-{tr_end} and test window {t_start}-{t_end}")
                break  # No need to check further once an overlap is found

    return overlaps

# Check for overlaps
overlaps = check_overlap(train_indices, test_indices, max_window_size, trials = 100)
print("Number of overlapping windows:", len(overlaps))
if overlaps:
    print("Example of overlapping windows:", overlaps[:5])  # Print the first 5 overlapping windows


Number of overlapping windows: 0


In [13]:
from tools.setting.data_config import DataConfig
from tools.setting.ml_params import MLParameters
from trainer_hub import TrainerHub
import torch

data_config = DataConfig(dataset_name = 'eeg-sub-01', task_type='multi_class_classification', obs_shape=[num_features], label_size=num_classes)

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters(core_model = 'gpt', encoder_model = 'none')

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False) 

In [14]:
trainer_hub.train(trainset, testset)

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Iterations:   0%|          | 0/6504 [00:00<?, ?it/s]

[0/100][50/6504][Time 19.53]
Unified LR across all optimizers: 0.0001995308238189185
--------------------Training Metrics--------------------
Trainer:  gpt
Inf: 0.3163	Gen: 0.6635	Rec: 0.6063	E: 0.3732	R: 0.2592	P: 0.9532
--------------------Test Metrics------------------------
accuracy: 0.4594
precision: 0.0656
recall: 0.1400
f1_score: 0.0786

[0/100][100/6504][Time 18.39]
Unified LR across all optimizers: 0.00019907191565870155
--------------------Training Metrics--------------------
Trainer:  gpt
Inf: 0.1598	Gen: 0.4757	Rec: 0.4501	E: 0.1858	R: 0.1348	P: 0.7675
--------------------Test Metrics------------------------
accuracy: 0.5156
precision: 0.0773
recall: 0.1468
f1_score: 0.0882

[0/100][150/6504][Time 18.18]
Unified LR across all optimizers: 0.00019861406295796434
--------------------Training Metrics--------------------
Trainer:  gpt
Inf: 0.1320	Gen: 0.4314	Rec: 0.4141	E: 0.1495	R: 0.1153	P: 0.7132
--------------------Test Metrics------------------------
accuracy: 0.4453
precis

In [None]:
trainer_hub.test(testset)