# Import Libraries

In [1]:
# Standard libraries
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Enable interactive plots
%matplotlib qt

# Import epoched data and settings

In [2]:
# Load list of files to import
files = [  
    "sub-P001_ses-S001_task-T1_run-001_eeg",
    "sub-P002_ses-S001_task-T1_run-001_eeg",
    "sub-P003_ses-S001_task-T1_run-001_eeg",
    "sub-P004_ses-S001_task-T1_run-001_eeg",
    "sub-P005_ses-S001_task-T1_run-001_eeg",
    "sub-P006_ses-S001_task-T1_run-001_eeg",
    "sub-P007_ses-S001_task-T1_run-001_eeg",
    "sub-P008_ses-S001_task-T1_run-001_eeg", 
    "sub-P009_ses-S001_task-T1_run-001_eeg",
    "sub-P010_ses-S001_task-T1_run-001_eeg",  
]

# Get unique subject IDs
subject_ids = [file.split('_')[0] for file in files]
unique_subject_ids = list(set(subject_ids))

# Preallocate variables to store EEG data and settings
loaded_data = [None] * len(files)
eeg_epochs = [{} for _ in range(len(files))]
settings = [None] * len(files)


# Import data
for f, file in enumerate(files):
    for sub in subject_ids:
        if sub == file.split('_')[0]:
            # Import EEG data, since it is stored in a compressed numpy file (.npz) we need to use the np.load function 
            loaded_data[f]= np.load(f"..\\Data\\Pilot2\\EEG\\{sub}\\ses-S001\\eeg\\{file}.npz", allow_pickle=True)

            # Access the data for each stimulus
            eeg_epochs[f] = {stim_label: loaded_data[f][stim_label] for stim_label in loaded_data[f]}

            # Import settings
            with open(f"..\\Data\\Pilot2\\EEG\\{sub}\\ses-S001\\eeg\\{file}.json", "r") as file_object:
                settings[f] = json.load(file_object)

# Calculate the amplitude

In [3]:
target_freqs = [5, 10, 20, 30]  # Define your SSVEP frequencies

# Preallocate
eeg_features = [None] * len(files)  # [subjects][stim_label][epoch][channel][freq]
eeg_features_norm = [None] * len(files)  # Normalized features
eeg_f = [None] * len(files)
eeg_amp = [None] * len(files)
eeg_amp_norm = [None] * len(files)

for f, file in enumerate(files):
    srate = settings[f]["eeg_srate"]
    eeg_features[f] = {}
    eeg_features_norm[f] = {}
    eeg_f[f] = {}
    eeg_amp[f] = {}
    eeg_amp_norm[f] = {}

    for stim_label, epochs in eeg_epochs[f].items():
        eeg_features[f][stim_label] = []
        eeg_features_norm[f][stim_label] = []
        eeg_f[f][stim_label] = []
        eeg_amp[f][stim_label] = []
        eeg_amp_norm[f][stim_label] = []

        for epoch in epochs:
            if epoch.ndim == 2:  # Multi-channel
                n_channels, n_time = epoch.shape
            else:  # Single-channel
                epoch = np.expand_dims(epoch, axis=0)
                n_channels, n_time = epoch.shape

            freqs = np.fft.fftfreq(n_time, d=1/srate)
            pos_freqs = freqs > 0
            freqs = freqs[pos_freqs]

            ch_features = []  # will be [channels x target_freqs]
            ch_features_norm = []
            ch_ffts = []      # full spectrum, per channel
            ch_ffts_norm = []  # full spectrum, per channel, normalized

            for ch_data in epoch:
                fft_vals = np.fft.fft(ch_data)
                fft_mag = np.abs(fft_vals)[pos_freqs]  # Absolute mag
                fft_mag_norm = fft_mag/np.sum(fft_mag)  # Normalize to sum = 1

                amp_at_targets = []
                amp_at_targets_norm = []
                for tf in target_freqs:
                    idx = np.argmin(np.abs(freqs - tf))
                    amp_at_targets.append(fft_mag[idx])
                    amp_at_targets_norm.append(fft_mag_norm[idx])

                ch_features.append(amp_at_targets)
                ch_features_norm.append(amp_at_targets_norm)
                ch_ffts.append(fft_mag)
                ch_ffts_norm.append(fft_mag_norm)

            eeg_features[f][stim_label].append(ch_features)        # [channels x target_freqs]
            eeg_features_norm[f][stim_label].append(ch_features_norm)
            eeg_amp[f][stim_label].append(np.array(ch_ffts))       # [channels x full_freqs]
            eeg_amp_norm[f][stim_label].append(np.array(ch_ffts_norm))
            eeg_f[f][stim_label].append(freqs)                     # [full_freqs]

        eeg_features[f][stim_label] = np.array(eeg_features[f][stim_label])  # [epochs x channels x target_freqs]
        eeg_features_norm[f][stim_label] = np.array(eeg_features_norm[f][stim_label])
        eeg_amp[f][stim_label] = np.array(eeg_amp[f][stim_label])            # [epochs x channels x full_freqs]
        eeg_amp_norm[f][stim_label] = np.array(eeg_amp_norm[f][stim_label])
        eeg_f[f][stim_label] = np.array(eeg_f[f][stim_label])                # [epochs x full_freqs]

