In [1]:
# back to project root
%cd ~/research

import argparse
import gc
import os
import sys
import time
from glob import glob

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from tqdm import tqdm
import yaml
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from torch import nn, optim

sys.path.append("src")
from group.passing.dataset import make_data_loaders, make_all_data
from group.passing.lstm_model import LSTMModel
from utility.activity_loader import load_individuals
from utility.logger import logger
from tools.train_passing import init_model, init_loss, init_optim, train, test

/raid6/home/yokoyama/research


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams["font.size"] = 24
plt.rcParams['xtick.direction'] = 'in'  # x axis in
plt.rcParams['ytick.direction'] = 'in'  # y axis in

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
cfg_path = "config/passing/pass_train.yaml"
with open(cfg_path, "r") as f:
    train_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["individual"], "r") as f:
    ind_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["group"], "r") as f:
    grp_cfg = yaml.safe_load(f)

In [5]:
data_dirs_all = {}
for room_num, surgery_items in train_cfg["dataset"]["setting"].items():
    for surgery_num in surgery_items.keys():
        dirs = sorted(glob(os.path.join("data", room_num, surgery_num, "passing", "*")))
        data_dirs_all[f"{room_num}_{surgery_num}"] = dirs

inds = {}
for key_prefix, dirs in tqdm(data_dirs_all.items()):
    for model_path in dirs:
        num = model_path.split("/")[-1]
        json_path = os.path.join(model_path, ".json", "individual.json")
        tmp_inds = load_individuals(json_path, ind_cfg)
        for pid, ind in tmp_inds.items():
            inds[f"{key_prefix}_{num}_{pid}"] = ind

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:05<00:00, 10.98s/it]


# グリッドサーチ

In [6]:
# create data loader
dataset_cfg = train_cfg["dataset"]
passing_defs = grp_cfg["passing"]["default"]
train_loader, val_loader, test_loader = make_data_loaders(
    inds, dataset_cfg, passing_defs, logger
)

2022-08-13 20:29:05,285 => createing time series 02_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00,  3.98it/s]
2022-08-13 20:29:09,563 => createing time series 07_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:29<00:00,  1.65it/s]
2022-08-13 20:29:38,661 => createing time series 08_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:05<00:00,  6.59it/s]
2022-08-13 20:29:44,433 => createing time series 08_002
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45/45 [00:31<00:00,  1.44it/s]


In [7]:
# model config
mdl_cfg = {
    "dropouts": [0.1, 0],
    "hidden_dims": [128, 64],
    "n_classes": 2,
    "n_linears": 2,
    "rnn_dropout": 0.1,
    "size": 4,
}

# grid search parameters
params = {
    'n_rnns': [1, 2, 3],
    'rnn_hidden_dim': [128, 256],
    'pos_weight': [8, 16, 32]
}

# epoch
# epoch_len = train_cfg["optim"]["epoch"]
epoch_len = 30

In [None]:
max_acc = [[0, 0, 0, 0], None]
max_pre = [[0, 0, 0, 0], None]
max_rcl = [[0, 0, 0, 0], None]
max_f1 = [[0, 0, 0, 0], None]
max_models = [None for _ in range(4)]

for n_rnns in params['n_rnns']:
    for dim in params['rnn_hidden_dim']:
        for weight in params['pos_weight']:
            param = dict(n_rnns=n_rnns, rnn_hidden_dim=dim, weight=weight)
            print(param)
            
            # update config
            config = {}
            for key, val in mdl_cfg.items():
                config[key] = val
            for key, val in param.items():
                config[key] = val
            pos_weight = param["weight"]
                
            # init model, loss, optim
            model = init_model(config, device)
            criterion = init_loss([1, pos_weight], device)
            optimizer, scheduler = init_optim(
                model, train_cfg["optim"]["lr"], train_cfg["optim"]["lr_rate"]
            )
            
            # training
            model, epoch, history = train(
                model, train_loader, val_loader,
                criterion, optimizer, scheduler,
                epoch_len, logger, device
            )
            
            # test
            score = test(model, test_loader, logger, device)
            acc, pre, rcl, f1 = score
            
            # update max scores
            if acc > max_acc[0][0]:
                max_acc[0] = score
                max_acc[1] = param
                max_models[0] = model
            if pre > max_pre[0][1]:
                max_pre[0] = score
                max_pre[1] = param
                max_models[1] = model
            if rcl > max_rcl[0][2]:
                max_rcl[0] = score
                max_rcl[1] = param
                max_models[2] = model
            if f1 > max_f1[0][3]:
                max_f1[0] = score
                max_f1[1] = param
                max_models[3] = model
                
            torch.cuda.empty_cache()

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 8}


