In [None]:
import mne  # For EEG/MEG data processing
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from mne.io import read_raw_edf, read_raw_fif, read_raw_bdf
from mne.preprocessing import ICA
from tqdm import tqdm
import pickle
import mat73
from scipy.io import loadmat

import os

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

def ion():
    mne.viz.set_browser_backend('qt', verbose = True)
    return

def ioff():
    mne.viz.set_browser_backend('matplotlib', verbose = False)
    return

# Load data and basic informations

In [None]:
# Main path that will be used for saving
main_path = '/home/lunis/Documents/nlin-EEG/data/BMD'

# Subject ID
subj = '001'

# Subject unprocessed raw header location
file_path = main_path + 'Unprocessed/subj' + subj + '.vhdr'

# Load the EEG data from the .vhdr file
raw = mne.io.read_raw_brainvision(file_path, preload=True)

# Print general information about the loaded data
print(raw.info)

# # Plot the raw EEG data
# raw.plot()

In [None]:
# Crop if recording is too long
#raw.crop(tmax = 2930)

In [None]:
# Drop Iz
raw.drop_channels(['Iz'])

In [None]:
for ch_name, ch_type in zip(raw.info['ch_names'], raw.get_channel_types()):
    print(f"Channel: {ch_name}, Type: {ch_type}")

In [None]:
# Sampling rate (Hz)
print(f'Sampling rate: {raw.info["sfreq"]} Hz')

# List of channels (e.g., Fp1, Fp2, Cz, etc.)
print(f'Channels: {raw.ch_names}')

# Duration of the recording (in seconds)
duration = raw.n_times / raw.info['sfreq']
print(f'Recording duration: {duration} seconds')

# Number of EEG channels
n_channels = len(raw.ch_names)
print(f'Number of channels: {n_channels}')

# Preprocessing

## Basic filtering and rereferencing

In [None]:
# Filter the EEG data (remove frequencies below 1 Hz and above 40 Hz)
filtr = [2.0,100.0]

raw_filtered = raw.copy().filter(l_freq=filtr[0], h_freq=filtr[1])

# Remove current interference
raw_filtered.notch_filter(freqs = np.arange(50, 251, 50))

In [None]:
# Compare PRE and POST filter raw signals side by side
ioff()
pre = raw.plot(n_channels = 61, start = 250, show = False);
plt.close()
post = raw_filtered.plot(n_channels = 61, start = 250, show = False);
plt.close()

pre_canvas = FigureCanvas(pre)
pre_canvas.draw ()

pre_img = np.asarray(pre_canvas.buffer_rgba())

post_canvas = FigureCanvas(post)
post_canvas.draw ()

post_img = np.asarray(post_canvas.buffer_rgba())

fig, ax = plt.subplots(1,2, figsize = (20,20), dpi = 300)

ax[0].imshow(pre_img)
ax[1].imshow(post_img)

ax[0].set_axis_off()
ax[1].set_axis_off()

ax[0].set_title('PRE-Filtering')
ax[1].set_title('POST-Filtering')

plt.tight_layout()
plt.show()

## ICA preprocessing

In [None]:
ica = ICA(n_components=30, random_state=42, max_iter='auto')
ica.fit(raw_filtered)

In [None]:
### Print Power spectrum of ICA decomposition ###
sources = ica.get_sources(raw_filtered)

# Plot the power spectrum for each ICA component
n_components = ica.n_components_ 