# Get target freqs in their own arrays

In [4]:
abs_fft_5 = [{} for _ in range(len(files))]
abs_fft_10 = [{} for _ in range(len(files))]
abs_fft_20 = [{} for _ in range(len(files))]
abs_fft_30 = [{} for _ in range(len(files))]

norm_fft_5 = [{} for _ in range(len(files))]
norm_fft_10 = [{} for _ in range(len(files))]
norm_fft_20 = [{} for _ in range(len(files))]
norm_fft_30 = [{} for _ in range(len(files))]

for f, file in enumerate(files):
    for stim_label in eeg_features[f]:
        abs_fft_5[f][stim_label] = eeg_features[f][stim_label][:, :, 0]
        abs_fft_10[f][stim_label] = eeg_features[f][stim_label][:, :, 1]
        abs_fft_20[f][stim_label] = eeg_features[f][stim_label][:, :, 2]
        abs_fft_30[f][stim_label] = eeg_features[f][stim_label][:, :, 3]

    for stim_label in eeg_features_norm[f]:
        norm_fft_5[f][stim_label] = eeg_features_norm[f][stim_label][:, :, 0]
        norm_fft_10[f][stim_label] = eeg_features_norm[f][stim_label][:, :, 1]
        norm_fft_20[f][stim_label] = eeg_features_norm[f][stim_label][:, :, 2]
        norm_fft_30[f][stim_label] = eeg_features_norm[f][stim_label][:, :, 3]

# Plot Absolute

In [5]:
plot_full_fft = True
file_to_plot = 1
num_stimuli = 12
selected_channels = [14]  # Choose channels to show
f_limits = [5, 35]

if plot_full_fft:
    fig, axes = plt.subplots(3, 4, figsize=(15, 10))
    fig.suptitle(f'ABSOLUTE FFT Spectrum with Target Frequencies: File {files[file_to_plot]}')
    axes = axes.flatten()

    for stim_idx in range(num_stimuli):
        stim_label = list(eeg_amp[file_to_plot].keys())[stim_idx]
        amp_data = np.array(eeg_amp[file_to_plot][stim_label])   # [epochs x channels x freqs]
        freq_axis = eeg_f[file_to_plot][stim_label][0]           # [freqs]
        fmask = (freq_axis >= f_limits[0]) & (freq_axis <= f_limits[1])
        freqs_plot = freq_axis[fmask]

        for ch_idx in selected_channels:
            for epoch_idx in range(amp_data.shape[0]):
                amp = amp_data[epoch_idx, ch_idx, fmask]
                axes[stim_idx].plot(freqs_plot, amp, label=f'Epoch {epoch_idx+1}, Ch {ch_idx}', alpha=0.5)

        # Vertical lines at target freqs
        for tf in target_freqs:
            if f_limits[0] <= tf <= f_limits[1]:
                axes[stim_idx].axvline(tf, color='red', linestyle='--', alpha=0.3)

        axes[stim_idx].set_title(f'Stimulus: {stim_label}')
        axes[stim_idx].set_xlim(f_limits)
        axes[stim_idx].set_ylabel("Absolute FFT Amplitude")

        if stim_idx >= 8:
            axes[stim_idx].set_xlabel("Frequency [Hz]")

        axes[stim_idx].legend(fontsize='xx-small', ncol=2)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Plot normalized

In [6]:
plot_full_fft = True
file_to_plot = 1
num_stimuli = 12
selected_channels = [14]  # Choose channels to show
f_limits = [5, 35]

