In [1]:
import csv
import os
import pickle
import re
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pywt
import seaborn as sns
from keras.utils import to_categorical
from scipy.integrate import simps
from scipy.signal import butter, filtfilt, iirnotch, welch
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from Brain_to_Image import batch_csv as batch
from Brain_to_Image import helper_functions as hf
from Brain_to_Image.dataset_formats import (MDB2022_MNIST_EP_params,
                                            MDB2022_MNIST_IN_params,
                                            MDB2022_MNIST_MU_params,
                                            keys_MNIST_EP, keys_MNIST_IN,
                                            keys_MNIST_MU)

#%matplotlib widget
sns.set(font_scale=1.2)
print(os.getcwd())

dataset = "MNIST_EP"
root_dir = f"Datasets/MindBigData MNIST of Brain Digits/{dataset}"
if True:
    # ## TRAIN
    input_file = f"train_MindBigData2022_{dataset}.csv"
    output_file = f"train_MindBigData2022_{dataset}.pkl"
else:
    ## TEST
    input_file = f"test_MindBigData2022_{dataset}.csv"
    output_file = f"test_MindBigData2022_{dataset}.pkl"

label = 'digit_label'
## MNIST_MU sf = 220, 440 samples , MNIST_EP sf = 128, 256 samples , MNIST_IN sf = 128, 256 samples
if "_EP" in dataset or "_IN" in dataset:
    sample_rate = 128  #Hz
else:
    sample_rate = 220  #Hz
# Define notch frequencies and widths
notch_freqs = [50] #, 60]  # Line noise frequencies (50 Hz and harmonics)
notch_widths = [1] #, 2]  # Notch widths (in Hz)
# Define filter parameters
lowcut = 0.4 # 0.4  # Low-cutoff frequency (Hz)
highcut = 60 # 110  # High-cutoff frequency (Hz)
class_labels = [0,1,2,3,4,5,6,7,8,9]
keys_ = ['EEGdata_T7','EEGdata_P7','EEGdata_T8','EEGdata_P8']

c:\Users\timta\Documents\Msc Notes\CMP9140-2324 Research Project


In [2]:
# function to Create windowed data based on window 32, overlap 4 = step size 28
def sliding_window_eeg(signal, window_size=32, overlap=4):
    """
    Apply a sliding window with overlap to a 2-second EEG signal.

    Parameters:
    signal (numpy.ndarray): 1D array of EEG signal data (256 samples)
    window_size (int): Size of each window (default: 32)
    overlap (int): Number of overlapping samples between windows (default: 4)

    Returns:
    numpy.ndarray: 2D array of windowed data
    """
    if len(signal) != 256:
        raise ValueError("Signal length must be 256 samples (2 seconds at 128Hz)")

    # Calculate the step size
    step = window_size - overlap

    # Calculate the number of windows
    num_windows = (len(signal) - window_size) // step + 1

    # Create an empty array to store the windowed data
    windowed_data = np.zeros((num_windows, window_size, 1))

    # Apply the sliding window
    for i in range(num_windows):
        start = i * step
        end = start + window_size
        windowed_data[i] = signal[start:end].reshape(window_size,1)
    return windowed_data

# w_data = sliding_window_eeg(df[df[label]==3].iloc[4]['EEGdata_AF3'],16,2)
# w_data.shape

In [3]:
## "epoch_filtered_-_car_"  method 1 CAR subtraction
## "epoch_filtered_corr_car_"  method 2 correlation with CAR

prefix = ["epoch_filtered_-_car_","epoch_filtered_corr_car_"]
print(f"** reading file {root_dir}/{prefix[0]}{output_file}")
df_data = pd.read_pickle(f"{root_dir}/{prefix[0]}{output_file}")
#df_data.info()
#print(df_data[label].value_counts())
#print(df_data.columns)
df_data.head()

** reading file Datasets/MindBigData MNIST of Brain Digits/MNIST_EP/epoch_filtered_-car_train_MindBigData2022_MNIST_EP.pkl