# Loop through each ICA component
for i in range(n_components):
    # Extract the signal of component i
    component_data = sources.get_data(picks=[i])  # Get data for the i-th ICA component

    # Compute the power spectral density (PSD) of the component
    psd, freqs = mne.time_frequency.psd_array_welch(
        component_data[0],  # Extract the first row (since it's a single component)
        sfreq=raw.info['sfreq'],  # Sampling frequency from the raw data
        fmin=filtr[0], fmax=filtr[1],  # Focus on the filter range
        n_fft=2048,  # Length of FFT (controls frequency resolution)
        verbose = False
    )

    # Plot the power spectrum of the component
    fig, ax = plt.subplots(1,2, figsize = (7,3), dpi = 100)
    ax[1].plot(freqs, 10 * np.log10(psd), label=f'Component {i}')
    ax[1].set_xlabel('Frequency (Hz)')
    ax[1].set_ylabel('Power (dB)')
    ax[1].set_title(f'Power Spectrum of ICA Component {i}')
    ax[1].grid(True, linestyle='--', alpha=0.6)
    ax[1].set_ylim(-35,-6)

    ica.plot_components(picks = i, show_names = True, axes = ax[0]);
    plt.show()

In [None]:
# Plot sources in interactive mode for bad sources selection
ion()
ica.plot_sources(raw_filtered, start = 250);

### Remove bad components

In [None]:
# See what components where excluded in interactive mode
bad_ICA = sorted(ica.exclude)
print(bad_ICA)

In [None]:
# Apply ICA to copy of raw data
cleaned_raw = raw_filtered.copy()
ica.apply(cleaned_raw)

In [None]:
# Compare PRE and POST ICA raw signals side by side
ioff()
pre = raw_filtered.plot(n_channels = 61, start = 250, show = False);
plt.close()
post = cleaned_raw.plot(n_channels = 61, start = 250, show = False);
plt.close()

pre_canvas = FigureCanvas(pre)
pre_canvas.draw ()

pre_img = np.asarray(pre_canvas.buffer_rgba())

post_canvas = FigureCanvas(post)
post_canvas.draw ()

post_img = np.asarray(post_canvas.buffer_rgba())

fig, ax = plt.subplots(1,2, figsize = (20,20), dpi = 300)

ax[0].imshow(pre_img)
ax[1].imshow(post_img)

ax[0].set_axis_off()
ax[1].set_axis_off()

ax[0].set_title('PRE-ICA Removal')
ax[1].set_title('POST-ICA Removal')

plt.tight_layout()
plt.show()

In [None]:
# Open POST ICA raw in interactive mode
ion()
cleaned_raw.plot(n_channels = 61, start = 250);

In [None]:
# Save interpolated channels names and apply interpolation
bad_CH = cleaned_raw.info['bads']

for i, b in enumerate(bad_CH):
    bad_CH[i] = str(b)

if len(bad_CH) != 0:
    cleaned_raw.interpolate_bads()
    ioff()
    cleaned_raw.plot(n_channels = 61, start = 250);

    print(f'Channels {bad_CH} were interpolated')

else:
    print('No bad channels were selected')

# Extract the events

In [None]:
# Extract event markers
events, event_dict = mne.events_from_annotations(raw) # events : n_events x ? x event code

# Print all the event names and codes
print(f'Event dictionary: {event_dict}')

# Keep only the ones we are interested in
red_event_dict = {'S__1': 1,
                  'S__2': 2,
                  'S__3': 3,
                  'S__4': 4,
                  'S_11': 11,
                  'S_12': 12,
                  'S_13': 13,
                  'S_14': 14}

# Print reduced event names and codes
print(f'Reduced event dictionary: {red_event_dict}')

In [None]:
# Backward Maksing Triggers

# 1 Conscious Left, Self
# 2 Conscious Left, Other
# 3 Conscious Right, Self
# 4 Conscious Right, Other

#11 Unconscious Left, Self
#12 Unconscious Left, Other
#13 Unconscious Right, Self
#14 Unconscious Right, Other

#20 Correct Localization Answer (not interested)
#21 Incorrect Localization Answer (not interested)
#30 Correct Discrimination Answer (not interested)
#31 Incorrect Discrimination Answer (not interested)

#40 Invisible
#41 Almost invisible
#42 Barely Visible
#43 Visible

# Make Trial Conscious or Unconscious based on Answer

# Create copy of events array
m_events = np.copy(events)

