Import all the base:

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import hilbert, butter, filtfilt, find_peaks
from scipy.ndimage import generic_filter

## Step 1: Use Bartholomew's distance to identify potential task switching points.

Calculate whether the Bhattacharyya distance between 500 milliseconds before and after a given time point exceeds the average of the previous 10 seconds plus 3SD, using 40-millisecond intervals.  
Save the time point into the list

In [None]:
import numpy as np
import pandas as pd
from scipy.signal import welch

#Calculate the bhattacharyya distance
def bhattacharyya_distance(psd1, psd2):
    """
    Calculate the Bhattacharyya distance between two discrete distributions (power spectra).
    psd1 and psd2 should be two power spectral density arrays
    """
    
    # Normalization to probability distribution
    p = psd1 / (np.sum(psd1) + 1e-10)
    q = psd2 / (np.sum(psd2) + 1e-10)
    #bc = 1 mean totally same, bc = 0 mean totally different
    bc = np.sum(np.sqrt(p * q))
    #bhattacharyya distance=-ln(bc)
    return -np.log(bc + 1e-10)


def detect_task_switch_by_bhattacharyya(eeg_data_frame, gfp, sfreq=256):
    """
    Args:
        eeg_data_frame (_type_): pandas dataframe (Row: Time point, Column: EEG channel)
        gfp (_type_): numpy array of every time point's gfp
        sfreq (int, optional): 256hz. Defaults to 256.
    return:
        A list of time points that may be task switch time point
    """
    entire_time_len = len(eeg_data_frame)
    
    #use 40ms as the step length of window
    step_len = int(0.040 * sfreq)
    #use 500 ms as the window length
    window_len = int(0.5 * sfreq)
    
    #Use 10s as the baseline length as filter
    #Later use 10s (mean + 3SD) to filt
    baseline_len = int(10 * sfreq)
    
    
    #Trans pandas dataframe to numpy array to calculate quicker
    #Filp to row is channels, column is time points
    eeg_data_numpy = eeg_data_frame.values.T #Now the shape is (channel, time)
    
    #Set variable to store the candidate time points
    candidate_time_points = []
    bd_scores = []
    
    
    #Traverse the entire dataset, save the bd_scores with time points
    for i in range(window_len, entire_time_len-window_len, step_len):
        #Calculate the bhattacharyya distance between front window and back window
        
        #The data of from window and back window. 
        #Remember now the numpy array is (channel, time) so we can do that. 
        front_window = eeg_data_numpy[:, i-window_len:i]
        back_window = eeg_data_numpy[:, i:i+window_len]
        
        # Spectrum Analysis, get front window and back window's psd(power spectral density)
        _, psd_front_window = welch(front_window, fs=sfreq, nperseg=window_len)
        _, psd_back_window = welch(back_window, fs=sfreq, nperseg=window_len)
        
        avg_psd_front_window = np.mean(psd_front_window, axis=0)
        avg_psd_back_window = np.mean(psd_back_window, axis=0)
        
        distance = bhattacharyya_distance(avg_psd_front_window, avg_psd_back_window)
        
        
        #Name bhattacharyya_distance bd_scores to make sure it's clear
        bd_scores.append((i, distance))
    
    #Filter the possible task switch time points
    for idx,(time_point, bd) in enumerate(bd_scores):
        #The fist 10s is the baseline, so we skip it, assume not task switch happen on first 10s
        if time_point < baseline_len:
            continue
        
        #Select the scores in front 10s of time point
        front_bd_scores = [s for p, s in bd_scores if time_point - baseline_len <= p < time_point]
        
        #Make sure there is no error happen
        #Normally should not happen, but just for safety
        if not front_bd_scores:
            print(f"Error: No scores found for time point {time_point}")
            continue
        #The threshold is mean + 3SD of 10s before time point
        threshold = np.mean(front_bd_scores) + 3 * np.std(front_bd_scores)
        
        if bd > threshold:
            candidate_time_points.append(time_point)
    return candidate_time_points




