# 前処理

In [1]:
import os
import json

import numpy as np
from sklearn.model_selection import KFold
from torchinfo import summary

from eeg_utils import *
from eeg_lstm import *

In [2]:
train_data,target=prepare_all()

left_channels=[0, 2, 4, 6, 8, 10, 12, 14, 20, 22, 24, 26, 28, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 64, 66, 68, 70]
right_channels=[1, 3, 5, 7, 9, 11, 13, 15, 21, 23, 25, 27, 29, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 65, 67, 69, 71]
other_channels=[i for i in range(72) if i not in left_channels and i not in right_channels]

In [3]:
def rl_std(data):
    std = data.std(axis=1) # (N,C)
    std[std==0]=1.0
    stdmed = np.median(std,axis=0) # (C,)
    std = np.array([np.clip(std[:,c],stdmed[c],np.inf) for c in range(72)])
    std = std.T
    normalized = data/std[:,None,:]

    # 左右差
    left=normalized[:,:,left_channels].copy()
    right=normalized[:,:,right_channels].copy()
    normalized[:,:,left_channels] = left-right
    normalized[:,:,right_channels] = left+right
    
    std = normalized.std(axis=1)
    std[std==0]=1.0
    stdmed = np.median(std,axis=0) # (C,)
    std = np.array([np.clip(std[:,c],stdmed[c],np.inf) for c in range(72)])
    std = std.T
    normalized = normalized/(2*std[:,None,:])
    return normalized

def make_train_data(sub, train_index,val_index):
    normalized_data=[]
    for i in range(3):
        # 左右差
        data = train_data[sub][i].copy()
        normalized = rl_std(data)
        normalized = np.clip(normalized, -3,3)

        normalized_data.append(normalized)
        
    X_train = np.concatenate([normalized_data[i] for i in train_index])[:,:,:72] # +/-50ms
    y_train = np.concatenate([target[sub][i] for i in train_index])

    X_val = normalized_data[val_index[0]][:,25:275,:72].copy()
    y_val = target[sub][val_index[0]]
          
    return X_train,X_val,y_train,y_val


def make_test_data(sub):
    data = test_data[sub].copy()
    normalized = rl_std(data)
    normalized = np.clip(normalized, -3,3)

    return normalized


# 学習

In [4]:
def save_prediction(pred_dict,sub,isplit,save_path):
    os.makedirs(save_path, exist_ok=True)
    filename = os.path.join(save_path, f"{sub}_{isplit}.json")
    with open(filename, 'w') as f:
        json.dump(pred_dict, f)    

In [5]:
weights={}
for sub in train_data.keys():
    weights[sub]=[]
    cv=3
    kf=KFold(cv)
    for isplit, (train_index, val_index) in enumerate(kf.split(train_data[sub])):
        X_train,X_val,y_train,y_val=make_train_data(sub, train_index, val_index)

        param=dict(
            num_epochs=500,
            num_classes=3,
            lr=2e-6,
            weight_decay=1e-0,
            batch_size=16,
            dropout=0.0,
            conv_params=[
                dict(out_channels=2048, kernel_size=5, stride=2),
                dict(out_channels=512, kernel_size=3, stride=1),
            ],
            seed=1,
        )
        est=SignalEstimator(**param)
        model_summary=summary(model=est.model, input_size=(1,250,72))
        train_loss,val_loss=est.fit(X_train,y_train,val_X=X_val,val_y=y_val,verbose=0)
        weights[sub].append(est.model.state_dict())

        predictions = est(X_val)
        predictions_dict={f"{sub}-{isplit}":predictions.tolist()}
        save_prediction(predictions_dict,sub,isplit,"conv1d_val")
                                    
        print("loss",est.minimum_loss)
        print("acc",est.score(X_val, y_val%10))

Epoch [500/500], Loss: 0.0041, Val Loss: 0.2728, Acc: 0.905
loss 0.22065124028845678
acc 0.9177215189873418
Epoch [500/500], Loss: 0.0035, Val Loss: 0.4229, Acc: 0.856
loss 0.31453006267547606
acc 0.8875
Epoch [500/500], Loss: 0.0071, Val Loss: 0.4720, Acc: 0.830
loss 0.3747638966302452
acc 0.8364779874213837
Epoch [500/500], Loss: 0.0058, Val Loss: 0.7153, Acc: 0.806
loss 0.4784692764282227
acc 0.81875
Epoch [500/500], Loss: 0.0079, Val Loss: 0.3940, Acc: 0.863
loss 0.3565009593963623
acc 0.85625
Epoch [500/500], Loss: 0.0045, Val Loss: 0.4052, Acc: 0.881
loss 0.3328737222923423
acc 0.8742138364779874
Epoch [500/500], Loss: 0.0076, Val Loss: 0.6136, Acc: 0.794
loss 0.4050909519195557
acc 0.85625
Epoch [500/500], Loss: 0.0044, Val Loss: 0.5140, Acc: 0.830
loss 0.3615709460756314
acc 0.8805031446540881
Epoch [500/500], Loss: 0.0086, Val Loss: 0.5734, Acc: 0.823
loss 0.38253861439378956
acc 0.879746835443038
Epoch [500/500], Loss: 0.0030, Val Loss: 0.3736, Acc: 0.875
loss 0.3065016508102

# 予測

In [6]:
test_data=prepare_testdata()

In [7]:
for sub in train_data.keys():
    for isplit,weight in enumerate(weights[sub]):
        param=dict(
            num_epochs=500,
            num_classes=3,
            lr=2e-6,
            weight_decay=1e-0,
            batch_size=16,
            dropout=0.0,
            conv_params=[
                dict(out_channels=2048, kernel_size=5, stride=2),
                dict(out_channels=512, kernel_size=3, stride=1),
            ],
            seed=1,
        )
        est=SignalEstimator(**param)
        est.minimum_loss_weight=weight

        X_test=make_test_data(sub)
        predictions = est(X_test)
        predictions_dict={f"{sub}-{isplit}":predictions.tolist()}
        save_prediction(predictions_dict,sub,isplit,"conv1d_test")