vis_masked = 0
invis_unmasked = 0
for i, event in enumerate(events):

    # If visible
    if event[2] in [43]:
        if m_events[i - 3][2] in [11,12,13,14]:
            m_events[i - 3][2] = m_events[i - 3][2] - 10
            vis_masked += 1

    # If not visible
    elif event[2] in [40]:
        if m_events[i - 3][2] in [1,2,3,4]:
            m_events[i - 3][2] = m_events[i - 3][2] + 10
            invis_unmasked += 1

print(f'Swapped {vis_masked} masked trials which where rated conscious'
      f'\nSwapped {invis_unmasked} unmasked trials which where rated unconscious')

# Save number of swapped triggers
swap = [vis_masked,invis_unmasked]

In [None]:
# Add comment regarding preprocessing results
comment = '\'none\''

# Reformat SubID
fsubj = f'\'{subj}\''

# Create row to append in dataframe
sub_d = [fsubj, filtr, bad_ICA, bad_CH, swap, comment]

proc_info = pd.read_csv(main_path + 'proc_info.csv', sep = ';')
proc_info.loc[-1] = sub_d
print(proc_info)

proc_info.to_csv(main_path + 'proc_info.csv', sep = ';', index = False)

In [None]:
# Save Epochs file for further analysis
epochs = []
for event_key in red_event_dict.keys():

    print(f'\nSaving condition {event_key}')
    event_id = red_event_dict[event_key]
    epoch = mne.Epochs(cleaned_raw, m_events, event_id = event_id, tmin = -0.2, tmax = 0.7, baseline = (None, 0), preload = True, verbose = True)

    epochs.append(epoch)

    # Save condition epoch to local
    #os.makedirs(main_path + subj + '/', exist_ok = True)
    #epochs.save(main_path + subj + '/' + subj + event_key + '-epo.fif')

In [None]:
ion()
epochs[0].plot()

In [None]:
# Get exact timepoints of epochs
evoked = epochs.average()

ts = evoked.get_data()

print(f'Extracted events time series have {ts.shape[1]} time points')

#   Automated preprocessing

In [None]:
import mne  # For EEG/MEG data processing
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from mne.io import read_raw_edf, read_raw_fif, read_raw_bdf
from mne.preprocessing import ICA
from tqdm import tqdm
import pickle
import mat73
from scipy.io import loadmat

import os
import warnings

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

#Our Very Big Dictionary
from init import get_maind

maind = get_maind()

def ion():
    mne.viz.set_browser_backend('qt', verbose = True)
    return

def ioff():
    mne.viz.set_browser_backend('matplotlib', verbose = False)
    return

