In [20]:
import os
import pandas as pd 
import numpy as np
from scipy.signal import iirnotch, butter, filtfilt, welch
from scipy.stats import kurtosis
from mne.time_frequency import tfr_array_morlet
import mne
from library import *
from tqdm import tqdm
from autoreject import get_rejection_threshold

# prev work - 77%ml in r with: 
# notch - 60
# bp - 1-40
# FREQUENCY_BANDS = {'delta': (1, 4), 'theta': (4, 8), 'alpha': (7.5, 13), 'beta': (13, 30), 'gamma': (30, 44)}
# spike removal
# avg ref

def separated_puzzles(df):    
    df['block_id'] = df['value'].str.contains('puzzle_loaded').cumsum()

    skipped = 0
    kept = 0    
    pzls = []
    for blockid, block in df.groupby('block_id'):     
        b = block.copy().reset_index(drop=True)

        startid = b['value'].str.contains('puzzle_loaded').idxmax()
        endid = b['value'].str.contains('puzzle_finished').idxmax()
        
        # trial loaded but before first puzzle, or no puzzle finished event
        if endid == 0 or endid == None: 
            continue        
        
        b = b.iloc[:endid + 1, :]

        puzzle_data = b.loc[endid, 'value']
        puzzle_data = {p.split(':')[0].strip(): p.split(':')[1].strip() for p in puzzle_data.split(';')}

        if 'RESULT' not in puzzle_data or puzzle_data['RESULT'] == 'timeout':
            skipped += 1
            continue  

        
        # adds all the puzzle info; may be useful later.
        # for key in puzzle_data.keys():
        #     b[key] = puzzle_data[key]

        # QUIRK: elo bin is in the end string, but ACTUAL elo value is only in the start string. 
        start_data = b.loc[startid, 'value']
        start_data = {p.split(':')[0].strip(): p.split(':')[1].strip() for p in start_data.split(';')}
        puzzlestr = start_data['PUZZLESTR']
        b['elo'] = puzzlestr.split(',')[3]

        b['elo_bin'] = puzzle_data['ELO']
        b['solved'] = puzzle_data['RESULT'] == 'solved'
        b = b.drop(columns=['value'])

        pzls.append(b)        
        kept += 1

    print(f'Kept {kept} puzzles, skipped {skipped} puzzles')
    return pzls

df = pd.read_csv(DBFNAME, dtype={'pid':str})
df['value'] = df['value'].astype(str)

print(df['pid'].unique())

all_raw = []
all_psd = []
for p in tqdm(df['pid'].unique()):
   # print(p)
    pdf = df[df['pid'] == p].copy().reset_index(drop=True)    

    pdf = mne_filt(pdf)

    raw_dfs = separated_puzzles(pdf)
    #filt_dfs = [mne_psd_morlet(raw_df) for raw_df in raw_dfs]

    # total_neg = 0
    # total = 0
    # for filt_df in filt_dfs:
    #     # print number of cells that contain a negative value 
    #     is_negative = filt_df[[key for key in filt_df.keys() if key not in ['timestamp', 'pid', 'elo', 'block_id',  'solved']]] < 0
    #     total_neg += is_negative.sum().sum()
    #     total += filt_df.shape[0] * filt_df.shape[1]
    # print("-------------------------------------------------------------")
    # print(f"For p {p} - {total_neg / total:.3f}% negative: {total_neg} negative cells out of {total} total cells.")
    # print("-------------------------------------------------------------")
    
    all_raw.extend(raw_dfs)
    #all_psd.extend(filt_dfs)    

#psd_df = pd.concat(all_psd).reset_index(drop=True)
raw_df = pd.concat(all_raw).reset_index(drop=True)

#psd_df.to_csv(DBFNAME.split('.csv')[0] + '_psd_morlet' + FILTNAME, index=False)
raw_df.to_csv(DBFNAME.split('.csv')[0] + '_raw_' + FILTNAME, index=False)

['2a7ba24a' '44dd5d31' '92511e53' '1bf79614' 'cf88d785' '168dd1ac'
 '7cc59678' '16230396' '5c1034cf' '7848bec2']


  0%|          | 0/10 [00:00<?, ?it/s]

Creating RawArray with float64 data, n_channels=4, n_times=373032
    Range : 0 ... 373031 =      0.000 ...  1457.152 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 8449 samples (33.004 s)

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband a

  raw.set_annotations(annotations)
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   4 | elapsed:    0.0s remaining:    0.0s


Not setting metadata


