# Compute the SNR of pilot data

## Goals:
1. Import the preprocessed data from npz and json files
2. Calculate the Signal-to-Noise Ratio (SNR) for each epoch (for use in stats)
3. Export the SNR


# Import Libraries

In [None]:
# Standard libraries
import json
import numpy as np
import pandas as pd
import scipy.signal as signal

# Custom libraries
from Functions import processing

# Import Epoched Data and Settings

In [2]:
# Load list of files to import
files = [  
    "sub-P003_ses-S001_task-T1_run-001_eeg"
]

# Get unique subject IDs
subject_ids = [file.split('_')[0] for file in files]
unique_subject_ids = list(set(subject_ids))

# Preallocate variables to store EEG data and settings
eeg_epochs = [None] * len(files)
settings = [None] * len(files)
rs_open_data = [None] * len(files)

# Import data
for f, file in enumerate(files):
    # Import EEG data, since it is stored in a compressed numpy file (.npz) we need to use the np.load function 
    loaded_data = np.load(f"Data\\Pilot2\\EEG\\sub-P003\\ses-S001\\eeg\\{file}.npz", allow_pickle=True)

    # Access the data for each stimulus
    eeg_epochs[f] = {stim_label: loaded_data[stim_label] for stim_label in loaded_data.files}

    # Import settings
    with open(f"Data\\Pilot2\\EEG\\sub-P003\\ses-S001\\eeg\\{file}.json", "r") as file_object:
        settings[f] = json.load(file_object)

    # Import RS eyes open data
    rs_open_data[f] = np.load(f"Data\\Pilot2\\EEG\\sub-P003\\ses-S001\\eeg\\{file}_baseline.npz", allow_pickle=True)

# Calculate PSD of all Epochs

In [3]:
# PSD settings
window_size = 5  # 10 = 0.1 Hz resolution, 5 = 0.2 Hz resolution, 2 = 0.5 Hz resolution

# Preallocate variables
eeg_f = [None] * len(files)
eeg_pxx = [None] * len(files)  # Preallocate to list in case not all files have the same number of channels

# Compute PSD for each file
for f, file in enumerate(files):
    eeg_f[f] = {}
    eeg_pxx[f] = {}

    # Compute PSD for each stimulus
    for stim_label, epochs in eeg_epochs[f].items():
        eeg_f[f][stim_label] = []
        eeg_pxx[f][stim_label] = []

        # Compute PSD for each epoch
        for epoch in epochs:
            f_values, pxx_values = signal.welch(
                x=epoch,
                fs=settings[f]["eeg_srate"],
                nfft=int(window_size * settings[f]["eeg_srate"]),
                nperseg=window_size * settings[f]["eeg_srate"],
                noverlap= (window_size * settings[f]["eeg_srate"]) * 0.5,  # 50% overlap between windows
            )
            eeg_f[f][stim_label].append(f_values)
            eeg_pxx[f][stim_label].append(pxx_values)

        # Convert lists to arrays for consistency
        eeg_f[f][stim_label] = np.array(eeg_f[f][stim_label])
        eeg_pxx[f][stim_label] = np.array(eeg_pxx[f][stim_label])

# Compute SNR for all Epochs
- SNR is calculated for each epoch and then averaged per stimulus

In [4]:
# Settings
noise_band = 1    # Single-sided noise band [Hz]
nharms = 2        # Number of harmonics used
db_out = True     # Boolean to get output in dB
stim_freq = 10.0  # Example frequency, replace with your actual frequency

# 1. Collect all unique channel names from all files
all_channel_names = set()
for s in settings:
    all_channel_names.update(s["new_ch_names"])
all_channel_names = sorted(list(all_channel_names))  # Keep consistent order

# 2. Initialize containers
snr = [None] * len(files)
epoch_count_snr = {}
# epoch_count_snr can be computed once if consistent
for stim_label in settings[0]["stimuli"].values():
    epoch_count_snr[stim_label] = eeg_pxx[0][stim_label].shape[0]

# 3. Compute SNR with per-file channel alignment
for f0 in range(len(files)):
    stim_labels = list(settings[f0]["stimuli"].values())
    file_channels = settings[f0]["new_ch_names"]
    ch_idx_map = {ch: i for i, ch in enumerate(file_channels)}

    temp_snr = np.zeros([len(stim_labels), len(all_channel_names)])

    for stim_idx, stim_label in settings[f0]["stimuli"].items():
        s = stim_labels.index(stim_label)
        channel_snr_list = []

        num_epochs = eeg_pxx[f0][stim_label].shape[0]  
        for epoch in range(num_epochs):
            snr_epoch = processing.ssvep_snr(
                f=eeg_f[f0][stim_label][epoch],  # shape: (n_freqs,)
                pxx=eeg_pxx[f0][stim_label][epoch, :, :],  # shape: (n_channels, n_freqs)
                stim_freq=stim_freq,
                noise_band=noise_band,
                nharms=nharms,
                db_out=db_out
            )
            channel_snr_list.append(snr_epoch)  # shape: (n_channels,)

        # Average across epochs → (n_channels,)
        mean_snr = np.mean(np.stack(channel_snr_list), axis=0)

        # Assign into unified SNR array, mapping to correct channel indices
        for i, ch_name in enumerate(all_channel_names):
            if ch_name in ch_idx_map:
                temp_snr[s, i] = mean_snr[ch_idx_map[ch_name]]
            else:
                temp_snr[s, i] = 0  # Channel missing in this file

    snr[f0] = temp_snr

# Export SNR

In [5]:
save_snr = True  # Boolean to save SNRs to CSV

# Preallocate empty list to store all dataFrames
dfs = []

for f0, file in enumerate(files):
    col_names = []
    snr_shape = snr[f0].shape  # (n_stimuli, n_channels)
    temp_snr = np.zeros((len(all_channel_names), len(settings[f0]["stimuli"])))

    col_idx = 0
    for s, stimuli in settings[f0]["stimuli"].items():
        s_int = int(s)
        if s_int < snr_shape[0]:
            temp_snr[:, col_idx] = snr[f0][s_int, :]
            col_names.append(f"{stimuli}")
            col_idx += 1
        else:
            print(f"Stimulus index {s} is out of bounds for snr[f0].shape[0]: {snr_shape[0]}")

    row_names = [f"{file.split('_')[0]} - {channel}" for channel in all_channel_names]

    dfs.append(
        pd.DataFrame(
            data=temp_snr,
            columns=col_names,
            index=row_names
        )
    )

# Concatenate all DataFrames
snr_df = pd.concat(dfs)

# Save SNRs to CSV
if save_snr:
    #snr_df.to_csv("Data\\Pilot-data\\EEG\\All\\snr.csv")
    snr_df.to_csv("averaged_snr_output.csv", index=False)