# Iterable processing function
def it_process(i: int):

    # Suppress warnings
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')

        raw = mne.io.read_raw_brainvision(instructions['file_paths'][i], preload = False, verbose = False)
        post = mne.io.read_raw_brainvision(instructions['file_paths'][i], preload = True, verbose = False)

    if instructions['subIDs'][i] == '006':

        raw.crop(tmax = 2930)
        post.crop(tmax = 2930)

    # Apply filter
    post.filter(l_freq = instructions['f_filter'][0], h_freq = instructions['f_filter'][1], verbose = False);

    # Apply ICA if components are specified
    if len(instructions['ICA'][i]) != 0:

        ica = ICA(n_components=30, random_state=42, max_iter='auto', verbose = False)
        ica.fit(post, verbose = False)

        ica.exclude = instructions['ICA'][i]

        ica.apply(post, verbose = False)

        del ica

    if instructions['show'] == True:
        
        ioff()
        pre_p = raw.plot(n_channels = 61, start = 250, show = False);
        plt.close()
        post_p = post.plot(n_channels = 61, start = 250, show = False);
        plt.close()

        pre_canvas = FigureCanvas(pre_p)
        pre_canvas.draw()

        pre_img = np.asarray(pre_canvas.buffer_rgba())

        post_canvas = FigureCanvas(post_p)
        post_canvas.draw()

        post_img = np.asarray(post_canvas.buffer_rgba())

        fig, ax = plt.subplots(1,2, figsize = (20,15), dpi = 200)

        ax[0].imshow(pre_img)
        ax[1].imshow(post_img)

        ax[0].set_axis_off()
        ax[1].set_axis_off()

        ax[0].set_title('PRE-Processing')
        ax[1].set_title('POST-Processing')

        fig.suptitle(instructions['subIDs'][i])

        plt.tight_layout()
        plt.show()

        del pre_p, post_p, pre_canvas, post_canvas, pre_img, post_img

    del raw

    # Extract event markers
    events, event_dict = mne.events_from_annotations(post, verbose = False)

    # Create copy of events array
    m_events = np.copy(events)

    swap = None

    if instructions['swap'] == True:

        vis_masked = 0
        invis_unmasked = 0
        for j, event in enumerate(events):

            # If visible
            if event[2] in [43]:
                if m_events[j - 3][2] in [11,12,13,14]:
                    m_events[j - 3][2] = m_events[j - 3][2] - 10
                    vis_masked += 1

            # If not visible
            elif event[2] in [40]:
                if m_events[j - 3][2] in [1,2,3,4]:
                    m_events[j - 3][2] = m_events[j - 3][2] + 10
                    invis_unmasked += 1

        # Save number of swapped triggers
        swap = [vis_masked,invis_unmasked]

    # Create new directory
    new_path = maind['path'] + f'data/{instructions['newcode']}/'

    # Save Epochs file for further analysis
    for event_key in instructions['events'].keys():

        event_id = instructions['events'][event_key]
        epochs = mne.Epochs(post, m_events, event_id = event_id, tmin = instructions['tlims'][0], tmax = instructions['tlims'][1], baseline = (None, 0), preload = False, verbose = False)

        # Save condition epoch to local
        os.makedirs(new_path + instructions['subIDs'][i] + '/', exist_ok = True)
        epochs.save(new_path + instructions['subIDs'][i] + '/' + instructions['subIDs'][i] + event_key + '-epo.fif', verbose = False, overwrite = True)

        del epochs
    del post

    # Add info to Dataframe
    comment = '\'none\''

    # Reformat SubID
    fsubj = f'\'{instructions['subIDs'][i]}\''

    # Create row to append in dataframe
    sub_d = [fsubj, instructions['f_filter'], instructions['ICA'][i], None, swap, comment]

    #print(f'Subject {instructions['subIDs'][i]} done!')

    return sub_d

def auto_process(maind: dict, instructions: dict):

    # Unprocessed data path
    source_path = maind[instructions['exp_name']]['directories']['rw_data']

    # Subject IDs
    subIDs = maind[instructions['exp_name']]['subIDs']

    # Subject unprocessed raw header location
    file_paths = [source_path[:-1] + 'Unprocessed/subj' + subj + '.vhdr' for subj in subIDs]

    # Add lists to instructions for iteration
    instructions['subIDs'] = subIDs
    instructions['file_paths'] = file_paths

    iters = [i for i in range(0,len(subIDs))]

    from multiprocessing import Pool

    workers = 5
    chunksize = 1

    with Pool(workers, maxtasksperchild = 1) as p:
    
        results = list(tqdm(p.imap(it_process, iters, chunksize = chunksize),
                            desc = 'Processing subjects',
                            unit = 'sub',
                            total = len(iters),
                            leave = True,
                            dynamic_ncols = True)
                        )

    # Create list with processing information
    dataf = []
    for r in results:

        dataf.append(r)

    # Create new directory
    new_path = maind['path'] + f'data/{instructions['newcode']}/'

    # Save info to dataframe
    proc_info = pd.DataFrame(dataf, columns = ['SubID','Filter','bad_ICA','bad_CH','TRIG_swap','Comment'])

    proc_info.to_csv(new_path + 'proc_info.csv', sep = ';', index = False)

    return

##   Process without ICA