In [None]:
#Calculate the average power for a specific frequency(alpha/theta) band using the Welch method.
def get_multi_channel_band_power(eeg_data, sfreq, band):

    f, psd = welch(eeg_data, fs=sfreq, nperseg=eeg_data.shape[1], axis=-1)
    idx = np.logical_and(f >= band[0], f <= band[1])
    band_psd = psd[:, idx]
    return np.mean(band_psd)


#Method 1: No detail(Use 200ms as the base)
def alpha_theta_check(candidate_time_points, eeg_data_frame, sfreq=256):
    """_summary_

    Args:
        candidate_time_points (list): A list of possible task switch times point
        eeg_data_frame: A pandas dataframe of EEG data
        
        
    Check the alpha and theta power for a candidate time point.
    Use 300ms as the window size to check.
        Check range is 15000ms near which is 750ms before and after the candidate time point.
        Check if the power of alpha significant increase with theta significant decrease(or reverse) happened in 200ms. 
            (By comparing the power of alpha and theta in 200ms before and after)
        
    """
    
    #Set alpha and theta bands as defauly
    alphs_band = (8, 13)
    theta_band = (4, 8)
    
    
    #Set list save verified time points
    verified_segments = []
    
    #The window size is 200ms
    check_window = int(0.2 * sfreq)
    #The range is 750ms before and after the candidate time point
    half_range = int(0.75 * sfreq)
    #use 40ms as the step length of window
    step_len = int(0.040 * sfreq)
    
    #Trans pandas dataframe to numpy array to calculate quicker
    #Filp to row is channels, column is time points
    eeg_data_numpy = eeg_data_frame.values.T #Now the shape is (channel, time)
    
    for time_point in candidate_time_points:
        #Set the scope of checking(1500ms)
        start_scope = max(0, time_point - half_range)
        end_scope = min(len(eeg_data_frame), time_point + half_range)

        #The scope must be larger than the window size times 2(otherwise it is meaningless)
        if end_scope - start_scope < check_window*2:
            break
        
        #Set the variable
        alpha_change = False
        theta_change = False
        change_start = time_point
        change_end = time_point
        
        #Start checking:
        for time in range(start_scope, end_scope - check_window, step_len):
            #Set fron half window and post half window(200ms each)
            pre_window = eeg_data_numpy[:, time : time + check_window]
            post_window = eeg_data_numpy[:, time + check_window:time + 2*check_window]
            
            #Count the mean power of Alpha/Theta in pre and post window
            alpha_pre = get_multi_channel_band_power(pre_window, sfreq, alphs_band)
            alphs_post = get_multi_channel_band_power(post_window, sfreq, alphs_band)
            theta_pre = get_multi_channel_band_power(pre_window, sfreq, theta_band)
            theta_post = get_multi_channel_band_power(post_window, sfreq, theta_band)
            
            #Count the change ratio of Alpha/Theta
            alphs_change_ratio = (alphs_post - alpha_pre) / (alpha_pre + 1e-10)
            theta_change_ratio = (theta_post - theta_pre) / (theta_pre + 1e-10)
            
            
            if (abs(alphs_change_ratio) >= 0.3 and abs(theta_change_ratio) >= 0.2):
                if ((alphs_change_ratio < 0 and theta_change_ratio > 0) or (alphs_change_ratio > 0 and theta_change_ratio < 0)):
                    alpha_change = True
                    theta_change = True
                    change_start = min(change_start, time)
                    change_end = max(change_end, time+check_window)
        
        if alpha_change and theta_change:
            verified_segments.append((change_start, change_end))
            


#Method 2: With detail, using envelope

def butter_bandpass_filter(data, low, high, sfreq=256, order=4):
    nyq = 0.5 * sfreq
    low_cut = low / nyq
    high_cut = high / nyq
    b, a = butter(order, [low_cut, high_cut], btype='band')
    return filtfilt(b, a, data, axis=0)

