In [1]:
import os
import mne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from joblib import dump, load

In [3]:
data_dir = "/Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette" 
output_dir = "/Users/janet/projectnyx/sleep_data/EEG_preprocessed"       # Directory to save preprocessed files in a seperate project folder!
os.makedirs(output_dir, exist_ok=True)

In [5]:
def preprocess_sleep_edf(data_dir, output_dir, window_size=30):
    for file in os.listdir(data_dir):
        if file.endswith('PSG.edf'):
            file_path = os.path.join(data_dir, file)
            try:
                raw = mne.io.read_raw_edf(file_path, preload=True)
                raw.pick_channels(['EEG Fpz-Cz', 'EEG Pz-Oz'])
                raw.filter(0.5, 40)
                data = raw.get_data()
                sfreq = raw.info['sfreq'] # what does this do
                epoch_samples = int(window_size * sfreq)
                epochs = np.array([np.array_split(channel, len(channel) // epoch_samples) for channel in data]) # saving as numpy array
                output_file = os.path.join(output_dir, file.replace('PSG.edf', 'preprocessed.npz'))
                np.savez(output_file, epochs=epochs, sfreq=sfreq)
                print(f"Processed: {file}")
            except Exception as e:
                print(f"Error processing {file}: {str(e)}")

def visualize_data():
    # raw data
    raw.plot(duration= 60, n_channels= 2, title = "Raw Data")
    plt.savefig(os.path.join(output_dir, f"{filename}_raw.png"))
    plt.close()
    
    # first epoch in each preprocessed data

In [None]:
def load_data(output_dir):
    X = []
    y = []
    for file in os.listdir(output_dir):
        if file.endswith('.npz'):
            data = np.load(os.path.join(output_dir, file))
            epochs = data['epochs']
            # Flatten the epochs and add to X
            X.extend([epoch.flatten() for epoch in epochs[0]])  # Using only the first channel for simplicity
            # Generate random labels (0-4) for demonstration. Replace with actual labels.
            y.extend(np.random.randint(0, 5, size=epochs.shape[1]))
    return np.array(X), np.array(y)

def train_random_forest(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
    rf_classifier.fit(X_train, y_train)
    y_pred = rf_classifier.predict(X_test) #prediction
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.2f}")
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred))

    return rf_classifier

# Main
if __name__ == "__main__":
    data_dir = "/Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette"
    output_dir = "/Users/janet/projectnyx/sleep_data/EEG_preprocessed"
    os.makedirs(output_dir, exist_ok=True)

    preprocess_sleep_edf(data_dir, output_dir)
    print("Preprocessing completed. Preprocessed files saved in:", output_dir)
    X, y = load_data(output_dir)

    rf_model = train_random_forest(X, y) #training


    dump(rf_model, os.path.join(output_dir, 'random_forest_model.joblib'))
    print("Model trained and saved as 'random_forest_model.joblib'")

Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4002E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


Reading 0 ... 8489999  =      0.000 ... 84899.990 secs...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4002E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4061E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 3611999  =      0.000 ... 36119.990 

  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4061E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4031E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8459999  =      0.000 ... 84599.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4031E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4052E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8417999  =      0.000 ... 84179.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4052E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4022E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8267999  =      0.000 ... 82679.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4022E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4041E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7709999  =      0.000 ... 77099.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4041E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4011E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8405999  =      0.000 ... 84059.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4011E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4021E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8411999  =      0.000 ... 84119.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 661 samples (6.610 s)

Processed: SC4021E0-PSG.edf
Extracting EDF parameters from /Users/janet/projectnyx/sleep_data/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/SC4042E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8375999  =      0.000 ... 83759.990 secs...


  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
  raw = mne.io.read_raw_edf(file_path, preload=True)