Unnamed: 0,digit_label,EEGdata_AF3,EEGdata_AF4,EEGdata_F7,EEGdata_F8,EEGdata_F3,EEGdata_F4,EEGdata_FC5,EEGdata_FC6,EEGdata_T7,...,EEGdata_FC6_corr,EEGdata_T7_corr,EEGdata_T8_corr,EEGdata_P7_corr,EEGdata_P8_corr,EEGdata_O1_corr,EEGdata_O2_corr,erp,corr_mean_core,corr_mean_all
16632,0,"[3.1210577445229046, 2.7629549594562386, 2.985...","[3.9544539358291426, 4.554076329880816, 5.9097...","[1.473359477646282, 0.9879922982243392, 0.1401...","[3.227326303261042, 3.7673950479124527, 5.1033...","[-0.21202161529048258, 1.741510653316555, 3.68...","[0.041984176722513464, 0.10421549012801781, 0....","[2.2976471548224033, -0.06576887594309166, -1....","[1.5529214242222467, -0.6857059599702771, -1.6...","[1.964180438045, -2.69910706404519, -4.8529714...",...,-0.224438,-0.282554,-0.028123,0.486802,0.204503,0.217455,0.585916,"[-1.8244634238813786, -2.4216427286282842, -2....",0.095157,-0.005996
20369,0,"[2.970862382127992, 5.411682218878089, 7.01719...","[2.909096984268401, 6.267995800301906, 9.27011...","[2.3082652995296176, 7.1643073544053335, 10.75...","[2.666189015477303, 7.483431058281065, 10.7173...","[2.259999377354528, 5.997679591100788, 8.51031...","[-1.3546460304167924, 2.422972581907988, 5.510...","[2.775382059807958, 3.545501984521303, 4.17997...","[4.321833281151439, 4.905600996675127, 5.69278...","[3.442581299527143, 5.624003476048447, 7.28532...",...,-0.275237,-0.18127,-0.042777,-0.155357,-0.394998,0.363705,0.002971,"[-4.73476099062582, -1.9169430190940515, 0.235...",-0.1936,-0.050898
27537,0,"[0.23843709029029636, 3.557810413048021, 6.597...","[-1.4654737776575768, 0.3521044029136666, 2.80...","[0.2326959801031503, 4.55761189567845, 8.21898...","[1.538302248208206, 5.988011877316115, 9.04723...","[4.227173638075145, 10.63426382470813, 15.7127...","[-0.8428305787088188, 1.5121656851876883, 4.33...","[-0.9450612014444033, 3.7751896595246714, 7.89...","[-2.359073654343471, -0.17812931700706658, 2.0...","[-5.292256608037834, -1.3583088445691525, 2.13...",...,-0.633733,0.286509,0.04474,0.225461,0.428751,0.524308,0.569438,"[-2.321485072425892, 2.193866205349398, 5.9848...",0.246365,0.075906
38834,0,"[-2.515533471868592, -3.3763837409928015, -3.6...","[-0.37541259061813426, 0.8670186927997108, 1.9...","[-2.5268652528548756, -2.554462688990466, -2.3...","[-3.6743095131556665, -3.8340951703520814, -3....","[-2.40698026671273, -3.717175520958679, -4.280...","[-0.014006392128633416, -0.601270812369201, -0...","[0.10472759142205357, -1.6370276153459846, -2....","[-1.204481549510489, -1.1791039280567466, -1.4...","[-2.1123620020102973, -1.0601322205135584, -0....",...,-0.055072,0.308697,0.270705,0.492663,0.329835,0.335419,0.066382,"[-0.9622732345660193, -0.9902111290720838, -0....",0.350475,0.05846
45025,0,"[-2.7564534253264172, 0.5916952977506629, 2.13...","[-6.812826381821251, -6.20446425943703, -6.696...","[-2.8630667675293484, -3.639419880555958, -4.7...","[-2.052196674618404, -3.9249087206162576, -4.9...","[-0.7584009274275729, -0.022608853424369357, 0...","[0.25783156634073595, 1.0128318883412546, 1.32...","[-1.7028389195632825, -1.1927582295819874, -0....","[-2.855036257116565, -5.5000208219332665, -7.2...","[-2.746731630402801, -2.369860212769909, -2.24...",...,0.03136,0.078865,0.383203,-0.044663,0.038013,0.544914,-0.059081,"[1.031146171715275, 1.3585180172688673, 1.4232...",0.113854,0.013455