def verify_alpha_theta_2(eeg_data_frame, candidate_time_points, sfreq=256):
    """
    Args:
        candidate_time_points (list): A list of possible task switch times point
        eeg_data_frame: A pandas dataframe of EEG data
        sfreq (int): Sampling frequency of the EEG data(256hz)
    """

    entire_time_len = len(eeg_data_frame)
    
    #Get the average signal of all signal(First, take the average of all channels, then calculate the envelope to reflect the overall trend)
    average_signal = eeg_data_frame.mean(axis=1).values
    
    #Get the envelope of alpha and theta band
    alpha_envelope = np.abs(hilbert(butter_bandpass_filter(average_signal, 8, 13, sfreq=sfreq)))
    theta_envelope = np.abs(hilbert(butter_bandpass_filter(average_signal, 4, 8, sfreq=sfreq)))
    
    #Since power is V^2 and envelope is V, we need to square the envelope if we want to use the same standard value in paper.
    alpha_envelope_sq = alpha_envelope**2
    theta_envelope_sq = theta_envelope**2
    
    #Set list save verified time points
    verified_segments = []
    #The window size is 200ms
    check_window = int(0.2 * sfreq)
    #The range is 750ms before and after the candidate time point
    half_range = int(0.75 * sfreq)
    #use 40ms as the step length of window
    step_len = int(0.040 * sfreq)

    for time_point in candidate_time_points:
        search_start = max(0, time_point - half_range)
        search_end = min(entire_time_len, time_point + half_range)
        
        #Set the variable
        coarse_start = None
        
        for time in range(search_start, search_end - check_window, step_len):
            
            alpha_pre = np.mean(alpha_envelope_sq[time : time+check_window])
            alpha_post = np.mean(alpha_envelope_sq[time+check_window : time+2*check_window])
            theta_pre = np.mean(theta_envelope_sq[time : time+check_window])
            theta_post = np.mean(theta_envelope_sq[time+check_window : time+2*check_window])
            
            alpha_ratio = (alpha_post - alpha_pre) / (alpha_pre + 1e-10)
            theta_ratio = (theta_post - theta_pre) / (theta_pre + 1e-10)
            
            if abs(alpha_ratio) >= 0.3 and abs(theta_ratio) >= 0.2:
                if ((alpha_ratio < 0 and theta_ratio > 0) or (alpha_ratio > 0 and theta_ratio < 0)):
                    coarse_start = time
                    break
            
            #If found the coarse start. We want to find more details of when the task swich start and end
            #Try to find the peak of alpha and theta gradient
            if coarse_start is not None:
                
                #Set search range be 400 ms
                search_start = max(0, coarse_start - int(0.4 * sfreq))
                search_end = min(entire_time_len, coarse_start + int(0.4 * sfreq))
                
                
                alpha_grad = np.abs(np.diff(alpha_envelope[search_start: search_end]))
                theta_grad = np.abs(np.diff(theta_envelope[search_start: search_end]))
                
                #find peak of alpha and theta gradient
                precise_alpha_idx = np.argmax(alpha_grad) + search_start
                precise_theta_idx = np.argmax(theta_grad) + search_start
                
                #Take the earlier one as the task switch start and later one as the end
                final_task_switch_start = min(precise_alpha_idx, precise_theta_idx)
                final_task_switch_end = max(precise_alpha_idx, precise_theta_idx)
                
                verified_segments.append((final_task_switch_start, final_task_switch_end))
    return verified_segments
            


## Step 3:Check by GFP

3 Part in this function:  
Part A: There exist lowest GFP in the range of (300ms before task switch start) to (100ms after task switch end) compare to other time. The reason behaind is before task switch brand will shut down most part of brain. The range base on alpha and theta not 100% accurate, so use more time to make sure. Therefore check the lowest GFP in the range of (1500ms before task switch start) to (300ms after task switch end) also in (300ms before task switch start) to (100ms after task switch end).  

Part B: Check the mean of lowest GFP in the range of (300ms before task switch start) to (task switch end). The GFP decrease before task switch start and increase after task switch end.  So the mean of 1500ms before task switch start to 300ms should be less than then mean of (300ms before task switch start)'s GFP

Part C: From paper, before the task switch, there will be a 20hz to 50hz GFP decrease trend,so try to find the 6 decrease(23ms in 256hz) trend exist in (300ms before task switch start) to (task switch end).