[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    0.5s finished
  0%|          | 0/10 [00:01<?, ?it/s]

(6, 4, 50, 257)





ValueError: Must pass 2-d input. shape=(1028, 50, 6)

In [11]:
def detect_spikes(data, window_size=5, threshold_multiplier=4):
    data_mavg = data.rolling(window=window_size, center=True).mean()
    residuals = data - data_mavg
    threshold = threshold_multiplier * residuals.std()
    spikes = np.abs(residuals) > threshold
    return spikes

In [12]:
def mne_filt(df):    
    data = df[CH_NAMES].T.values           
    
    info = mne.create_info(ch_names=CH_NAMES, sfreq=SAMPLING_RATE, ch_types='eeg')
    raw = mne.io.RawArray(data, info)

    raw.filter(l_freq=.1, h_freq=50, filter_length='auto') # from 1-40.; 77%ML in r. 
    raw.notch_filter(freqs=60, filter_length='auto')
   
    # Re-reference the data to the average of all channels
    raw.set_eeg_reference(ref_channels='average', projection=False)
    
    df_transformed = pd.DataFrame(raw.get_data().T, columns=CH_NAMES)          
    df_transformed['timestamp'] = df['timestamp'] 
    df_transformed['pid'] = df['pid']  
    df_transformed['value'] = df['value']  
    assert(df_transformed.shape == df.shape)

    return df_transformed    
   

In [13]:
def mne_psd(df):

    data = df[CH_NAMES].values.T
    
    # Parameters
    window_length = 256  # Hamming window size
    overlap = 0.9        # Overlap percentage
    step_size = int(window_length * (1 - overlap)) #// 10 # 128-hz new output. //10 -> 128hz; now ~12.8hz
    nfft = 256           # Number of FFT points
    
    # Initialize an empty list to store the results
    power_data = []

    # Slide the window over time and calculate the power values for each frequency band
    for start_idx in range(0, data.shape[1] - window_length, step_size):
        window_data = data[:, start_idx:start_idx + window_length]

        # Calculate power spectral density using Welch's method
        freqs, psd = welch(window_data, fs=SAMPLING_RATE, window='hamming', 
                           nperseg=nfft, noverlap=int(nfft * overlap), scaling='density')
        
        row_data = {}
        for band, (low_freq, high_freq) in FREQUENCY_BANDS.items():
            band_mask = np.logical_and(freqs >= low_freq, freqs <= high_freq)
            band_psd = psd[:, band_mask]
            band_power = np.log(np.sum(band_psd, axis=1)) / np.log(10) # -> for output in bels

            for ch_idx, ch_name in enumerate(CH_NAMES):
                row_data[f'{band}.{ch_name}'] = band_power[ch_idx]

        power_data.append(row_data)

    psd_df = pd.DataFrame(power_data)

    # interpolate timestamps to the psd transformed data
    start_time = df['timestamp'].iloc[0]
    end_time = df['timestamp'].iloc[-1]
    num_datapoints = len(power_data)
    psd_df['timestamp'] = np.linspace(start_time, end_time, num_datapoints)
    
    # these are all the same for each puzzle, so all good.
    psd_df['pid'] = df['pid']
    psd_df['elo'] = df['elo']
    psd_df['elo_bin'] = df['elo_bin']
    psd_df['block_id'] = df['block_id']
    psd_df['solved'] = df['solved']

    return psd_df

In [19]:

def mne_psd_morlet(df):  
    """
    df is one puzzle's worth of data. 
    https://mne.tools/stable/auto_tutorials/time-freq/20_sensors_time_frequency.html#time-frequency-analysis-power-and-inter-trial-coherence
    """  
    data = df[CH_NAMES].values.T
    
    timestamps = df['timestamp'].values
    timestamps -= timestamps[0]
    time_diffs = np.diff(timestamps) / 1e3 # in seconds

    annotations = mne.Annotations(onset=timestamps[:-1],
                                  duration=time_diffs,
                                  description=['EDGE'] * len(time_diffs))

    info = mne.create_info(ch_names=CH_NAMES, sfreq=SAMPLING_RATE, ch_types='eeg')
    raw = mne.io.RawArray(data, info)
    raw.set_annotations(annotations)

    events = mne.make_fixed_length_events(raw, duration=1)
    epochs = mne.Epochs(raw, events, tmin=0, tmax=0.999, baseline=None)

    freqs = np.arange(1, 51)  # Frequencies from 1 to 50 Hz
    n_cycles = freqs / 2. 

    tfr = mne.time_frequency.tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, 
                                        average=False, return_itc=False, n_jobs=-1, 
                                        use_fft=True)

    # convert to psd
    power = np.abs(tfr.data)**2

    # convert to log scale (bels)
    log_power = 10 * np.log10(power)
    print(log_power.shape)

    # band_power_data = []
    # for band, (low_freq, high_freq) in FREQUENCY_BANDS.items():
    #     band_mask = np.logical_and(freqs >= low_freq, freqs <= high_freq)
    #     band_psd = psd[:, :, band_mask]
    #     band_power = np.log10(np.sum(band_psd, axis=2))
    #     band_power_data.append(band_power)

    # result_data = []

    # for epoch_idx in range(band_power_data[0].shape[0]):
    #     epoch_data = {}
    #     for band_idx, (band, (low_freq, high_freq)) in enumerate(FREQUENCY_BANDS.items()):
    #         for ch_idx, ch_name in enumerate(CH_NAMES):
    #             epoch_data[f'{band}.{ch_name}'] = band_power_data[band_idx][epoch_idx, ch_idx]
    #     result_data.append(epoch_data)

