# 前処理

In [1]:
import os
import json

import numpy as np
from sklearn.model_selection import KFold
from torchinfo import summary

from eeg_utils import *
from eeg_lstm import *

In [2]:
train_data,target=prepare_all()

left_channels=[0, 2, 4, 6, 8, 10, 12, 14, 20, 22, 24, 26, 28, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 64, 66, 68, 70]
right_channels=[1, 3, 5, 7, 9, 11, 13, 15, 21, 23, 25, 27, 29, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 65, 67, 69, 71]
other_channels=[i for i in range(72) if i not in left_channels and i not in right_channels]

In [3]:
def make_train_data(sub, train_index,val_index):
    normalized_data=[]
    for i in range(3):
        # 左右差
        data = train_data[sub][i].copy()
        
        left=data[:,:,left_channels].copy()
        right=data[:,:,right_channels].copy()
        data[:,:,left_channels] = left-right
        data[:,:,right_channels] = left+right

        std = data.std(axis=(0,1),keepdims=True)
        std[std==0]=1.0
        normalized = data/std
        normalized = np.clip(normalized, -4,4)
        normalized_data.append(normalized)
        
    X_train = np.concatenate([normalized_data[i] for i in train_index])[:,:,:72] # +/-50ms
    y_train = np.concatenate([target[sub][i] for i in train_index])

    X_val = normalized_data[val_index[0]][:,25:275,:72].copy()
    y_val = target[sub][val_index[0]]
          
    return X_train,X_val,y_train,y_val


def make_test_data(sub):
    X_test = test_data[sub][:,:,:72].copy()

    # 左右差
    left=X_test[:,:,left_channels].copy()
    right=X_test[:,:,right_channels].copy()
    X_test[:,:,left_channels] = left-right
    X_test[:,:,right_channels] = left+right
    
    std = X_test.std(axis=(0,1),keepdims=True)
    std[std==0]=1
    X_test = X_test / std
    X_test = np.clip(X_test, -4,4)

    return X_test


# 学習

In [4]:
def save_prediction(pred_dict,sub,isplit,save_path):
    os.makedirs(save_path, exist_ok=True)
    filename = os.path.join(save_path, f"{sub}_{isplit}.json")
    with open(filename, 'w') as f:
        json.dump(pred_dict, f)    

In [5]:
weights={}
for sub in train_data.keys():
    weights[sub]=[]
    cv=3
    kf=KFold(cv)
    for isplit, (train_index, val_index) in enumerate(kf.split(train_data[sub])):
        X_train,X_val,y_train,y_val=make_train_data(sub, train_index, val_index)

        param=dict(
            num_epochs=50,
            num_classes=3,
            lr=3e-5,
            weight_decay=1e-3,
            batch_size=16,
            dropout=0.0,
            conv_params=[
                dict(out_channels=128, kernel_size=5, stride=2),
            ],
            lstm_param=dict(
                hidden_size=256,
                num_layers=6,
            ),
            seed=1,
        )
        est=SignalEstimator(**param)
        model_summary=summary(model=est.model, input_size=(1,250,72))
        train_loss,val_loss=est.fit(X_train,y_train,val_X=X_val,val_y=y_val,verbose=0)
        weights[sub].append(est.model.state_dict())

        predictions = est(X_val)
        predictions_dict={f"{sub}-{isplit}":predictions.tolist()}
        save_prediction(predictions_dict,sub,isplit,"lstm_val")
                                    
        print("loss",est.minimum_loss)
        print("acc",est.score(X_val, y_val%10))

Epoch [50/50], Loss: 0.3179, Val Loss: 1.4003, Acc: 0.563
loss 0.6891713202754154
acc 0.6582278481012658
Epoch [50/50], Loss: 0.3543, Val Loss: 0.6714, Acc: 0.731
loss 0.5152502536773682
acc 0.725
Epoch [50/50], Loss: 0.3146, Val Loss: 0.5981, Acc: 0.761
loss 0.5094840091729315
acc 0.7295597484276729
Epoch [50/50], Loss: 0.1964, Val Loss: 0.6531, Acc: 0.781
loss 0.517423677444458
acc 0.7
Epoch [50/50], Loss: 0.0916, Val Loss: 0.5583, Acc: 0.844
loss 0.5583277225494385
acc 0.84375
Epoch [50/50], Loss: 0.3462, Val Loss: 0.7298, Acc: 0.711
loss 0.5817463233036065
acc 0.7169811320754716
Epoch [50/50], Loss: 0.0222, Val Loss: 1.2884, Acc: 0.669
loss 0.7336989402770996
acc 0.675
Epoch [50/50], Loss: 0.0355, Val Loss: 0.8152, Acc: 0.805
loss 0.4693814643523978
acc 0.8364779874213837
Epoch [50/50], Loss: 0.0331, Val Loss: 1.1672, Acc: 0.728
loss 0.667559949657585
acc 0.7278481012658228
Epoch [50/50], Loss: 0.0303, Val Loss: 1.0109, Acc: 0.762
loss 0.33721065521240234
acc 0.89375
Epoch [50/50],

# 予測

In [7]:
test_data=prepare_testdata()

In [8]:
for sub in train_data.keys():
    for isplit,weight in enumerate(weights[sub]):
        param=dict(
            num_epochs=50,
            num_classes=3,
            lr=3e-5,
            weight_decay=1e-3,
            batch_size=16,
            dropout=0.0,
            conv_params=[
                dict(out_channels=128, kernel_size=5, stride=2),
            ],
            lstm_param=dict(
                hidden_size=256,
                num_layers=6,
            ),
            seed=1,
        )
        est=SignalEstimator(**param)
        est.minimum_loss_weight=weight

        X_test=make_test_data(sub)
        predictions = est(X_test)
        predictions_dict={f"{sub}-{isplit}":predictions.tolist()}
        save_prediction(predictions_dict,sub,isplit,"lstm_test")