In [None]:
#Calculates the GFP (Global Field Power) of an EEG signal.
def calculate_GFP(eeg_data_frame, sfreq=256):
    """
        Calculate the Global Field Power (GFP) of an EEG data.
        Assume data structure is Pandas DataFrame with columns as channels and rows as time points.
    """
    #GFP is the SD of every time point across all channels(column)
    #Count and trans to numpy array
    gfp = eeg_data_frame.std(axis=1).values
    
    #Use 30ms as the smoothing range
    smooth_range = int(0.03 * sfreq)
    # Simple moving about 43ms smoothing removes extremely high frequency spikes
    #WHY: 
        #Assume noise can't 100% remove, so noise can make spikes happen.
        #Make spikes more obvious by averaging them out.
    smooth_gfp = np.convolve(gfp, np.ones(smooth_range)/smooth_range, mode='same')
    
    return smooth_gfp

def gfp_check(candidate_task_switch, smooth_gfp, sfreq=256):
    #Set variable
    verified_segments = []
    
    #Set task switch range before and after. 300ms before start time and 100ms after end time.
    before_check_range = int(0.3 * sfreq)
    after_check_range = int(0.1 * sfreq)
    
    
#Can change later after discussing 
    #The check lowest GFP range. 1500 ms before task switch until 300ms after task switch end.
        #For alpha and theta, it may change in 200ms. But consider the task swich is happened in undreds to thousands of milliseconds,
            #So we set the check range to 1500ms before and 300ms later.
    check_lowest_range_before = int(1.5*sfreq)
    check_lowest_range_after = int(0.3*sfreq)
    
    #Check average GFP in the check range is decrease(350hs)
    pre_avg_smp = int(0.350 * sfreq)
    
    #use 25ms as the minumum decrease trend scope
    decrease_scope = int(0.025 * sfreq)
    
    
    for start, end in candidate_task_switch:
        #Assume large task switch not happen in the first 1500ms
        if start < check_lowest_range_before:
            continue
        
        #Part A:
        #Set the actuall range of evaluation lowest GFP
        lowest_GFP_check_start = start - check_lowest_range_before
        lowest_GFP_check_end = min(len(smooth_gfp),end + check_lowest_range_after)

        #Check the lowest GFP in the range
        gfp_check_data = smooth_gfp[lowest_GFP_check_start : lowest_GFP_check_end]
        #Find the lowest GFP in the range
        lowest_gfp = np.argmin(gfp_check_data) + lowest_GFP_check_start
        
        
        start_check_point = start - before_check_range
        end_check_point = min(len(smooth_gfp), end + after_check_range)
        #Check the GFP lowest point of (1500ms+start) to (300ms+end) is also in (300ms+start) to (100ms+end)
        is_lowest_in_task_switch = (start_check_point <= lowest_gfp <= end_check_point)
        
        #If the lowest GFP is not in the task switch range, continue to next candidate
        if not is_lowest_in_task_switch:
            continue
        
        
        #Part B:
        #Then, check the average GFP near to start are less than far way to start.
        #In this function, compare 1500ms to 350ms before start time and 350ms before start time
        average_futher_GFP = np.mean(smooth_gfp[start - check_lowest_range_before : start - pre_avg_smp])
        average_near_GFP = np.mean(smooth_gfp[start - pre_avg_smp : start])
        
        if average_near_GFP > average_futher_GFP:
            continue
    

        #Part C:
        #Check is there exits 25ms GFP decrease trend in 300hs before start to end time.
        
        #Consider the range is same with find lowest GFP, so just use the variable before
        search_trend_data = smooth_gfp[start_check_point : end]
        diffs = np.diff(search_trend_data)
        
        decrease_count = 0
        find_decrease_trend = False
        for each in diffs:
            if each > 0:
                decrease_count = 0
            else:
                decrease_count += 1
            #In 256hz, the 6 points is about 23ms.
            #If there are 20 hz is decrease, it means decrease trend exist.
            if decrease_count >= 6:
                find_decrease_trend = True
                break
        if find_decrease_trend:
            verified_segments.append((start, end))
    
    return verified_segments
        