In [None]:
# Set Instructions
instructions = {
    'exp_name': 'bmasking_dense',
    'newname': 'bmasking_dense_noICA',
    'newcode': 'BMDnoICA',
    'ch_drops': ['Iz'],
    'f_filter': [2.0,50.0],
    'ICA': [[] for i in range(0,36)],
    'events': {'S__1': 1,
               'S__2': 2,
               'S__3': 3,
               'S__4': 4,
               'S_11': 11,
               'S_12': 12,
               'S_13': 13,
               'S_14': 14},
    'tlims': (-0.2,0.7),
    'show': False,
    'swap': True,
}

auto_process(maind = maind, instructions = instructions)

##  Process with way too much ICA

In [None]:
# Set Instructions
instructions = {
    'exp_name': 'bmasking_dense',
    'newname': 'bmasking_dense_highICA',
    'newcode': 'BMDhighICA',
    'ch_drops': ['Iz'],
    'f_filter': [0.5,40.0],
    'ICA': [[j for j in range(0,15)] for i in range(0,36)],
    'events': {'S__1': 1,
               'S__2': 2,
               'S__3': 3,
               'S__4': 4,
               'S_11': 11,
               'S_12': 12,
               'S_13': 13,
               'S_14': 14},
    'tlims': (-0.2,0.7),
    'show': True,
    'swap': True,
}

auto_process(maind = maind, instructions = instructions)

#   Alessio's Stuff

# Utils

In [None]:
def match_with_tolerance(list1, list2, tol):
    """
    For each element in list1, find a unique element in list2 such that
    abs(list1[i] - list2[j]) <= tol. If multiple candidates exist,
    pick the one with minimal |difference|. Each element in list2 can
    be matched at most once.

    Returns a list `matches` of length len(list1), where matches[i] is
    the index j in list2 matched to list1[i], or None if no match.
    """
    matches = [None] * len(list1)
    used    = set()  # keep track of already-matched indices in list2

    for i, x in enumerate(list1):
        best_j    = None
        best_diff = tol + 1e-12

        # scan through list2 to find the closest unused candidate
        for j, y in enumerate(list2):
            if j in used:
                continue
            diff = abs(x - y)
            if diff <= tol and diff < best_diff:
                best_diff = diff
                best_j    = j

        if best_j is not None:
            used.add(best_j)
        matches[i] = best_j

    return matches

In [None]:
# start_e = 1001 # code for the start of a netural trial
# start_n = 1002 # code for the start of an emotion trial
# resp_e = 1005
# resp_n = 1005
# end_e = 1007
# end_n = 1006
# confs = [41,42,43,44,51,52,53,54]
# cross_switch = 5

In [None]:
# # neutral trial
# # The control time (the time of the switch of the cross) as to be matched with the response time of the face trials
# data = cleaned_raw
# # data = raw_filtered


# cond_codes = [(start_e, resp_e, end_e),(start_n, resp_n, end_n)]

# trials_dict = {'trial': [], 'trial_norm': [], 'baseline':[],'baseline_norm':[], 'resp_point': [], 'label': [], 'confidence':[], 'crosstime':[],
#                'matching':[]}
# for c, cond in enumerate(['real','contr']):# Loop over the two classes
#     print(cond)
#     start = cond_codes[c][0]
#     resp = cond_codes[c][1]
#     end = cond_codes[c][2]
#     where_start = np.where(events[:,2]==start)
#     for w in tqdm(where_start[0]): # Loop over the start of each trial
#         start_point = events[w][0]
#         i=1
#         itsended = False # the current trial has an end
#         control=False
#         resp_point=0
#         conf_val = None
#         control_time = None
#         while(w+i!=len(events)): # Loop over the successive sample after the start, looking for the end
#             curr_samp = events[w+i]
#             ##### DECOMMENTARE CON I TRIGGER NORMALI ####
#             # if curr_samp[2]==cross_switch and c==1:
#             #     # control = True
#             #     control_time = curr_samp[0]
#             if curr_samp[2]==start_n or curr_samp[2]==start_e: # The end of the trial is missing, discard the trial
#                 break
#             if curr_samp[2]==resp:
#                 if resp_point==0:
#                     itsended=True
#                     resp_point=curr_samp[0]
#                     end_point=curr_samp[0]
#             if curr_samp[2]==end:
#                 end_point=curr_samp[0]
#                 itsended=True
#             if curr_samp[2] in confs:
#                 conf_val = curr_samp[-1]
#                 break
#             i+=1