#    psd_df = pd.DataFrame(result_data)   
 
    log_power = log_power.transpose(0, 2, 1, 3)  # Now the shape is, for example (6, 50, 4, 257)

    # Reshape the power array to 2D (frequencies x everything else)
    power_2d = log_power.reshape((log_power.shape[0], log_power.shape[1], -1)).T

    # Create frequency labels for the columns
    freq_labels = [f'freq_{freq}Hz' for freq in freqs]

    # Create a DataFrame from the power array
    df_power = pd.DataFrame(power_2d, columns=freq_labels)

    # Create arrays representing the epoch, time, and channel for each row
    epochs_array   = np.repeat(np.arange(power.shape[0]), power.shape[1]*power.shape[2]*power.shape[3])
    ch_names_array = np.tile(np.repeat(CH_NAMES, power.shape[2]*power.shape[3]), power.shape[0]*power.shape[1])
    times_array    = np.tile(np.repeat(tfr.times, len(CH_NAMES)), power.shape[0]*power.shape[1])

    # Add these as columns to the DataFrame
    df_power['epoch']   = epochs_array
    df_power['time']    = times_array
    df_power['channel'] = ch_names_array

    # Reorder the columns so the epoch, time, and channel are first
    df_power = df_power[['epoch', 'time', 'channel'] + freq_labels]


    # # interpolate timestamps to the psd transformed data    
    # start_time = df['timestamp'].iloc[0]
    # end_time = df['timestamp'].iloc[-1]
    # num_datapoints = psd_df.shape[0]
    # psd_df['timestamp'] = np.linspace(start_time, end_time, num_datapoints)
    
    # these are all the same for each puzzle, so all good.
    psd_df['pid']      = df['pid']
    psd_df['elo']      = df['elo']
    psd_df['elo_bin']  = df['elo_bin']
    psd_df['block_id'] = df['block_id']
    psd_df['solved']   = df['solved']
    
    return psd_df

In [15]:
# https://neuraldatascience.io/7-eeg/erp_artifacts.html
def mne_ica(df):
    data = df[CH_NAMES].values.T
    info = mne.create_info(ch_names=CH_NAMES, sfreq=SAMPLING_RATE, ch_types='eeg')
    raw = mne.io.RawArray(data, info)

    print(mne.channels.make_standard_montage('standard_1020', head_size='auto'))

    ch_coords = np.array([
                            [-0.2852, 0.8777, -0.3826], #[-0.5, -0.5, 0],  # TP9
                            [0.8090, 0.5878, 0.0000], #[-0.5,  0.5, 0],  # AF7
                            [0.8090, -0.5878, 0.0000],#[ 0.5,  0.5, 0],  # AF8
                            [-0.2853, -0.8777, -0.3826]#[ 0.5, -0.5, 0]   # TP10
                        ])

    # Create a DigMontage object
    dig_montage = mne.channels.make_dig_montage(
        ch_pos=dict(zip(CH_NAMES, ch_coords)),
        coord_frame='head'
    )
   
    raw.set_montage(dig_montage)
    
    ica_low_cut = 1.0       # For ICA, we filter out more low-frequency power
    hi_cut  = 30

    raw_ica = raw.copy().filter(ica_low_cut, hi_cut)

    tstep = 1.0
    events_ica = mne.make_fixed_length_events(raw_ica, duration=tstep)
    epochs_ica = mne.Epochs(raw_ica, events_ica,
                            tmin=0.0, tmax=tstep,
                            baseline=None,
                            preload=True)

    reject = get_rejection_threshold(epochs_ica)
    reject

    ica_n_components = .99   # Specify n_components as a decimal to set % explained variance

    ica = mne.preprocessing.ICA(n_components=ica_n_components, random_state=RANDOM_STATE)
    ica.fit(epochs_ica, reject=reject, tstep=tstep )

    ica.plot_properties(epochs_ica, picks=range(0, ica.n_components_), psd_args={'fmax': hi_cut});


    ica_z_thresh = 1.96 
    eog_indices, eog_scores = ica.find_bads_eog(raw_ica, 
                                                ch_name=['probe-1', 'probe-2'], 
                                                threshold=ica_z_thresh)
    ica.exclude = eog_indices

    ica.plot_scores(eog_scores)
    

    # # Get the ICA component sources
    # sources = ica.get_sources(raw)

    # # Compute the kurtosis scores for each component
    # kurt_scores = kurtosis(sources.get_data(), axis=1)    
    # eog_indices = np.where(kurt_scores > 7)[0]
        
    # raw_clean = ica.apply(raw.copy(), exclude=eog_indices)

    # data_clean = raw_clean.get_data()
    # data_clean.plot(title="cleaned raw data after ICA", scalings="auto") #,start=0, duration=10)


In [16]:
df = pd.read_csv(DBFNAME, dtype={'pid':str})
print(df['pid'].unique())



['2a7ba24a' '44dd5d31' '92511e53' '1bf79614' 'cf88d785' '168dd1ac'
 '7cc59678' '16230396' '5c1034cf' '7848bec2']