if plot_full_fft:
    fig, axes = plt.subplots(3, 4, figsize=(15, 10))
    fig.suptitle(f'NORMALIZED FFT Spectrum with Target Frequencies: File {files[file_to_plot]}')
    axes = axes.flatten()

    for stim_idx in range(num_stimuli):
        stim_label = list(eeg_amp_norm[file_to_plot].keys())[stim_idx]
        amp_data = np.array(eeg_amp_norm[file_to_plot][stim_label])   # [epochs x channels x freqs]
        freq_axis = eeg_f[file_to_plot][stim_label][0]           # [freqs]
        fmask = (freq_axis >= f_limits[0]) & (freq_axis <= f_limits[1])
        freqs_plot = freq_axis[fmask]

        for ch_idx in selected_channels:
            for epoch_idx in range(amp_data.shape[0]):
                amp = amp_data[epoch_idx, ch_idx, fmask]
                axes[stim_idx].plot(freqs_plot, amp, label=f'Epoch {epoch_idx+1}, Ch {ch_idx}', alpha=0.5)

        # Vertical lines at target freqs
        for tf in target_freqs:
            if f_limits[0] <= tf <= f_limits[1]:
                axes[stim_idx].axvline(tf, color='red', linestyle='--', alpha=0.3)

        axes[stim_idx].set_title(f'Stimulus: {stim_label}')
        axes[stim_idx].set_xlim(f_limits)
        axes[stim_idx].set_ylabel("Normalized FFT Amplitude")

        if stim_idx >= 8:
            axes[stim_idx].set_xlabel("Frequency [Hz]")

        axes[stim_idx].legend(fontsize='xx-small', ncol=2)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Export

In [7]:
export = True

if export:
    records = []

    for f, file in enumerate(files):
        participant_id = file.split('_')[0]

        all_fft_data = {
            'Abs_fft_5': abs_fft_5[f],
            'Abs_fft_10': abs_fft_10[f],
            'Abs_fft_20': abs_fft_20[f],
            'Abs_fft_30': abs_fft_30[f],
            'Norm_fft_5': norm_fft_5[f],
            'Norm_fft_10': norm_fft_10[f],
            'Norm_fft_20': norm_fft_20[f],
            'Norm_fft_30': norm_fft_30[f],
        }

        # Iterate over stimuli present in the band data
        for stim_label in list(all_fft_data.values())[0].keys():
            n_epochs = len(all_fft_data["Abs_fft_5"][stim_label])  # Safe number of available epochs
       
            for epoch_idx in range(n_epochs):
                print(f"Processing Participant: {participant_id}, Stimulus: {stim_label}, Epoch: {epoch_idx}")
                row = {
                    'Participant': participant_id,
                    'Stimulus': stim_label,
                    'Epoch': epoch_idx,
                }

                for fft_name, fft_dict in all_fft_data.items():
                    temp_fft = fft_dict[stim_label][epoch_idx]  # shape: (channels,)
                    for ch_idx, ch_name in enumerate(settings[f]['new_ch_names']):
                        col_name = f"{ch_name}_{fft_name.lower()}"
                        row[col_name] = temp_fft[ch_idx]

                records.append(row)

    # Combine all into one DataFrame and save
    df_all = pd.DataFrame(records)
    df_all.to_csv("all_participants_eeg_amp.csv", index=False)

Processing Participant: sub-P001, Stimulus: Contrast1Size1, Epoch: 0
Processing Participant: sub-P001, Stimulus: Contrast1Size1, Epoch: 1
Processing Participant: sub-P001, Stimulus: Contrast1Size1, Epoch: 2
Processing Participant: sub-P001, Stimulus: Contrast1Size2, Epoch: 0
Processing Participant: sub-P001, Stimulus: Contrast1Size2, Epoch: 1
Processing Participant: sub-P001, Stimulus: Contrast1Size2, Epoch: 2
Processing Participant: sub-P001, Stimulus: Contrast1Size3, Epoch: 0
Processing Participant: sub-P001, Stimulus: Contrast1Size3, Epoch: 1
Processing Participant: sub-P001, Stimulus: Contrast2Size1, Epoch: 0
Processing Participant: sub-P001, Stimulus: Contrast2Size1, Epoch: 1
Processing Participant: sub-P001, Stimulus: Contrast2Size1, Epoch: 2
Processing Participant: sub-P001, Stimulus: Contrast2Size2, Epoch: 0
Processing Participant: sub-P001, Stimulus: Contrast2Size2, Epoch: 1
Processing Participant: sub-P001, Stimulus: Contrast2Size2, Epoch: 2
Processing Participant: sub-P001, 