#         if itsended:
#             resp_point = resp_point-start_point
#             if c==0: # Face
#                 if resp_point>0:
#                     # I take the activity from -150 to -50 as a baseline (no stimulation)
#                     extracted_data = data.get_data(start=start_point-100, stop=end_point)
                    
#                     base = extracted_data[:, :100]
#                     # print(base.shape)
#                     mean_base = base.mean(axis=1)
#                     extracted_data_scale = extracted_data.T-mean_base
#                     # print(extracted_data_scale.shape)
#                     extracted_data_norm = extracted_data_scale/extracted_data_scale.std() # z-score the trial
#                     extracted_data_scale = extracted_data_scale.T
#                     extracted_data_norm = extracted_data_norm.T
                    
#                     trials_dict['trial'].append(extracted_data_scale[:, 100:])
#                     trials_dict['trial_norm'].append(extracted_data_norm[:, 100:])
#                     trials_dict['baseline'].append(extracted_data_scale[:, :100])
#                     trials_dict['baseline_norm'].append(extracted_data_norm[:, :100])
#                     trials_dict['resp_point'].append(resp_point)
#                     trials_dict['label'].append(cond)
#                     trials_dict['confidence'].append(conf_val)
#                     trials_dict['crosstime'].append(None)
#             elif c==1: # Control
#                 if resp_point>0:
#                     # I take the activity from -150 to -50 as a baseline (no stimulation)
#                     extracted_data = data.get_data(start=start_point-100, stop=end_point)
                    
#                     base = extracted_data[:, :100]
#                     # print(base.shape)
#                     mean_base = base.mean(axis=1)
#                     extracted_data_scale = extracted_data.T-mean_base
#                     # print(extracted_data_scale.shape)
#                     extracted_data_norm = extracted_data_scale/extracted_data_scale.std() # z-score the trial
#                     extracted_data_scale = extracted_data_scale.T
#                     extracted_data_norm = extracted_data_norm.T
#                     ##### DECOMMENTARE CON I TRIGGER NORMALI ####
#                     # control_time = control_time-start_point +time_adjuster
#                     # if control_time[0]>0:
#                     ##### DEINDENTARE CON I TRIGGER NORMALI #####
#                     trials_dict['trial'].append(extracted_data_scale[:, 100:])
#                     trials_dict['trial_norm'].append(extracted_data_norm[:, 100:])
#                     trials_dict['baseline'].append(extracted_data_scale[:, :100])
#                     trials_dict['baseline_norm'].append(extracted_data_norm[:, :100])
#                     trials_dict['resp_point'].append(resp_point)
#                     trials_dict['label'].append(cond)
#                     trials_dict['confidence'].append(conf_val)
#                     # trials_dict['crosstime'].append(control_time[0])
#                     trials_dict['crosstime'].append('cane')

                        
                
    
    
# print(len(trials_dict['trial']),len(trials_dict['trial_norm']),len(trials_dict['baseline']),len(trials_dict['baseline_norm']),
#       len(trials_dict['resp_point']), len(trials_dict['label']),len(trials_dict['confidence']),)
# print(trials_dict['trial'][0].shape,trials_dict['trial_norm'][0].shape,trials_dict['baseline'][0].shape,trials_dict['baseline_norm'][0].shape)
            


#------------- QUELLO VERO DA DECOMMENTARE CON I TRIGGER GIUSTI -------------------------------------#

# neutral trial
# The control time (the time of the switch of the cross) as to be matched with the response time of the face trials
data = cleaned_raw
# data = raw_filtered


