<a href="https://colab.research.google.com/github/vitaldb/examples/blob/master/eeg_mac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 뇌파로부터 마취제 농도 예측 인공지능 모델 실습
Sevoflurane 마취 중 뇌파로부터 마취제 농도(age related MAC) 예측 모델

## VitalDB 데이터 셋 이용
본 예제에서는 오픈 생체 신호 데이터셋인 VitalDB를 이용하는 모든 사용자는 반드시 아래 Data Use Agreement에 동의하여야 합니다.

https://vitaldb.net/data-bank/?query=guide&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.usmoena3l4rb

동의하지 않을 경우 이 창을 닫으세요.

## 본 프로그램에서 이용할 라이브러리 설치 및 import

In [1]:
!pip install vitaldb

import vitaldb
import numpy as np
import pandas as pd
import os
import scipy.signal
import matplotlib.pyplot as plt
import random
import itertools as it
import numpy as np
from matplotlib import pyplot as plt

Collecting vitaldb
  Downloading vitaldb-0.0.11-py3-none-any.whl (42 kB)
[?25l[K     |███████▋                        | 10 kB 21.4 MB/s eta 0:00:01[K     |███████████████▎                | 20 kB 25.2 MB/s eta 0:00:01[K     |███████████████████████         | 30 kB 12.5 MB/s eta 0:00:01[K     |██████████████████████████████▋ | 40 kB 9.6 MB/s eta 0:00:01[K     |████████████████████████████████| 42 kB 744 kB/s 
Installing collected packages: vitaldb
Successfully installed vitaldb-0.0.11


# Data loading 및 전처리

VitalDB Web API를 통해 데이터 로딩


In [None]:
SRATE = 128  # in hz
SEGLEN = 4 * SRATE  # samples
BATCH_SIZE = 1024
MAX_CASES = 100

cachefile = '{}sec_{}cases.npz'.format(SEGLEN // SRATE, MAX_CASES)
if os.path.exists(cachefile):
    dat = np.load(cachefile)
    x, y, b, c = dat['x'], dat['y'], dat['b'], dat['c']
else:
    df_trks = pd.read_csv("https://api.vitaldb.net/trks")  # 트랙 정보
    df_cases = pd.read_csv("https://api.vitaldb.net/cases")  # 환자 정보

    # 데이터 로딩 시 컬럼 순서
    EEG = 0
    SEVO = 1
    BIS = 2

    # inclusion & exclusion criteria
    caseids = set(df_cases.loc[df_cases['age'] > 18, 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'BIS/EEG1_WAV', 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'BIS/BIS', 'caseid']) &\
        set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_SEVO', 'caseid'])

    x = []  
    y = []  # sevo
    b = []  # bis
    c = []  # caseids
    icase = 0  # 현재까지 로딩된 케이스 수
    for caseid in caseids:
        print('loading {} ({}/{})'.format(caseid, icase, MAX_CASES), end='...', flush=True)

        # 아래 값들이 있으면 제외
        if np.any(vitaldb.load_case(caseid, 'Orchestra/PPF20_CE') > 0.2):
            print('propofol')
            continue
        if np.any(vitaldb.load_case(caseid, 'Primus/EXP_DES') > 1):
            print('desflurane')
            continue
        if np.any(vitaldb.load_case(caseid, 'Primus/FEN2O') > 2):
            print('n2o')
            continue
        if np.any(vitaldb.load_case(caseid, 'Orchestra/RFTN50_CE') > 0.2):
            print('remifentanil')
            continue

        # extract data
        vals = vitaldb.load_case(caseid, ['BIS/EEG1_WAV', 'Primus/EXP_SEVO', 'BIS/BIS'], 1 / SRATE)
        if np.nanmax(vals[:, SEVO]) < 1:
            print('all sevo <= 1')
            continue

        # convert etsevo to the age related mac
        age = df_cases.loc[df_cases['caseid'] == caseid, 'age'].values[0]
        vals[:, SEVO] /= 1.80 * 10 ** (-0.00269 * (age - 40))

        if not np.any(vals[:, BIS] > 0):
            print('all bis <= 0')
            continue

        # 뇌파가 잘 나와야 하기 때문에 bis가 값이 처음으로 계산되어 나온 곳 부터 시작함
        valid_bis_idx = np.where(vals[:, BIS] > 0)[0]
        first_bis_idx = valid_bis_idx[0]
        last_bis_idx = valid_bis_idx[-1]
        vals = vals[first_bis_idx:last_bis_idx + 1, :]

        if len(vals) < 1800 * SRATE:  # 30분 이하인 case는 사용하지 않음
            print('{} len < 30 min'.format(caseid))
            continue

        # MAC 값과 BIS 값은 5초까지 forward filling
        vals[:, SEVO:] = pd.DataFrame(vals[:, SEVO:]).ffill(limit=5 * SRATE).values

        # case 시작 부터 종료까지 1초 간격으로 데이터 추출하여 dataset 에 넣음
        oldlen = len(y)
        for irow in range(SEGLEN, len(vals), SRATE):
            bis = vals[irow, BIS]
            mac = vals[irow, SEVO]
            if np.isnan(bis) or np.isnan(mac) or bis == 0:
                continue
            # dataset 에 추가
            eeg = vals[irow - SEGLEN:irow, EEG]
            x.append(eeg)
            y.append(mac)
            b.append(bis)
            c.append(caseid)

        # valid case
        icase += 1
        print('{} samples read -> total {} samples ({}/{})'.format(len(y) - oldlen, len(y), icase, MAX_CASES))
        if icase >= MAX_CASES:
            break

    # 입력 데이터셋을 numpy array로 변경
    x = np.array(x)
    y = np.array(y)
    b = np.array(b)
    c = np.array(c)

    # save cahce file
    np.savez(cachefile, x=x, y=y, b=b, c=c)


loading 1 (0/100)...desflurane
loading 2 (0/100)...10381 samples read -> total 10381 samples (1/100)
loading 3 (1/100)...propofol
loading 4 (1/100)...14367 samples read -> total 24748 samples (2/100)
loading 5 (2/100)...propofol
loading 10 (2/100)...14509 samples read -> total 39257 samples (3/100)
loading 12 (3/100)...21126 samples read -> total 60383 samples (4/100)
loading 18 (4/100)...all bis <= 0
loading 19 (4/100)...propofol
loading 20 (4/100)...propofol
loading 21 (4/100)...8087 samples read -> total 68470 samples (5/100)
loading 24 (5/100)...3599 samples read -> total 72069 samples (6/100)
loading 25 (6/100)...9665 samples read -> total 81734 samples (7/100)
loading 26 (7/100)...desflurane
loading 27 (7/100)...11541 samples read -> total 93275 samples (8/100)
loading 30 (8/100)...propofol
loading 33 (8/100)...2797 samples read -> total 96072 samples (9/100)
loading 34 (9/100)...propofol
loading 35 (9/100)...propofol
loading 38 (9/100)...propofol
loading 43 (9/100)...9689 sample

## 뇌파 입력 데이터 필터링

In [None]:
# 결측값이 있으면 제거
print('invalid samples...', end='', flush=True)
valid_mask = ~(np.max(np.isnan(x), axis=1) > 0) # nan이 있으면 제거
valid_mask &= (np.max(x, axis=1) - np.min(x, axis=1) > 12)  # bis 임피던스 체크 eeg의 전체 range가 12 미만이면 제거
x = x[valid_mask, :]
y = y[valid_mask]
b = b[valid_mask]
c = c[valid_mask]
print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))

# 필터링
print('baseline drift...', end='', flush=True)
x -= scipy.signal.savgol_filter(x, 91, 3)  # remove baseline drift
print('removed')

# noise 가 많으면 제거
print('noisy samples...', end='', flush=True)
valid_mask = (np.nanmax(np.abs(x), axis=1) < 100) # noisy sample 

x = x[valid_mask, :]  # CNN 에 넣기 위해서는 3차원이어야 한다. 마지막 차원을 추가
y = y[valid_mask]
b = b[valid_mask]
c = c[valid_mask]
print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))

## 데이터를 학습(train)과 테스트(test)로 나누기

In [None]:
# 최종적으로 로딩 된 caseid
caseids = list(np.unique(c))
random.shuffle(caseids)

# case 단위로 train, test case로 나눔
ntest = max(1, int(len(caseids) * 0.2))
caseids_train = caseids[ntest:]
caseids_test = caseids[:ntest]

train_mask = np.isin(c, caseids_train)
test_mask = np.isin(c, caseids_test)
x_train = x[train_mask]
y_train = y[train_mask]
x_test = x[test_mask]
y_test = y[test_mask]
b_test = b[test_mask]
c_test = c[test_mask]

print('====================================================')
print('total: {} cases {} samples'.format(len(caseids), len(y)))
print('train: {} cases {} samples'.format(len(np.unique(c[train_mask])), len(y_train)))
print('test {} cases {} samples'.format(len(np.unique(c_test)), len(y_test)))
print('====================================================')

train: 45 cases 53222 samples, testing: 5 cases 5266 samples


# Model building


In [None]:
import keras.models
import tensorflow as tf
from keras.models import Model
from keras.layers import Layer, LayerNormalization, Dense, Dropout, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Input, concatenate, multiply, dot, MultiHeadAttention
from keras.callbacks import EarlyStopping, ModelCheckpoint

# hyperparameters
tests = {
    "nfilt" : [16, 32, 64],
    "fnode" : [32, 64, 128],
    "clayer" : [1, 2, 3, 4],
    "droprate" : [0.1, 0.2],
    "filtsize" : [5, 7, 9, 11],
    'poolsize' : [2, 4, 8],
    "pooltype" : ['avg', 'max']
}

# https://keras.io/examples/nlp/text_classification_with_transformer/
keys, values = zip(*tests.items())
permutations_dicts = it.product(*values)
permutations_dicts = list(permutations_dicts)
random.shuffle(permutations_dicts)
for nfilt, fnode, clayer, droprate, filtsize, poolsize, pooltype in permutations_dicts:

    keras.backend.clear_session()
    
    odir = '{}cases_{}sec'.format(MAX_CASES, SEGLEN // SRATE)
    odir += '_cnn{} filt{} size{} pool{} {} do{}'.format(clayer, nfilt, filtsize, poolsize, pooltype, droprate)
    print("============================")
    print(odir)
    print("============================")

    out = inp = Input(shape=(x_train.shape[1], 1))
    # initial cnn layer
    out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same')(out)
    # conv 여러층    
    for i in range(clayer):
        out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same', activation='relu')(out)
        out = MaxPooling1D(poolsize, padding='same')(out)
    if pooltype == "avg":
        out = GlobalAveragePooling1D()(out)
    else:
        out = GlobalMaxPooling1D()(out)

    if droprate:
        out = Dropout(droprate)(out)
    out = Dense(fnode)(out)
    if droprate:
        out = Dropout(droprate)(out)
    out = Dense(1)(out)

    if not os.path.exists(odir):
        os.mkdir(odir)

    cache_path = odir + "/weights.hdf5"
    model = Model(inputs=[inp], outputs=[out])
    model.summary()
    model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['mean_absolute_error'])
    hist = model.fit(x_train[..., None], y_train, validation_split=0.2, epochs=10, batch_size=BATCH_SIZE,
                    callbacks=[ModelCheckpoint(monitor='val_loss', filepath=cache_path, verbose=1, save_best_only=True),
                               EarlyStopping(monitor='val_loss', patience=1, verbose=1, mode='auto'),
                               ])

    # prediction
    pred_test = model.predict(x_test[..., None], batch_size=BATCH_SIZE).flatten()

    # 성능을 계산하여 출력
    test_mae = np.mean(np.abs(y_test - pred_test))
    for caseid in np.unique(c_test):
        case_mask = (c_test == caseid)
        pred_test[case_mask] = scipy.signal.medfilt(pred_test[case_mask], 31)

    # prediction 및 그림
    for caseid in np.unique(c_test):
        case_mask = (c_test == caseid)
        case_len = np.sum(case_mask)
        if case_len == 0:
            continue

        our_mae = np.mean(np.abs(y_test[case_mask] - pred_test[case_mask]))
        print('Total MAE={:.4f}, CaseID {}, MAE={:.4f}'.format(test_mae, caseid, our_mae))

        t = np.arange(0, case_len)
        plt.figure(figsize=(20, 5))
        plt.plot(t, y_test[case_mask], label='MAC')  # 측정 결과 
        plt.plot(t, pred_test[case_mask], label='Ours ({:.4f})'.format(our_mae))
        plt.legend(loc="upper left")
        plt.tight_layout()
        plt.xlim([0, case_len])
        plt.ylim([0, 2])
        plt.show()

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