In [26]:
## Make a copy of original data
df_copy = df_data.copy()


#### Don't need to run this if you load the data that has had the correlain/CAR calculated. It typically takes 30mins to complete.

In [None]:
## Don't need to run this if you load the data that has had the correlain/CAR calculated. It typically takes 30mins to complete.
for key in keys_MNIST_EP:
    df_copy[f"{key}_corr"] = pd.NA
df_copy[f"erp"] = pd.NA

for idx, row in tqdm(df_data.iterrows()):
    corr_data = row[keys_MNIST_EP]
    arr = np.stack(corr_data.values)
    # Calculate ERP
    car = np.mean(arr, axis=0)
    # Baseline correction (using first 125 ms as baseline)
    baseline = np.mean(car[:16])
    car_corrected = car - baseline

    for key in keys_MNIST_EP:
        ## CAR common avaerage reference
        ## Correlation = similarity between signal or subtracted signal and CAR
        ##
        ## method used in project
        ## subtract the CAR from each signal and then find corelation with result and CAR
        car_subtracted = corr_data[key] - car_corrected
        correlation = np.corrcoef(car_subtracted, car_corrected)[0, 1]
        ## Alternative method
        ## find correaltion between each signal and the CAR.
        #correlation = np.corrcoef(corr_data[key], car_corrected)[0, 1]
        row[f"{key}_corr"] = correlation
    #corr_data[label] = row[label]
    row['erp'] = car_corrected

    df_copy.loc[idx] = row

df_copy.info()
corr_keys_ = ['EEGdata_T7_corr','EEGdata_P7_corr','EEGdata_T8_corr','EEGdata_P8_corr']
df_copy['corr_mean_core'] = df_copy[corr_keys_].mean(axis=1)
corr_keys_ = [f"{key}_corr" for key in keys_MNIST_EP]
df_copy['corr_mean_all'] = df_copy[corr_keys_].mean(axis=1)
df_copy.head()

In [None]:
#sampled_df = None
df_copy['digit_label'].value_counts()


In [None]:


#df_copy.dropna(axis=1, inplace=True)
# df_copy.info()
# print(df_copy.columns)
# df_copy.head()

In [None]:
fraction = 1
## chnage factor depending upon CAR method used. typically 0.9 or above for method 1 CAR subtraction, 0.2 for method 2 for CAR correlation
factor = 0.925 # 0.2
sampled_indexes = df_copy[df_copy['corr_mean_core'] > factor].groupby(label).apply(lambda x: x.sample(frac=fraction)).index.get_level_values(1).tolist()
sampled_df = df_copy.loc[sampled_indexes]
#sampled_df.info()
print(sampled_df[label].value_counts())
print(sampled_df[label].value_counts().sum())

In [None]:

feature_data = []
label_data = []
for class_label in class_labels:
    class_df = sampled_df[sampled_df[label]==class_label]
    for idx, row in tqdm(class_df.iterrows()):
        for key in keys_:
            w_data = sliding_window_eeg(row[key])
            feature_data.append(np.array(w_data))
            label_data.append(to_categorical(int(class_label),num_classes=len(class_labels)))

train_data = np.array(feature_data)
labels = np.array(label_data).astype(np.uint8)

print(train_data.shape)
print(labels.shape)

x_train, x_test, y_train, y_test = train_test_split(train_data, labels, test_size=0.1, random_state=42)

In [None]:
prefix = ["example_data"]
print(f"writing {root_dir}/{prefix[0]}{output_file}")
data_out = {'x_train':x_train,'x_test':x_test,'y_train':y_train,'y_test':y_test} #{'x_test':train_data,'y_test':labels}
with open(f"{root_dir}/{prefix[0]}{output_file}", 'wb') as f:
    pickle.dump(data_out, f)