cond_codes = [(start_e, resp_e, end_e),(start_n, resp_n, end_n)]

trials_dict = {'trial': [], 'trial_norm': [], 'baseline':[],'baseline_norm':[], 'resp_point': [], 'label': [], 'confidence':[], 'crosstime':[],
               'matching':[]}
for c, cond in enumerate(['real','contr']):# Loop over the two classes
    print(cond)
    start = cond_codes[c][0]
    resp = cond_codes[c][1]
    end = cond_codes[c][2]
    where_start = np.where(events[:,2]==start)
    for w in tqdm(where_start[0]): # Loop over the start of each trial
        start_point = events[w][0]
        i=1
        itsended = False # the current trial has an end
        control=False
        resp_point=0
        while(w+i!=len(events)): # Loop over the successive sample after the start, looking for the end
            curr_samp = events[w+i]
            if curr_samp[2]==cross_switch and c==1:
                # control = True
                control_time = curr_samp[0]
            if curr_samp[2]==start_n or curr_samp[2]==start_e: # The end of the trial is missing, discard the trial
                break
            if curr_samp[2]==resp:
                if resp_point==0:
                    itsended=True
                    resp_point=curr_samp[0]
                    end_point=curr_samp[0]
            if curr_samp[2]==end:
                end_point=curr_samp[0]
                itsended=True
            if curr_samp[2] in confs:
                conf_val = curr_samp[-1]
                break
            i+=1

        if itsended:
            resp_point = resp_point-start_point
            if c==0: # Face
                if resp_point>0:
                    # I take the activity from -150 to -50 as a baseline (no stimulation)
                    extracted_data = data.get_data(start=start_point-100, stop=end_point)
                    
                    base = extracted_data[:, :100]
                    # print(base.shape)
                    mean_base = base.mean(axis=1)
                    extracted_data_scale = extracted_data.T-mean_base
                    # print(extracted_data_scale.shape)
                    extracted_data_norm = extracted_data_scale/extracted_data_scale.std() # z-score the trial
                    extracted_data_scale = extracted_data_scale.T
                    extracted_data_norm = extracted_data_norm.T
                    
                    trials_dict['trial'].append(extracted_data_scale[:, 100:])
                    trials_dict['trial_norm'].append(extracted_data_norm[:, 100:])
                    trials_dict['baseline'].append(extracted_data_scale[:, :100])
                    trials_dict['baseline_norm'].append(extracted_data_norm[:, :100])
                    trials_dict['resp_point'].append(resp_point)
                    trials_dict['label'].append(cond)
                    trials_dict['confidence'].append(conf_val)
                    trials_dict['crosstime'].append(None)
            elif c==1: # Control
                if resp_point>0:
                    # I take the activity from -150 to -50 as a baseline (no stimulation)
                    extracted_data = data.get_data(start=start_point-100, stop=end_point)
                    
                    base = extracted_data[:, :100]
                    # print(base.shape)
                    mean_base = base.mean(axis=1)
                    extracted_data_scale = extracted_data.T-mean_base
                    # print(extracted_data_scale.shape)
                    extracted_data_norm = extracted_data_scale/extracted_data_scale.std() # z-score the trial
                    extracted_data_scale = extracted_data_scale.T
                    extracted_data_norm = extracted_data_norm.T
                    control_time = control_time-start_point +time_adjuster
                    # if control_time[0]>0:
                    if control_time>0:

                    
                        trials_dict['trial'].append(extracted_data_scale[:, 100:])
                        trials_dict['trial_norm'].append(extracted_data_norm[:, 100:])
                        trials_dict['baseline'].append(extracted_data_scale[:, :100])
                        trials_dict['baseline_norm'].append(extracted_data_norm[:, :100])
                        trials_dict['resp_point'].append(resp_point)
                        trials_dict['label'].append(cond)
                        trials_dict['confidence'].append(conf_val)
                        trials_dict['crosstime'].append(control_time[0])
                        # trials_dict['crosstime'].append(control_time)

                        
                
    
    
