<a href="https://colab.research.google.com/github/vitaldb/examples/blob/master/eeg_mac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prediction of anesthetic concentration from EEG
In this example, we will build a model to predict anesthetic concentration (age-related MAC) from EEG during Sevoflurane anesthesia.

> Note that <b>all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below. 
</b> If you do not agree, please close this window. 
Click here: [Data Use Agreement](https://vitaldb.net/dataset/?query=overview&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.vcpgs1yemdb5)

## Required libraries

In [1]:
!pip install vitaldb

import vitaldb
import numpy as np
import pandas as pd
import os
import scipy.signal
import matplotlib.pyplot as plt
import random
import itertools as it
import numpy as np
from matplotlib import pyplot as plt

Collecting vitaldb
  Downloading vitaldb-0.0.11-py3-none-any.whl (42 kB)
[?25l[K     |███████▋                        | 10 kB 21.4 MB/s eta 0:00:01[K     |███████████████▎                | 20 kB 25.2 MB/s eta 0:00:01[K     |███████████████████████         | 30 kB 12.5 MB/s eta 0:00:01[K     |██████████████████████████████▋ | 40 kB 9.6 MB/s eta 0:00:01[K     |████████████████████████████████| 42 kB 744 kB/s 
Installing collected packages: vitaldb
Successfully installed vitaldb-0.0.11


## Preprocessing

Load data using <code>VitalDB Web API</code>

In [None]:
SRATE = 128  # in hz
SEGLEN = 4 * SRATE  # samples
BATCH_SIZE = 1024
MAX_CASES = 100

cachefile = '{}sec_{}cases.npz'.format(SEGLEN // SRATE, MAX_CASES)
if os.path.exists(cachefile):
    dat = np.load(cachefile)
    x, y, b, c = dat['x'], dat['y'], dat['b'], dat['c']
else:
    df_trks = pd.read_csv("https://api.vitaldb.net/trks")  # track information
    df_cases = pd.read_csv("https://api.vitaldb.net/cases")  # patient information

    # Column order when loading data
    EEG = 0
    SEVO = 1
    BIS = 2

    # Inclusion & Exclusion criteria
    caseids = set(df_cases.loc[df_cases['age'] > 18, 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'BIS/EEG1_WAV', 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'BIS/BIS', 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_SEVO', 'caseid'])

    x = []  
    y = []  # sevo
    b = []  # bis
    c = []  # caseids
    icase = 0  # number of loaded cases
    for caseid in caseids:
        print('loading {} ({}/{})'.format(caseid, icase, MAX_CASES), end='...', flush=True)

        # Excluding the following values
        if np.any(vitaldb.load_case(caseid, 'Orchestra/PPF20_CE') > 0.2):
            print('propofol')
            continue
        if np.any(vitaldb.load_case(caseid, 'Primus/EXP_DES') > 1):
            print('desflurane')
            continue
        if np.any(vitaldb.load_case(caseid, 'Primus/FEN2O') > 2):
            print('n2o')
            continue
        if np.any(vitaldb.load_case(caseid, 'Orchestra/RFTN50_CE') > 0.2):
            print('remifentanil')
            continue

        # Extract data
        vals = vitaldb.load_case(caseid, ['BIS/EEG1_WAV', 'Primus/EXP_SEVO', 'BIS/BIS'], 1 / SRATE)
        if np.nanmax(vals[:, SEVO]) < 1:
            print('all sevo <= 1')
            continue

        # Convert etsevo to the age related mac
        age = df_cases.loc[df_cases['caseid'] == caseid, 'age'].values[0]
        vals[:, SEVO] /= 1.80 * 10 ** (-0.00269 * (age - 40))

        if not np.any(vals[:, BIS] > 0):
            print('all bis <= 0')
            continue

        # Since the EEG should come out well, we start from the location where the value of bis was first calculated.
        valid_bis_idx = np.where(vals[:, BIS] > 0)[0]
        first_bis_idx = valid_bis_idx[0]
        last_bis_idx = valid_bis_idx[-1]
        vals = vals[first_bis_idx:last_bis_idx + 1, :]

        if len(vals) < 1800 * SRATE:  # Do not use cases that are less than 30 minutes
            print('{} len < 30 min'.format(caseid))
            continue

        # Forward fill in MAC value and BIS value up to 5 seconds
        vals[:, SEVO:] = pd.DataFrame(vals[:, SEVO:]).ffill(limit=5 * SRATE).values

        # Extract data every 1 second from its start to its end and then put into the dataset
        oldlen = len(y)
        for irow in range(SEGLEN, len(vals), SRATE):
            bis = vals[irow, BIS]
            mac = vals[irow, SEVO]
            if np.isnan(bis) or np.isnan(mac) or bis == 0:
                continue
            # add dataset
            eeg = vals[irow - SEGLEN:irow, EEG]
            x.append(eeg)
            y.append(mac)
            b.append(bis)
            c.append(caseid)

        # Valid case
        icase += 1
        print('{} samples read -> total {} samples ({}/{})'.format(len(y) - oldlen, len(y), icase, MAX_CASES))
        if icase >= MAX_CASES:
            break

    # Change the input dataset to a numpy array
    x = np.array(x)
    y = np.array(y)
    b = np.array(b)
    c = np.array(c)

    # Save cahce file
    np.savez(cachefile, x=x, y=y, b=b, c=c)


loading 1 (0/100)...desflurane
loading 2 (0/100)...10381 samples read -> total 10381 samples (1/100)
loading 3 (1/100)...propofol
loading 4 (1/100)...14367 samples read -> total 24748 samples (2/100)
loading 5 (2/100)...propofol
loading 10 (2/100)...14509 samples read -> total 39257 samples (3/100)
loading 12 (3/100)...21126 samples read -> total 60383 samples (4/100)
loading 18 (4/100)...all bis <= 0
loading 19 (4/100)...propofol
loading 20 (4/100)...propofol
loading 21 (4/100)...8087 samples read -> total 68470 samples (5/100)
loading 24 (5/100)...3599 samples read -> total 72069 samples (6/100)
loading 25 (6/100)...9665 samples read -> total 81734 samples (7/100)
loading 26 (7/100)...desflurane
loading 27 (7/100)...11541 samples read -> total 93275 samples (8/100)
loading 30 (8/100)...propofol
loading 33 (8/100)...2797 samples read -> total 96072 samples (9/100)
loading 34 (9/100)...propofol
loading 35 (9/100)...propofol
loading 38 (9/100)...propofol
loading 43 (9/100)...9689 sample

## Filtering input data

In [None]:
# Remove missing values
print('invalid samples...', end='', flush=True)
valid_mask = ~(np.max(np.isnan(x), axis=1) > 0) # nan이 있으면 제거
valid_mask &= (np.max(x, axis=1) - np.min(x, axis=1) > 12)  # bis 임피던스 체크 eeg의 전체 range가 12 미만이면 제거
x = x[valid_mask, :]
y = y[valid_mask]
b = b[valid_mask]
c = c[valid_mask]
print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))

# Filtering
print('baseline drift...', end='', flush=True)
x -= scipy.signal.savgol_filter(x, 91, 3)  # remove baseline drift
print('removed')

# Remove if the value of noise is bigger than 100
print('noisy samples...', end='', flush=True)
valid_mask = (np.nanmax(np.abs(x), axis=1) < 100) # noisy sample 

x = x[valid_mask, :]  # To use CNN, it should be three-dimensional. Therefore, add the dimension.
y = y[valid_mask]
b = b[valid_mask]
c = c[valid_mask]
print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))

## Splitting samples into training and testing dataset

In [None]:
# caseid
caseids = list(np.unique(c))
random.shuffle(caseids)

# Split dataset into training and testing data
ntest = max(1, int(len(caseids) * 0.2))
caseids_train = caseids[ntest:]
caseids_test = caseids[:ntest]

train_mask = np.isin(c, caseids_train)
test_mask = np.isin(c, caseids_test)
x_train = x[train_mask]
y_train = y[train_mask]
x_test = x[test_mask]
y_test = y[test_mask]
b_test = b[test_mask]
c_test = c[test_mask]

print('====================================================')
print('total: {} cases {} samples'.format(len(caseids), len(y)))
print('train: {} cases {} samples'.format(len(np.unique(c[train_mask])), len(y_train)))
print('test {} cases {} samples'.format(len(np.unique(c_test)), len(y_test)))
print('====================================================')

train: 45 cases 53222 samples, testing: 5 cases 5266 samples


## Modeling and Evaluation

In [None]:
import keras.models
import tensorflow as tf
from keras.models import Model
from keras.layers import Layer, LayerNormalization, Dense, Dropout, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Input, concatenate, multiply, dot, MultiHeadAttention
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Hyperparameters
tests = {
    "nfilt" : [16, 32, 64],
    "fnode" : [32, 64, 128],
    "clayer" : [1, 2, 3, 4],
    "droprate" : [0.1, 0.2],
    "filtsize" : [5, 7, 9, 11],
    'poolsize' : [2, 4, 8],
    "pooltype" : ['avg', 'max']
}

# https://keras.io/examples/nlp/text_classification_with_transformer/
keys, values = zip(*tests.items())
permutations_dicts = it.product(*values)
permutations_dicts = list(permutations_dicts)
random.shuffle(permutations_dicts)
for nfilt, fnode, clayer, droprate, filtsize, poolsize, pooltype in permutations_dicts:

    keras.backend.clear_session()
    
    odir = '{}cases_{}sec'.format(MAX_CASES, SEGLEN // SRATE)
    odir += '_cnn{} filt{} size{} pool{} {} do{}'.format(clayer, nfilt, filtsize, poolsize, pooltype, droprate)
    print("============================")
    print(odir)
    print("============================")

    out = inp = Input(shape=(x_train.shape[1], 1))
    # Initialize cnn layer
    out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same')(out)
    # Multilayer conv 
    for i in range(clayer):
        out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same', activation='relu')(out)
        out = MaxPooling1D(poolsize, padding='same')(out)
    if pooltype == "avg":
        out = GlobalAveragePooling1D()(out)
    else:
        out = GlobalMaxPooling1D()(out)

    if droprate:
        out = Dropout(droprate)(out)
    out = Dense(fnode)(out)
    if droprate:
        out = Dropout(droprate)(out)
    out = Dense(1)(out)

    # -------------
    #  Save models
    # ------------- 
    if not os.path.exists(odir):
        os.mkdir(odir)

    cache_path = odir + "/weights.hdf5"
    model = Model(inputs=[inp], outputs=[out])
    model.summary()
    model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['mean_absolute_error'])
    hist = model.fit(x_train[..., None], y_train, validation_split=0.2, epochs=10, batch_size=BATCH_SIZE,
                    callbacks=[ModelCheckpoint(monitor='val_loss', filepath=cache_path, verbose=1, save_best_only=True),
                               EarlyStopping(monitor='val_loss', patience=1, verbose=1, mode='auto'),
                               ])

    # Prediction
    pred_test = model.predict(x_test[..., None], batch_size=BATCH_SIZE).flatten()

    # Calculate the performance
    test_mae = np.mean(np.abs(y_test - pred_test))
    for caseid in np.unique(c_test):
        case_mask = (c_test == caseid)
        pred_test[case_mask] = scipy.signal.medfilt(pred_test[case_mask], 31)

    # Evaluation
    for caseid in np.unique(c_test):
        case_mask = (c_test == caseid)
        case_len = np.sum(case_mask)
        if case_len == 0:
            continue

        our_mae = np.mean(np.abs(y_test[case_mask] - pred_test[case_mask]))
        print('Total MAE={:.4f}, CaseID {}, MAE={:.4f}'.format(test_mae, caseid, our_mae))

        t = np.arange(0, case_len)
        plt.figure(figsize=(20, 5))
        plt.plot(t, y_test[case_mask], label='MAC')  # 측정 결과 
        plt.plot(t, pred_test[case_mask], label='Ours ({:.4f})'.format(our_mae))
        plt.legend(loc="upper left")
        plt.tight_layout()
        plt.xlim([0, case_len])
        plt.ylim([0, 2])
        plt.show()

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
