# Import Libraries

In [1]:
# Standard libraries
import json
import numpy as np
import pandas as pd
import scipy.signal as signal
from scipy.fft import fft, fftfreq
import matplotlib.pyplot as plt

# Custom libraries
from Functions import processing

# Enable interactive plots
%matplotlib qt

# Import epoched data and settings

In [2]:
# Load list of files to import
files = [  
    "sub-P003_ses-S001_task-T1_run-001_eeg"
]

# Get unique subject IDs
subject_ids = [file.split('_')[0] for file in files]
unique_subject_ids = list(set(subject_ids))

# Preallocate variables to store EEG data and settings
eeg_epochs = [None] * len(files)
settings = [None] * len(files)

# Import data
for f, file in enumerate(files):
    # Import EEG data, since it is stored in a compressed numpy file (.npz) we need to use the np.load function 
    loaded_data = np.load(f"Data\\Pilot2\\EEG\\sub-P003\\ses-S001\\eeg\\{file}.npz", allow_pickle=True)

    # Access the data for each stimulus
    eeg_epochs[f] = {stim_label: loaded_data[stim_label] for stim_label in loaded_data.files}

    # Import settings
    with open(f"Data\\Pilot2\\EEG\\sub-P003\\ses-S001\\eeg\\{file}.json", "r") as file_object:
        settings[f] = json.load(file_object)

# Calculate the amplitude

In [None]:
window_size = 5  # Still used to define epoch size in seconds

# Preallocate variables
eeg_f = [None] * len(files)
eeg_amp = [None] * len(files)  # Store normalized amplitude

for f, file in enumerate(files):
    eeg_f[f] = {}
    eeg_amp[f] = {}

    for stim_label, epochs in eeg_epochs[f].items():
        eeg_f[f][stim_label] = []
        eeg_amp[f][stim_label] = []

        for epoch in epochs:
            # FFT requires 1D array; apply FFT per channel if needed
            if epoch.ndim == 2:  # channels x time
                amp_list = []
                freq_list = []
                for ch in epoch:
                    fft_vals = np.fft.fft(ch)
                    freqs = np.fft.fftfreq(len(ch), d=1/settings[f]["eeg_srate"])
                    
                    # Keep only the positive frequencies
                    idxs = freqs > 0
                    fft_magnitude = np.abs(fft_vals[idxs])
                    fft_magnitude /= np.sum(fft_magnitude)  # Normalize
                    
                    amp_list.append(fft_magnitude)
                    freq_list = freqs[idxs]  # same for all channels
                eeg_f[f][stim_label].append(freq_list)
                eeg_amp[f][stim_label].append(np.array(amp_list))
            else:
                # Single-channel data
                fft_vals = np.fft.fft(epoch)
                freqs = np.fft.fftfreq(len(epoch), d=1/settings[f]["eeg_srate"])
                idxs = freqs > 0
                fft_magnitude = np.abs(fft_vals[idxs])
                fft_magnitude /= np.sum(fft_magnitude)  # Normalize

                eeg_f[f][stim_label].append(freqs[idxs])
                eeg_amp[f][stim_label].append(fft_magnitude)

        eeg_f[f][stim_label] = np.array(eeg_f[f][stim_label])
        eeg_amp[f][stim_label] = np.array(eeg_amp[f][stim_label])


# Plot

In [4]:
plot_amp = True  # Enable to see plots
f_limits = [5, 35]  # Frequency limits for the plots [min, max][Hz]
file_to_plot = 0    # Select index of file to be plotted
num_stimuli = 12    # Number of stimuli

if plot_amp:
    fig, axes = plt.subplots(3, 4, figsize=(15, 10))
    fig.suptitle(f'Normalized FFT Amplitude for All Stimuli in File: {files[file_to_plot]}')
    axes = axes.flatten()

    for stim_idx in range(num_stimuli):
        stim_label = list(eeg_f[file_to_plot].keys())[stim_idx]
        
        # Extract frequency mask for plotting
        f_values = eeg_f[file_to_plot][stim_label][0]
        fmask = (f_values >= f_limits[0]) & (f_values <= f_limits[1])
        temp_freq = f_values[fmask]
        
        for epoch_idx, epoch_amp in enumerate(eeg_amp[file_to_plot][stim_label]):
            # epoch_amp shape: (n_channels, n_freqs) or (n_freqs,) if single-channel
            if epoch_amp.ndim == 2:
                # Optionally average across channels, or pick a specific channel
                avg_amp = epoch_amp.mean(axis=0)[fmask]
            else:
                avg_amp = epoch_amp[fmask]

            axes[stim_idx].plot(temp_freq, avg_amp, label=f'Epoch {epoch_idx+1}')

        axes[stim_idx].set_title(f'Stimulus: {stim_label}')
        axes[stim_idx].set_xlim(f_limits)
        axes[stim_idx].set_ylabel("Normalized Amplitude")

        if stim_idx >= 8:
            axes[stim_idx].set_xlabel("Frequency [Hz]")

        axes[stim_idx].legend()

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
