In [None]:
import mne
import os

import numpy as np
import pandas as pd
from scipy.stats import zscore

In [None]:
class EEG_Data:
    def __init__(self, file_path):
        """Initializes the EEG_Data object by loading an EEG file.

        Args:
            file_path (str): File path to the EEG data file. Supports .edf and .set files.

        Raises:
            FileNotFoundError: If the file path does not point to an existing file.
            ValueError: If the file extension is not .edf or .set.
        """
        
        # Checks if the file path exists
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"File not found: {file_path}") # Error display for no file path found
        
        # Initialize variables
        self.file_path = file_path 
        self.raw = None # Store the raw MNE data after loading
        
        ext = os.path.splitext(file_path)[1].lower() # Getting extension from either .set or .edf
        
        self.samp_freq=256
        
        # Loading appropiate file type
        if ext == '.edf':
            self.raw = mne.io.read_raw_edf(file_path, preload=True)
        elif ext == '.set':
            self.raw = mne.io.read_raw_eeglab(file_path, preload=True)
        else:
            raise ValueError(f"Unsupported file type: {ext}")
        
        print(f"Loaded {ext} file: {file_path}")
        
    @property
    def info(self):
        """Returns the raw MNE data info

        Returns:
            Dict: metadata such as channel names, sample rate, high pass, low pass and n channels
        """
        return self.raw.info
    
    def filter_data(self, low_pass=0.5, high_pass=40, notch_freq=60):
        """Applies low pass, high pass, sample frequency, and notch filter to raw MNE data

        Args:
            low_pass (float, optional): Defaults to 0.5.
            high_pass (int, optional):Defaults to 40.
            samp_freq (int, optional):Defaults to 256.
            notch_freq (int, optional): Defaults to 60.
        """
        
        # Low and pass filter
        self.raw.filter(l_freq=low_pass, h_freq=high_pass, inplace=True)
        
        # Sample frequency
        self.raw.resample(self.samp_freq)
        
        # Notch filter
        self.raw.notch_filter(freqs=notch_freq, inplace=True)
        
    # def z_score_std(self):
    #     """Standardizes the values with a mean = 0 and standard deviation of 1
    #     """
    
    #     data = self.raw.get_data() # Gets data in a given range
    #     z_data = zscore(data, axis=1) # Applies z score standardization across all data
    #     self.raw._data[:] = z_data # Modifies stored data inside self.raw object
        
    def ch_names_data(self):
        """Retrives names and data of EEG channels from raw MNE data

        Returns:
            Dict: _description_
        """
        # Returns list of indices of only EEG data
        picks = mne.pick_types(self.raw.info, eeg=True)
    
        # Get EEG channel names from list of EEG indices
        eeg_channel_names = [self.raw.info['ch_names'][i] for i in picks]

        # Get EEG data for those channels
        eeg_data = self.raw.get_data(picks=picks) # np.2d array (n_channels, n_times)
        
        return eeg_channel_names, eeg_data
    
    # def save_edf(self, file_name):
    #     """Saves the EEG data (channel names and signals) to a new edf file.

    #     Args:
    #         file_name (str): Name of the output EDF file. Must end with '.edf'.
    #     """
        
    #     ch_names, eeg_data = self.ch_names_data()
        
    #     edf_info = mne.create_info(ch_names=ch_names, sfreq=self.samp_freq, ch_types=['eeg']*len(ch_names))
        
    #     edf_raw = mne.io.RawArray(eeg_data, edf_info)
        
    #     if file_name.endswith('.edf'):
    #         mne.export.export_raw(file_name, edf_raw, fmt='edf')
    #         print(f"Saved edf file to: {file_name}")
    #     else:
    #         print(f"{file_name} needs to end with '.edf'")