In [1]:
import numpy as np
import os
import random
import scipy
from scipy.interpolate import griddata
from scipy import signal
from numpy.fft import fft
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.initializers import he_normal

from sklearn.model_selection import train_test_split

import pandas as pd
import pickle

import seaborn as sns
sns.set(font='Yu Gothic')
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.simplefilter(action='ignore', category=RuntimeWarning)

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
channels

In [None]:
files = os.listdir("201103/")
files

In [None]:
O = pd.read_csv("201103/Sekiguchi_O_03.11.20_23.28.22.md.csv", skiprows=1)
O

In [None]:
O = O.values[:, 3:22]
O.shape

In [None]:
plt.figure(figsize=(24, 16))
plt.plot(O[:, 0:-5])
plt.show()

In [None]:
tt = np.linspace(0, O.shape[0]//128, O.shape[0])
plt.figure(figsize=(24, 16))
plt.plot(tt, O[:, 0:-5])
plt.xlim(30, 30.5)
plt.xlabel("Time[Second]")
plt.ylabel("Amplitude")
plt.show()

In [None]:
Fs = 128
# firフィルタを適用する
def firFilter(x):
    # fir filter
    b = signal.firwin(511, [1.0/(Fs/2), 15.0/(Fs/2)], pass_zero=False)
    x[:, :-1] = signal.filtfilt(b, 1, x[:, :-1], axis=0)
    return x

In [None]:
# triggerからNポイント抽出する
N = 128
def triggerExtract(x, trigger):
    triggers = np.where(x[:, -1] == trigger)[0]
    print(len(triggers))
    if trigger == -1:
        triggers = triggers[:-2]
        result = np.zeros((len(triggers), N, x.shape[1]))
    else:
        result = np.zeros((len(triggers), N, x.shape[1]))
    i = 0
    for tmp in triggers:
        tmp = x[tmp:tmp+N].reshape([1, -1, x.shape[1]])
        if tmp.shape[1] == N:
            result[i] = tmp
            i += 1
    return result

In [None]:
# 波形の加算平均をとる
def WaveAverage(x):
    result = np.zeros((x.shape[1], x.shape[2]))
    for i in range(x.shape[0]):
        result += x[i]
    result = result / x.shape[0]
    return result

In [None]:
def preprocess(x):
    x = firFilter(x)
    ttrigger = triggerExtract(x, [2, 8])
    ntrigger = triggerExtract(x, [0, 1, 3, 4, 5, 6, 7, 9, 10, 11])
    ttrigger_ave = WaveAverage(ttrigger)
    ntrigger_ave = WaveAverage(ntrigger)
    
    tt = np.linspace(0, 1, ttrigger.shape[1])
    channel = 15
    plt.figure(figsize=(24, 18))
    plt.subplots_adjust(wspace=0.4, hspace=0.8)
    for i in range(len(channels)):
        plt.subplot(8, 4, i+1)
        plt.plot(tt, ttrigger_ave[:, i], 'r', label="target")
        plt.plot(tt, ntrigger_ave[:, i], 'b', label="nontarget")
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.xlabel("Second[s]")
        plt.ylabel("Amplitude[μV]")
        plt.title(channels[i], fontsize=18)
#     plt.legend()
    return ttrigger_ave, ntrigger_ave

In [None]:
tt = np.linspace(0, 2, O[:, 0].shape[0])
channel = 15
plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.8)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O[:, i])
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.xticks(np.arange(0, 2.25, 0.25))
    plt.title(channels[i], fontsize=18)

In [None]:
O = firFilter(O)

In [None]:
O_target = triggerExtract(O, 2)
O_target = np.concatenate([O_target, triggerExtract(O, 8)], axis=0)

O_nontarget = triggerExtract(O, 0)
for num in [0, 1, 3, 4, 5, 6, 7, 9, 10, 11]:
    O_nontarget = np.concatenate([O_nontarget, triggerExtract(O, num)], axis=0)

O_target.shape, O_nontarget.shape

In [None]:
tt = np.linspace(0, 1, O_target.shape[1])

plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target[5, :, i], 'r', label='target')
    plt.plot(tt, O_nontarget[5, :, i], 'b', label='nontarget')
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.title(channels[i], fontsize=18)
    plt.legend()

In [None]:
O_target_ave = WaveAverage(O_target)
O_nontarget_ave = WaveAverage(O_nontarget)

O_target_ave.shape, O_nontarget_ave.shape

In [None]:
tt = np.linspace(0, 1, O_target.shape[1])
channel = 15
plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target_ave[:, i], 'r', label="target")
    plt.plot(tt, O_nontarget_ave[:, i], 'b', label="nontarget")
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
#     plt.xlim(0.1, 0.6)
    plt.title(channels[i], fontsize=18)
#     plt.legend()

## Target(O) vs non Target(A)

In [None]:
O_target = triggerExtract(O, 2)
O_target = np.concatenate([O_target, triggerExtract(O, 8)], axis=0)

O_nontarget = triggerExtract(O, 0)
for num in [6]:
    O_nontarget = np.concatenate([O_nontarget, triggerExtract(O, num)], axis=0)

tt = np.linspace(0, 1, O_target.shape[1])

plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target[5, :, i], 'r', label='target')
    plt.plot(tt, O_nontarget[5, :, i], 'b', label='nontarget')
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.title(channels[i], fontsize=18)
    plt.legend()

In [None]:
O_target_ave = WaveAverage(O_target)
O_nontarget_ave = WaveAverage(O_nontarget)

tt = np.linspace(0, 1, O_target.shape[1])
channel = 15
plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target_ave[:, i], 'r', label="target")
    plt.plot(tt, O_nontarget_ave[:, i], 'b', label="nontarget")
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.title(channels[i], fontsize=18)
#     plt.legend()

## Target(O) vs non Target(P)

In [None]:
O_target = triggerExtract(O, 2)
O_target = np.concatenate([O_target, triggerExtract(O, 8)], axis=0)

O_nontarget = triggerExtract(O, 2)
for num in [9]:
    O_nontarget = np.concatenate([O_nontarget, triggerExtract(O, num)], axis=0)

tt = np.linspace(0, 1, O_target.shape[1])

plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target[5, :, i], 'r', label='target')
    plt.plot(tt, O_nontarget[-5, :, i], 'b', label='nontarget')
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.title(channels[i], fontsize=18)
    plt.legend()

In [None]:
O_target_ave = WaveAverage(O_target)
O_nontarget_ave = WaveAverage(O_nontarget)

tt = np.linspace(0, 1, O_target.shape[1])
channel = 15
plt.figure(figsize=(24, 18))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
for i in range(len(channels)):
    plt.subplot(4, 4, i+1)
    plt.plot(tt, O_target_ave[:, i], 'r', label="target")
    plt.plot(tt, O_nontarget_ave[:, i], 'b', label="nontarget")
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.xlabel("Second[s]")
    plt.ylabel("Amplitude[μV]")
    plt.title(channels[i], fontsize=18)
#     plt.legend()