2022-08-13 20:33:58,804 => start training
2022-08-13 20:34:09,732 Epoch[1/30] train loss: 0.45273, val loss: nan, lr: 0.0010000, time: 10.93
2022-08-13 20:34:21,468 Epoch[2/30] train loss: 0.39710, val loss: nan, lr: 0.0010000, time: 11.73
2022-08-13 20:34:35,222 Epoch[3/30] train loss: 0.39456, val loss: nan, lr: 0.0010000, time: 13.75
2022-08-13 20:34:51,687 Epoch[4/30] train loss: 0.39283, val loss: nan, lr: 0.0010000, time: 16.46
2022-08-13 20:35:07,775 Epoch[5/30] train loss: 0.39210, val loss: nan, lr: 0.0010000, time: 16.09
2022-08-13 20:35:23,470 Epoch[6/30] train loss: 0.39115, val loss: nan, lr: 0.0010000, time: 15.69
2022-08-13 20:35:40,173 Epoch[7/30] train loss: 0.39045, val loss: nan, lr: 0.0010000, time: 16.70
2022-08-13 20:35:56,619 Epoch[8/30] train loss: 0.38957, val loss: nan, lr: 0.0010000, time: 16.44
2022-08-13 20:36:12,460 Epoch[9/30] train loss: 0.38884, val loss: nan, lr: 0.0010000, time: 15.84
2022-08-13 20:36:29,178 Epoch[10/30] train loss: 0.38766, val loss:

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 16}


2022-08-13 20:42:28,765 Epoch[1/30] train loss: 0.49457, val loss: nan, lr: 0.0010000, time: 15.05
2022-08-13 20:42:44,362 Epoch[2/30] train loss: 0.42916, val loss: nan, lr: 0.0010000, time: 15.59
2022-08-13 20:43:01,019 Epoch[3/30] train loss: 0.42532, val loss: nan, lr: 0.0010000, time: 16.66
2022-08-13 20:43:17,363 Epoch[4/30] train loss: 0.42372, val loss: nan, lr: 0.0010000, time: 16.34
2022-08-13 20:43:33,285 Epoch[5/30] train loss: 0.42235, val loss: nan, lr: 0.0010000, time: 15.92
2022-08-13 20:43:49,514 Epoch[6/30] train loss: 0.42102, val loss: nan, lr: 0.0010000, time: 16.23
2022-08-13 20:44:05,357 Epoch[7/30] train loss: 0.41871, val loss: nan, lr: 0.0010000, time: 15.84
2022-08-13 20:44:21,571 Epoch[8/30] train loss: 0.41790, val loss: nan, lr: 0.0010000, time: 16.21
2022-08-13 20:44:38,051 Epoch[9/30] train loss: 0.41673, val loss: nan, lr: 0.0010000, time: 16.48
2022-08-13 20:44:54,632 Epoch[10/30] train loss: 0.41457, val loss: nan, lr: 0.0010000, time: 16.58
2022-08-1

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 32}


2022-08-13 20:50:52,118 Epoch[1/30] train loss: 0.51521, val loss: nan, lr: 0.0010000, time: 16.35
2022-08-13 20:51:07,263 Epoch[2/30] train loss: 0.47905, val loss: nan, lr: 0.0010000, time: 15.14
2022-08-13 20:51:23,940 Epoch[3/30] train loss: 0.47616, val loss: nan, lr: 0.0010000, time: 16.68
2022-08-13 20:51:40,086 Epoch[4/30] train loss: 0.47297, val loss: nan, lr: 0.0010000, time: 16.14
2022-08-13 20:51:56,468 Epoch[5/30] train loss: 0.47165, val loss: nan, lr: 0.0010000, time: 16.38
2022-08-13 20:52:12,321 Epoch[6/30] train loss: 0.46954, val loss: nan, lr: 0.0010000, time: 15.85
2022-08-13 20:52:27,994 Epoch[7/30] train loss: 0.46673, val loss: nan, lr: 0.0010000, time: 15.67
2022-08-13 20:52:44,942 Epoch[8/30] train loss: 0.46510, val loss: nan, lr: 0.0010000, time: 16.95
2022-08-13 20:53:00,683 Epoch[9/30] train loss: 0.46372, val loss: nan, lr: 0.0010000, time: 15.74
2022-08-13 20:53:16,598 Epoch[10/30] train loss: 0.46173, val loss: nan, lr: 0.0010000, time: 15.91
2022-08-1