print(len(trials_dict['trial']),len(trials_dict['trial_norm']),len(trials_dict['baseline']),len(trials_dict['baseline_norm']),
      len(trials_dict['resp_point']), len(trials_dict['label']),len(trials_dict['confidence']),)
print(trials_dict['trial'][0].shape,trials_dict['trial_norm'][0].shape,trials_dict['baseline'][0].shape,trials_dict['baseline_norm'][0].shape)

In [None]:
# # Example usage
# N_tr = np.sum(np.asarray(trials_dict['crosstime'])==None)

# list2 = np.array(np.asarray(trials_dict['resp_point'])[np.asarray(trials_dict['crosstime'])!=None]) # control
# list1 = np.asarray(trials_dict['resp_point'])[np.asarray(trials_dict['crosstime'])==None] # Face
# trials_dict['matching'] = np.full((len(list1)), -999)
# tol   = 600

# matches = match_with_tolerance(list1, list2, tol)
# for i, j in enumerate(matches):
#     if j is None:
#         print(f"list1[{i}] = {list1[i]:.2f} → no match")
#     else:
#         print(f"list1[{i}] = {list1[i]:.2f} ↔ list2[{j}] = {list2[j]:.2f}")

# for l1,l2 in enumerate(matches):
#     # print(l1,l2)
#     if l1!=None and l2!=None:
#         trials_dict['matching'][l1] = l2+N_tr
#         # trials_dict['matching'][l2+N_tr] = l1
    

# np.sum(trials_dict['matching']>=0)

#------------- QUELLO VERO DA DECOMMENTARE CON I TRIGGER GIUSTI -------------------------------------#


# Example usage
N_tr = np.sum(np.asarray(trials_dict['crosstime'])==None) # Number of face trials

list2 = np.array(np.asarray(trials_dict['resp_point'])[np.asarray(trials_dict['crosstime'])!=None]) # control
list2_rt = np.array(np.asarray(trials_dict['resp_point'])[np.asarray(trials_dict['crosstime'])!=None]) # control
list1 = np.asarray(trials_dict['resp_point'])[np.asarray(trials_dict['crosstime'])==None] # Face
trials_dict['matching'] = np.full((len(list1)), -999)
tol   = 600

matches = match_with_tolerance(list1, list2, tol)
for i, j in enumerate(matches):
    if j is None:
        print(f"list1[{i}] = {list1[i]:.2f} → no match")
    else:
        print(f"list1[{i}] = {list1[i]:.2f} ↔ list2[{j}] = {list2_rt[j]:.2f}")

for l1,l2 in enumerate(matches):
    # print(l1,l2)
    if l1!=None and l2!=None:
        trials_dict['matching'][l1] = l2+N_tr
        # trials_dict['matching'][l2+N_tr] = l1
    

np.sum(trials_dict['matching']>=0)

In [None]:
print(trials_dict['matching'], len(trials_dict['matching']))

In [None]:
# Print the non-matched trials and controls
print('Controls')
for i in range(N_tr):
    if i not in matches:
        print(trials_dict['crosstime'][N_tr:][i])
print('Face')
for i,m in enumerate(matches):
    if m==None:
        print(trials_dict['resp_point'][:N_tr][i])

In [None]:
### Breaking ###
if breaking:
    with open('D:/PhD/CFS_eeg/data/new_exp/prep_task/breaking/subj_'+subj+'_trials_dict.pkl', 'wb') as handle:
        pickle.dump(trials_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Reverse ###
elif reverse:
    with open('D:/PhD/CFS_eeg/data/new_exp/prep_task/reverse/subj_'+subj+'_rev_trials_dict.pkl', 'wb') as handle:
        pickle.dump(trials_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
fig =plt.figure()
plt.hist(trials_dict['resp_point'][:142], alpha=0.4, bins=20);
plt.hist(np.array(trials_dict['crosstime'])[np.array(trials_dict['crosstime'])!=None], alpha=0.5, bins=20);
plt.show()