In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
import pandas as pd
from sklearn.model_selection import train_test_split
from joblib import dump,load


In [None]:
raw_data = (sorted(glob('/Users/niraj/Desktop/Research Works/recordings/*[0-9].edf')))

annot_data = (sorted(glob('/Users/niraj/Desktop/Research Works/recordings/*sleepscoring.edf')))



In [None]:
raw_train = []
for data in raw_data:
    raw = mne.io.read_raw_edf(data,infer_types =True,preload=False)
    raw_train.append(raw)


In [None]:
annot_train = []
for data in annot_data:
    annot = mne.read_annotations(data)
    annot_train.append(annot)
    print(annot)

In [None]:
for i in range(len(raw_data)):
    raw_train[i].set_annotations(annot_train[i],emit_warning=False,on_missing='ignore')


In [None]:
# df = pd.concat((data) for data in raw_train)
raw_train[0]

In [None]:
# raw_train[3].plot(start = 60,duration= 120,scalings= dict(eeg=1e-3,emg = 1e-2,eog=1e-4,misc=10))
# plt.show()

In [None]:
annotations_event_id = {"Sleep stage W": 1,
                       "Sleep stage N1": 2,
                       "Sleep stage N2": 3,
                       "Sleep stage N3": 4,
                       "Sleep stage N4": 4,
                       "Sleep stage R": 5}

# events_train, _ = mne.events_from_annotations(
#     raw_train[0], event_id=annotations_event_id, chunk_duration=30.0)

In [None]:
event_id= {'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage 3/4': 4,
            'Sleep stage R': 5}

stage_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]



In [None]:
# tmax = 30.0  # tmax in included


# epochs_train = mne.Epochs(raw=raw_train[100], events=events_train, on_missing='warn',
#                           event_id=event_id, tmin=0., tmax=tmax, baseline=None)
# print(epochs_train)
# print(raw_train[100])

In [None]:
tmax = 30.0
epochs_data = []
for i in range(0,5):
    events_train, _ = mne.events_from_annotations(
    raw_train[i], event_id=annotations_event_id, chunk_duration=30.0)
    epochs_train = mne.Epochs(raw=raw_train[i], events=events_train, on_missing='warn',
                          event_id=event_id, tmin=0.0, tmax=tmax, baseline=None)
    # epochs_train.drop_bad()
    epochs_data.append(epochs_train)

# epochs_data
concatenated_epochs = mne.concatenate_epochs(epochs_data)





# (epochs_data[0:10])
# print(type(epochs_data))
# concatenated_epochs = mne.concatenate_epochs(epochs_data)
# concatenated_epochs

In [None]:
concatenated_epochs.info

In [None]:
concatenated_epochs.info['ch_names']

In [None]:
fig, (ax1) = plt.subplots(ncols=1)

# iterate over the subjects
stages = sorted(event_id.keys())
for ax, epochs in zip([ax1], [concatenated_epochs]):
    for stage, color in zip(stages, stage_colors):
        spectrum = epochs[stage].compute_psd(fmin=0.1, fmax=20.0)
        spectrum.plot(
            ci=None,
            color=color,
            axes=ax,
            show=False,
            average=True,
            spatial_colors=False,
            picks="data",
            exclude="bads",
        )
    ax.set( xlabel="Frequency (Hz)")
ax1.set(ylabel="µV²/Hz (dB)")
# ax2.legend(ax2.lines[2::3], stages)

In [None]:
concatenated_epochs.info['ch_names']


In [None]:
concatenated_epochs.info

In [None]:
fig, (ax1) = plt.subplots(ncols=1)

# iterate over the subjects
stages = sorted(event_id.keys())
for ax, epochs in zip([ax1], [concatenated_epochs]):
    for stage, color in zip(stages, stage_colors):
        spectrum = epochs[stage].compute_psd(fmin=0.1, fmax=20.0)
        spectrum.plot(
            ci=None,
            color=color,
            axes=ax,
            show=False,
            average=True,
            spatial_colors=False,
            picks='ECG',
            exclude="bads",
        )
    ax.set( xlabel="Frequency (Hz)")
ax1.set(ylabel="µV²/Hz (dB)")

In [None]:
tmax = 30.0
epochs_data = []
for i in range(10,12):
    events_train, _ = mne.events_from_annotations(
    raw_train[i], event_id=annotations_event_id, chunk_duration=30.0)
    epochs_train = mne.Epochs(raw=raw_train[i], events=events_train, on_missing='warn',
                          event_id=event_id, tmin=0.0, tmax=tmax, baseline=None)
    # epochs_train.drop_bad()
    epochs_data.append(epochs_train)

# epochs_data
test_epochs = mne.concatenate_epochs(epochs_data)

In [None]:
# concatenated_epochs.shape

In [None]:
# fig, (ax1,ax2)= plt.subplots(ncols=2)
# stages= sorted(event_id.keys())
# for ax, title, epochs in zip([ax1, ax2], [raw_train,raw_train1],[epochs_train, epochs_train1]):
#     for stage, color in zip(stages, stage_colors):
#         spectrum = epochs[stage].compute_psd(fmin=0.1,fmax=30.0)
#         # print("This is stage: ",stage)
#         spectrum.plot(ci=None,color=color,axes=ax,show=False,average=True,spatial_colors=False,picks='all',exclude='bads')
#     ax.set(title=title,xlabel='Frequency(Hz)')
# ax1.set(ylabel='µV²/Hz (dB)')

# ax2.legend(ax2.lines[2::3],stages)
# ax1.legend(ax1.lines[2::3],stages)
# plt.show()

In [None]:
def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data.
    """
    # specific frequency bands
    FREQ_BANDS = {
        "delta": [0.5, 4.5],
        "theta": [4.5, 8.5],
        "alpha": [8.5, 11.5],
        "sigma": [11.5, 15.5],
        "beta": [15.5, 30],
    }

    spectrum = epochs.compute_psd(picks=['data'], fmin=0.5, fmax=30.0)
    psds, freqs = spectrum.get_data(return_freqs=True)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

In [None]:
temp = eeg_power_band(concatenated_epochs)
temp.shape

In [None]:
# temp.shape

In [None]:
# df = pd.DataFrame(temp)

In [None]:
# concatenated_epochs.shape

In [None]:
# test_epochs.shape

In [None]:
pipe = make_pipeline(
    FunctionTransformer(eeg_power_band, validate=False),
    RandomForestClassifier(n_estimators=100, random_state=42),
)
y = concatenated_epochs.events[:,2]
# Train
# y_train = epochs_train.events[:5, 2]

pipe.fit(concatenated_epochs, y)

# Test
y_pred = pipe.predict(test_epochs)

# # Assess the results
y_test = test_epochs.events[:,2]
acc = accuracy_score(y_test, y_pred)

print("Accuracy score: {}".format(acc))

In [None]:
# dump(pipe,"model.joblib")