{'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 8}


2022-08-13 20:59:21,307 Epoch[1/30] train loss: 0.48750, val loss: nan, lr: 0.0010000, time: 16.50
2022-08-13 20:59:37,555 Epoch[2/30] train loss: 0.40020, val loss: nan, lr: 0.0010000, time: 16.25
2022-08-13 20:59:54,491 Epoch[3/30] train loss: 0.39568, val loss: nan, lr: 0.0010000, time: 16.93
2022-08-13 21:00:10,535 Epoch[4/30] train loss: 0.39414, val loss: nan, lr: 0.0010000, time: 16.04
2022-08-13 21:00:27,084 Epoch[5/30] train loss: 0.39274, val loss: nan, lr: 0.0010000, time: 16.55
2022-08-13 21:00:43,526 Epoch[6/30] train loss: 0.39218, val loss: nan, lr: 0.0010000, time: 16.44
2022-08-13 21:01:00,541 Epoch[7/30] train loss: 0.39102, val loss: nan, lr: 0.0010000, time: 17.01
2022-08-13 21:01:17,634 Epoch[8/30] train loss: 0.39002, val loss: nan, lr: 0.0010000, time: 17.09
2022-08-13 21:01:33,480 Epoch[9/30] train loss: 0.38922, val loss: nan, lr: 0.0010000, time: 15.84
2022-08-13 21:01:50,414 Epoch[10/30] train loss: 0.38800, val loss: nan, lr: 0.0010000, time: 16.93
2022-08-1

{'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 16}


2022-08-13 21:08:03,320 Epoch[1/30] train loss: 0.47584, val loss: nan, lr: 0.0010000, time: 16.76
2022-08-13 21:08:19,679 Epoch[2/30] train loss: 0.42855, val loss: nan, lr: 0.0010000, time: 16.36
2022-08-13 21:08:36,141 Epoch[3/30] train loss: 0.42515, val loss: nan, lr: 0.0010000, time: 16.46
2022-08-13 21:08:52,810 Epoch[4/30] train loss: 0.42339, val loss: nan, lr: 0.0010000, time: 16.67
2022-08-13 21:09:09,392 Epoch[5/30] train loss: 0.42164, val loss: nan, lr: 0.0010000, time: 16.58
2022-08-13 21:09:26,423 Epoch[6/30] train loss: 0.42060, val loss: nan, lr: 0.0010000, time: 17.03
2022-08-13 21:09:43,042 Epoch[7/30] train loss: 0.41916, val loss: nan, lr: 0.0010000, time: 16.62
2022-08-13 21:09:59,308 Epoch[8/30] train loss: 0.41740, val loss: nan, lr: 0.0010000, time: 16.26
2022-08-13 21:10:15,457 Epoch[9/30] train loss: 0.41558, val loss: nan, lr: 0.0010000, time: 16.15
2022-08-13 21:10:31,955 Epoch[10/30] train loss: 0.41399, val loss: nan, lr: 0.0010000, time: 16.50
2022-08-1

{'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 32}


2022-08-13 21:16:46,400 Epoch[1/30] train loss: 0.51960, val loss: nan, lr: 0.0010000, time: 16.98
2022-08-13 21:17:02,475 Epoch[2/30] train loss: 0.47905, val loss: nan, lr: 0.0010000, time: 16.07
2022-08-13 21:17:18,368 Epoch[3/30] train loss: 0.47639, val loss: nan, lr: 0.0010000, time: 15.89
2022-08-13 21:17:33,772 Epoch[4/30] train loss: 0.47333, val loss: nan, lr: 0.0010000, time: 15.40
2022-08-13 21:17:50,289 Epoch[5/30] train loss: 0.47275, val loss: nan, lr: 0.0010000, time: 16.51
2022-08-13 21:18:07,099 Epoch[6/30] train loss: 0.47001, val loss: nan, lr: 0.0010000, time: 16.81
2022-08-13 21:18:24,166 Epoch[7/30] train loss: 0.46670, val loss: nan, lr: 0.0010000, time: 17.07
2022-08-13 21:18:41,385 Epoch[8/30] train loss: 0.46499, val loss: nan, lr: 0.0010000, time: 17.22
2022-08-13 21:18:58,556 Epoch[9/30] train loss: 0.46188, val loss: nan, lr: 0.0010000, time: 17.17
2022-08-13 21:19:15,440 Epoch[10/30] train loss: 0.46117, val loss: nan, lr: 0.0010000, time: 16.88
2022-08-1

{'n_rnns': 2, 'rnn_hidden_dim': 128, 'weight': 8}


2022-08-13 21:25:33,376 Epoch[1/30] train loss: 0.43976, val loss: nan, lr: 0.0010000, time: 16.23
2022-08-13 21:25:50,077 Epoch[2/30] train loss: 0.39734, val loss: nan, lr: 0.0010000, time: 16.70
2022-08-13 21:26:06,781 Epoch[3/30] train loss: 0.39491, val loss: nan, lr: 0.0010000, time: 16.70
2022-08-13 21:26:23,926 Epoch[4/30] train loss: 0.39330, val loss: nan, lr: 0.0010000, time: 17.14
2022-08-13 21:26:40,601 Epoch[5/30] train loss: 0.39206, val loss: nan, lr: 0.0010000, time: 16.67
2022-08-13 21:26:56,981 Epoch[6/30] train loss: 0.39149, val loss: nan, lr: 0.0010000, time: 16.38
2022-08-13 21:27:13,500 Epoch[7/30] train loss: 0.38975, val loss: nan, lr: 0.0010000, time: 16.52
2022-08-13 21:27:29,693 Epoch[8/30] train loss: 0.38925, val loss: nan, lr: 0.0010000, time: 16.19
2022-08-13 21:27:45,789 Epoch[9/30] train loss: 0.38797, val loss: nan, lr: 0.0010000, time: 16.09
2022-08-13 21:28:02,441 Epoch[10/30] train loss: 0.38661, val loss: nan, lr: 0.0010000, time: 16.65
2022-08-1

{'n_rnns': 2, 'rnn_hidden_dim': 128, 'weight': 16}


2022-08-13 21:34:13,839 Epoch[1/30] train loss: 0.48112, val loss: nan, lr: 0.0010000, time: 15.87
2022-08-13 21:34:29,854 Epoch[2/30] train loss: 0.42961, val loss: nan, lr: 0.0010000, time: 16.01
2022-08-13 21:34:46,462 Epoch[3/30] train loss: 0.42597, val loss: nan, lr: 0.0010000, time: 16.61
2022-08-13 21:35:02,731 Epoch[4/30] train loss: 0.42423, val loss: nan, lr: 0.0010000, time: 16.27
2022-08-13 21:35:18,931 Epoch[5/30] train loss: 0.42301, val loss: nan, lr: 0.0010000, time: 16.20
2022-08-13 21:35:35,206 Epoch[6/30] train loss: 0.42163, val loss: nan, lr: 0.0010000, time: 16.27
2022-08-13 21:35:51,606 Epoch[7/30] train loss: 0.42051, val loss: nan, lr: 0.0010000, time: 16.40
2022-08-13 21:36:08,070 Epoch[8/30] train loss: 0.41840, val loss: nan, lr: 0.0010000, time: 16.46
2022-08-13 21:36:24,509 Epoch[9/30] train loss: 0.41759, val loss: nan, lr: 0.0010000, time: 16.44
2022-08-13 21:36:40,566 Epoch[10/30] train loss: 0.41577, val loss: nan, lr: 0.0010000, time: 16.05
2022-08-1

{'n_rnns': 2, 'rnn_hidden_dim': 128, 'weight': 32}


2022-08-13 21:42:48,321 Epoch[1/30] train loss: 0.52745, val loss: nan, lr: 0.0010000, time: 16.52
2022-08-13 21:43:05,311 Epoch[2/30] train loss: 0.48048, val loss: nan, lr: 0.0010000, time: 16.99
2022-08-13 21:43:21,446 Epoch[3/30] train loss: 0.47776, val loss: nan, lr: 0.0010000, time: 16.13
2022-08-13 21:43:38,166 Epoch[4/30] train loss: 0.47469, val loss: nan, lr: 0.0010000, time: 16.72
2022-08-13 21:43:54,791 Epoch[5/30] train loss: 0.47258, val loss: nan, lr: 0.0010000, time: 16.62
2022-08-13 21:44:11,455 Epoch[6/30] train loss: 0.47081, val loss: nan, lr: 0.0010000, time: 16.66
2022-08-13 21:44:27,581 Epoch[7/30] train loss: 0.46905, val loss: nan, lr: 0.0010000, time: 16.12
2022-08-13 21:44:44,460 Epoch[8/30] train loss: 0.46730, val loss: nan, lr: 0.0010000, time: 16.88
2022-08-13 21:45:00,689 Epoch[9/30] train loss: 0.46515, val loss: nan, lr: 0.0010000, time: 16.23
2022-08-13 21:45:17,217 Epoch[10/30] train loss: 0.46235, val loss: nan, lr: 0.0010000, time: 16.53
2022-08-1

{'n_rnns': 2, 'rnn_hidden_dim': 256, 'weight': 8}


2022-08-13 21:51:39,826 Epoch[1/30] train loss: 0.46634, val loss: nan, lr: 0.0010000, time: 25.78
2022-08-13 21:52:05,268 Epoch[2/30] train loss: 0.39891, val loss: nan, lr: 0.0010000, time: 25.44
2022-08-13 21:52:30,870 Epoch[3/30] train loss: 0.39586, val loss: nan, lr: 0.0010000, time: 25.60
2022-08-13 21:52:56,440 Epoch[4/30] train loss: 0.39422, val loss: nan, lr: 0.0010000, time: 25.57
2022-08-13 21:53:21,952 Epoch[5/30] train loss: 0.39283, val loss: nan, lr: 0.0010000, time: 25.51
2022-08-13 21:53:48,128 Epoch[6/30] train loss: 0.39174, val loss: nan, lr: 0.0010000, time: 26.17
2022-08-13 21:54:13,900 Epoch[7/30] train loss: 0.39047, val loss: nan, lr: 0.0010000, time: 25.77
2022-08-13 21:54:39,534 Epoch[8/30] train loss: 0.38866, val loss: nan, lr: 0.0010000, time: 25.63
2022-08-13 21:55:05,340 Epoch[9/30] train loss: 0.38702, val loss: nan, lr: 0.0010000, time: 25.80
2022-08-13 21:55:31,039 Epoch[10/30] train loss: 0.38605, val loss: nan, lr: 0.0010000, time: 25.70
2022-08-1

{'n_rnns': 2, 'rnn_hidden_dim': 256, 'weight': 16}


2022-08-13 22:05:00,825 Epoch[1/30] train loss: 0.47475, val loss: nan, lr: 0.0010000, time: 25.46
2022-08-13 22:05:26,467 Epoch[2/30] train loss: 0.42929, val loss: nan, lr: 0.0010000, time: 25.64
2022-08-13 22:05:51,967 Epoch[3/30] train loss: 0.42622, val loss: nan, lr: 0.0010000, time: 25.50
2022-08-13 22:06:18,092 Epoch[4/30] train loss: 0.42479, val loss: nan, lr: 0.0010000, time: 26.12
2022-08-13 22:06:44,026 Epoch[5/30] train loss: 0.42345, val loss: nan, lr: 0.0010000, time: 25.93
2022-08-13 22:07:10,089 Epoch[6/30] train loss: 0.42212, val loss: nan, lr: 0.0010000, time: 26.06
2022-08-13 22:07:35,691 Epoch[7/30] train loss: 0.42047, val loss: nan, lr: 0.0010000, time: 25.60
2022-08-13 22:08:01,175 Epoch[8/30] train loss: 0.41948, val loss: nan, lr: 0.0010000, time: 25.48
2022-08-13 22:08:26,633 Epoch[9/30] train loss: 0.41777, val loss: nan, lr: 0.0010000, time: 25.46
2022-08-13 22:08:52,782 Epoch[10/30] train loss: 0.41568, val loss: nan, lr: 0.0010000, time: 26.15
2022-08-1

In [None]:
print(f"epoch={epoch}")
print('max accuracy: ', max_acc[1])
acc, pre, rcl, f1 = max_acc[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max precision: ', max_pre[1])
acc, pre, rcl, f1 = max_pre[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max recall: ', max_rcl[1])
acc, pre, rcl, f1 = max_rcl[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max f1: ', max_f1[1])
acc, pre, rcl, f1 = max_f1[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

## モデル保存

In [None]:
# select max recall
model = max_models[2]
param = max_rcl[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [None]:
model_path = f'models/passing/pass_model_lstm_recall_ep{epoch}.pth'
torch.save(model.state_dict(), model_path)

In [None]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_recall_ep{epoch}.yaml', 'w') as f:
    yaml.dump(config, f)

In [None]:
# select max f1
model = max_models[3]
param = max_f1[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [None]:
model_path = f'models/passing/pass_model_lstm_f1_ep{epoch}.pth'
torch.save(model.state_dict(), model_path)

In [None]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_f1_ep{epoch}.yaml', 'w') as f:
    yaml.dump(config, f)

# 検証
## モデルロード

In [None]:
# load model
epoch = 30
rcl_f1 = "f1"

try:
    torch.cuda.empty_cache()
    del model
    gc.collect()
except NameError:
    pass

mdl_cfg_path = f'config/passing/pass_model_lstm_{rcl_f1}_ep{epoch}.yaml'
with open(mdl_cfg_path, "r") as f:
    mdl_cfg = yaml.safe_load(f)
model = init_model(mdl_cfg, device)

param = torch.load(mdl_cfg["pretrained_path"])
model.load_state_dict(param)

## データロード

In [None]:
x_dict, y_dict = make_all_data(inds, train_cfg["dataset"]["setting"], grp_cfg["passing"]["default"], logger)

In [None]:
np.random.seed(train_cfg["dataset"]["random_seed"])

seq_len = grp_cfg["passing"]["default"]["seq_len"]
size = mdl_cfg["size"]

keys_1 = [key for key in x_dict if 1 in y_dict[key]]
keys_0 = [key for key in x_dict if 1 not in y_dict[key]]
random_keys_1 = np.random.choice(keys_1, size=len(keys_1), replace=False)
random_keys_0 = np.random.choice(keys_0, size=len(keys_0), replace=False)

train_ratio = train_cfg["dataset"]["train_ratio"]
val_ratio = train_cfg["dataset"]["val_ratio"]
train_len_1 = int(len(keys_1) * train_ratio)
train_len_0 = int(len(keys_0) * train_ratio)
val_len_1 = int(len(keys_1) * val_ratio)
val_len_0 = int(len(keys_0) * val_ratio)

train_keys_1 = random_keys_1[:train_len_1].tolist()
val_keys_1 = random_keys_1[train_len_1 : train_len_1 + val_len_1].tolist()
test_keys_1 = random_keys_1[train_len_1 + val_len_1 :].tolist()
train_keys_0 = random_keys_0[:train_len_0].tolist()
test_keys_0 = random_keys_0[train_len_0:].tolist()
val_keys_0 = random_keys_1[train_len_0 : train_len_0 + val_len_0].tolist()

train_keys = sorted(train_keys_1 + train_keys_0)
val_keys = sorted(val_keys_1 + val_keys_0)
test_keys = sorted(test_keys_1 + test_keys_0)

In [None]:
def create_sequence(x_lst, y_lst, seq_len=30, size=4):
    x_seq = []
    y_seq = []
    for i in range(len(x_lst) - seq_len + 1):
        x = x_lst[i:i + seq_len]
        x_seq.append(x)
        y_seq.append(y_lst[i + seq_len - 1])
    
    return x_seq, y_seq


columns = ["distance", "body_direction", "arm_ave", "wrist_distance"]
def plot(x_lst, y_lst, pred, seq_len=30, path=None):
    x_lst = [[0 for _ in range(x_lst.shape[1])]] + [[np.nan for _ in range(x_lst.shape[1])] for i in range(seq_len - 1)] + x_lst.tolist()
    y_lst = [0] + [np.nan for i in range(seq_len - 1)] + y_lst
    pred = [0] + [np.nan for i in range(seq_len - 1)] + pred.tolist()
    
    fig = plt.figure(figsize=(13, 4))
    ax = fig.add_axes((0.04, 0.17, 0.80, 0.81))
    
    ax.plot(pred, label='pred')
    ax.plot(y_lst, linestyle=':', label='ground truth')
    for i, feature in enumerate(np.array(x_lst).T):
        ax.plot(feature, alpha=0.4, label=columns[i])

    ax.set_ylim((-0.05, 1.05))
    ax.set_xlabel('frame')
    ax.legend(
        bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0,
        fontsize=20, handlelength=0.8, handletextpad=0.2
    )
    
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    if path is not None:
        fig.savefig(path)
    plt.show()

## トレインデータ

In [None]:
save_keys = [
    '02_06_1_3',
]

In [None]:
y_all_train = []
pred_all_train = []
y_eve_train = []
pred_eve_train = []

model.eval()
with torch.no_grad():
    for key in train_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]
        
        x, _ = create_sequence(x_lst, y_lst, seq_len, size)
        x = torch.Tensor(np.array(x)).float().to(device)
        
        if len(x) == 0:
            continue

        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[seq_len - 1:]
        y_lst = y_lst[seq_len - 1:]
            
        y_all_train += y_lst
        pred_all_train += pred.tolist()
        y_eve_train.append(1 in y_lst)
        pred_eve_train.append(1 in pred.tolist())
        
        if 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join("data", "passing", "image", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, seq_len, path=path)

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_train, pred_all_train)))
print('precision: {:.3f}'.format(precision_score(y_all_train, pred_all_train)))
print('recall: {:.3f}'.format(recall_score(y_all_train, pred_all_train)))
print('f1_score: {:.3f}'.format(f1_score(y_all_train, pred_all_train)))

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_train, pred_eve_train)))
print('precision: {:.3f}'.format(precision_score(y_eve_train, pred_eve_train)))
print('recall: {:.3f}'.format(recall_score(y_eve_train, pred_eve_train)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_train, pred_eve_train)))

## テストデータ

In [None]:
save_keys = [
    '08_03_2_5',
]

In [None]:
y_all_test = []
pred_all_test = []
y_eve_test = []
pred_eve_test = []
tn, fn = 0, 0

model.eval()
with torch.no_grad():
    for key in test_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]

        x, _ = create_sequence(x_lst, y_lst, seq_len, size)
        x = torch.Tensor(x).float().to(device)

        if len(x) == 0:
            tn += 1
            continue
            
        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[seq_len - 1:]
        y_lst = y_lst[seq_len - 1:]
        
        y_all_test += y_lst
        pred_all_test += pred.tolist()
        y_eve_test.append(1 in y_lst)
        pred_eve_test.append(1 in pred.tolist())
        if 1 not in y_lst:
            if 1 not in pred:
                tn += 1
            else:
                fn += 1
        
        if 1 not in pred and 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join("data", "passing", "image", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, seq_len, path=path)

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_test, pred_all_test)))
print('precision: {:.3f}'.format(precision_score(y_all_test, pred_all_test)))
print('recall: {:.3f}'.format(recall_score(y_all_test, pred_all_test)))
print('f1_score: {:.3f}'.format(f1_score(y_all_test, pred_all_test)))

cm = confusion_matrix(y_all_test, pred_all_test)
sns.heatmap(cm, cmap='Blues')

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_test, pred_eve_test)))
print('precision: {:.3f}'.format(precision_score(y_eve_test, pred_eve_test)))
print('recall: {:.3f}'.format(recall_score(y_eve_test, pred_eve_test)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_test, pred_eve_test)))

print('true negative:', tn)
print('false negative:', fn)

cm = confusion_matrix(y_eve_test, pred_eve_test)
sns.heatmap(cm, cmap='Blues')