In [1]:
# back to project root
%cd /mnt/c/research

import argparse
import os
import sys
import time
from glob import glob

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import yaml
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from torch import nn, optim

sys.path.append("src")
from group.passing.dataset import make_data_loaders, make_all_data
from group.passing.lstm_model import LSTMModel
from utility.activity_loader import load_individuals
from utility.logger import logger
from tools.train_passing import init_model, init_loss, init_optim, train, test

/mnt/c/research


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams["font.size"] = 24
plt.rcParams['xtick.direction'] = 'in'  # x axis in
plt.rcParams['ytick.direction'] = 'in'  # y axis in

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
cfg_path = "config/passing/pass_train.yaml"
with open(cfg_path, "r") as f:
    train_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["individual"], "r") as f:
    ind_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["group"], "r") as f:
    grp_cfg = yaml.safe_load(f)
mdl_cfg_path = grp_cfg["passing"]["cfg_path"]
with open(mdl_cfg_path, "r") as f:
    mdl_cfg = yaml.safe_load(f)

In [5]:
data_dirs_all = {}
for room_num, surgery_items in train_cfg["dataset"]["setting"].items():
    for surgery_num in surgery_items.keys():
        dirs = sorted(glob(os.path.join("data", room_num, surgery_num, "passing", "*")))
        data_dirs_all[f"{room_num}_{surgery_num}"] = dirs

logger.info(f"=> loading individuals from {data_dirs_all}")
inds = {}
for key_prefix, dirs in data_dirs_all.items():
    for model_path in dirs:
        num = model_path.split("/")[-1]
        json_path = os.path.join(model_path, ".json", "individual.json")
        tmp_inds, _ = load_individuals(json_path, ind_cfg)
        for pid, ind in tmp_inds.items():
            inds[f"{key_prefix}_{num}_{pid}"] = ind

2022-05-24 16:18:32 [INFO]: => loading individuals from {'02_001': ['data/02/001/passing/01', 'data/02/001/passing/02', 'data/02/001/passing/03', 'data/02/001/passing/04', 'data/02/001/passing/05', 'data/02/001/passing/06', 'data/02/001/passing/07', 'data/02/001/passing/08', 'data/02/001/passing/09', 'data/02/001/passing/10', 'data/02/001/passing/11', 'data/02/001/passing/12', 'data/02/001/passing/13', 'data/02/001/passing/14', 'data/02/001/passing/15', 'data/02/001/passing/16', 'data/02/001/passing/17', 'data/02/001/passing/18', 'data/02/001/passing/19', 'data/02/001/passing/20', 'data/02/001/passing/21', 'data/02/001/passing/22', 'data/02/001/passing/23'], '08_001': ['data/08/001/passing/01', 'data/08/001/passing/02', 'data/08/001/passing/03', 'data/08/001/passing/04', 'data/08/001/passing/05', 'data/08/001/passing/06', 'data/08/001/passing/07', 'data/08/001/passing/08', 'data/08/001/passing/09', 'data/08/001/passing/10', 'data/08/001/passing/11', 'data/08/001/passing/12', 'data/08/0

In [6]:
# create model
model = init_model(mdl_cfg, device)

# 深層学習

In [7]:
# create data loader
dataset_cfg = train_cfg["dataset"]
passing_defs = grp_cfg["passing"]["default"]
train_loader, val_loader, test_loader = make_data_loaders(
    inds, dataset_cfg, passing_defs, logger
)

2022-05-24 16:18:44 [INFO]: => createing time series 02_001
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:02<00:00,  8.59it/s]
2022-05-24 16:18:47 [INFO]: => createing time series 08_001
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:02<00:00, 14.50it/s]
2022-05-24 16:18:49 [INFO]: => createing time series 09_001
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00,  7.73it/s]
2022-05-24 16:18:51 [INFO]: => extracting feature 02_001
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 

# グリッドサーチ

In [8]:
params = {
    'n_rnns': [1, 2],
    'rnn_hidden_dim': [128, 256, 512],
    'pos_weight': [8, 16]
}
epoch_len = train_cfg["optim"]["epoch"]

In [9]:
max_acc = [[0, 0, 0, 0], None]
max_pre = [[0, 0, 0, 0], None]
max_rcl = [[0, 0, 0, 0], None]
max_f1 = [[0, 0, 0, 0], None]
max_models = [None for _ in range(4)]

for n_rnns in params['n_rnns']:
    for dim in params['rnn_hidden_dim']:
        for weight in params['pos_weight']:
            param = dict(n_rnns=n_rnns, rnn_hidden_dim=dim, weight=weight)
            print(param)
            
            # update config
            config = {}
            for key, val in mdl_cfg.items():
                config[key] = val
            for key, val in param.items():
                config[key] = val
            pos_weight = param["weight"]
                
            # init model, loss, optim
            model = init_model(config, device)
            criterion = init_loss([1, pos_weight], device)
            optimizer, scheduler = init_optim(
                model, train_cfg["optim"]["lr"], train_cfg["optim"]["lr_rate"]
            )
            
            # training
            model, epoch, history = train(
                model, train_loader, val_loader,
                criterion, optimizer, scheduler,
                epoch_len, logger, device
            )
            
            # test
            score = test(model, test_loader, logger, device)
            acc, pre, rcl, f1 = score
            
            # update max scores
            if acc > max_acc[0][0]:
                max_acc[0] = score
                max_acc[1] = param
                max_models[0] = model
            if pre > max_pre[0][1]:
                max_pre[0] = score
                max_pre[1] = param
                max_models[1] = model
            if rcl > max_rcl[0][2]:
                max_rcl[0] = score
                max_rcl[1] = param
                max_models[2] = model
            if f1 > max_f1[0][3]:
                max_f1[0] = score
                max_f1[1] = param
                max_models[3] = model
                
            torch.cuda.empty_cache()

2022-05-24 09:47:50 [INFO]: => start training


{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 8}


2022-05-24 09:48:23 [INFO]: Epoch[1/50] train loss: 0.54059, val loss: nan, lr: 0.0010000, time: 32.72
2022-05-24 09:48:55 [INFO]: Epoch[2/50] train loss: 0.48087, val loss: nan, lr: 0.0010000, time: 31.55
2022-05-24 09:49:26 [INFO]: Epoch[3/50] train loss: 0.45065, val loss: nan, lr: 0.0010000, time: 31.19
2022-05-24 09:49:57 [INFO]: Epoch[4/50] train loss: 0.43386, val loss: nan, lr: 0.0010000, time: 31.49
2022-05-24 09:50:29 [INFO]: Epoch[5/50] train loss: 0.42408, val loss: nan, lr: 0.0010000, time: 31.38
2022-05-24 09:51:00 [INFO]: Epoch[6/50] train loss: 0.41802, val loss: nan, lr: 0.0010000, time: 31.62
2022-05-24 09:51:32 [INFO]: Epoch[7/50] train loss: 0.41342, val loss: nan, lr: 0.0010000, time: 31.48
2022-05-24 09:52:03 [INFO]: Epoch[8/50] train loss: 0.41050, val loss: nan, lr: 0.0010000, time: 31.38
2022-05-24 09:52:34 [INFO]: Epoch[9/50] train loss: 0.40790, val loss: nan, lr: 0.0010000, time: 31.13
2022-05-24 09:53:05 [INFO]: Epoch[10/50] train loss: 0.40586, val loss: n

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 16}


2022-05-24 10:15:53 [INFO]: Epoch[1/50] train loss: 0.75312, val loss: nan, lr: 0.0010000, time: 31.34
2022-05-24 10:16:24 [INFO]: Epoch[2/50] train loss: 0.63875, val loss: nan, lr: 0.0010000, time: 30.70
2022-05-24 10:16:55 [INFO]: Epoch[3/50] train loss: 0.55906, val loss: nan, lr: 0.0010000, time: 30.83
2022-05-24 10:17:26 [INFO]: Epoch[4/50] train loss: 0.51497, val loss: nan, lr: 0.0010000, time: 31.49
2022-05-24 10:17:58 [INFO]: Epoch[5/50] train loss: 0.49021, val loss: nan, lr: 0.0010000, time: 31.57
2022-05-24 10:18:30 [INFO]: Epoch[6/50] train loss: 0.47466, val loss: nan, lr: 0.0010000, time: 31.50
2022-05-24 10:19:01 [INFO]: Epoch[7/50] train loss: 0.46723, val loss: nan, lr: 0.0010000, time: 31.74
2022-05-24 10:19:32 [INFO]: Epoch[8/50] train loss: 0.45731, val loss: nan, lr: 0.0010000, time: 31.12
2022-05-24 10:20:04 [INFO]: Epoch[9/50] train loss: 0.45331, val loss: nan, lr: 0.0010000, time: 31.23
2022-05-24 10:20:35 [INFO]: Epoch[10/50] train loss: 0.44753, val loss: n

{'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 8}


2022-05-24 10:43:33 [INFO]: Epoch[1/50] train loss: 0.73985, val loss: nan, lr: 0.0010000, time: 32.59
2022-05-24 10:44:04 [INFO]: Epoch[2/50] train loss: 0.61264, val loss: nan, lr: 0.0010000, time: 31.73
2022-05-24 10:44:37 [INFO]: Epoch[3/50] train loss: 0.53920, val loss: nan, lr: 0.0010000, time: 32.82
2022-05-24 10:45:10 [INFO]: Epoch[4/50] train loss: 0.49274, val loss: nan, lr: 0.0010000, time: 33.36
2022-05-24 10:45:44 [INFO]: Epoch[5/50] train loss: 0.46298, val loss: nan, lr: 0.0010000, time: 33.46
2022-05-24 10:46:16 [INFO]: Epoch[6/50] train loss: 0.44353, val loss: nan, lr: 0.0010000, time: 32.22
2022-05-24 10:46:49 [INFO]: Epoch[7/50] train loss: 0.43229, val loss: nan, lr: 0.0010000, time: 33.02
2022-05-24 10:47:21 [INFO]: Epoch[8/50] train loss: 0.42516, val loss: nan, lr: 0.0010000, time: 32.09
2022-05-24 10:47:54 [INFO]: Epoch[9/50] train loss: 0.41946, val loss: nan, lr: 0.0010000, time: 32.54
2022-05-24 10:48:27 [INFO]: Epoch[10/50] train loss: 0.41419, val loss: n

{'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 16}


2022-05-24 11:11:40 [INFO]: Epoch[1/50] train loss: 0.69209, val loss: nan, lr: 0.0010000, time: 32.10
2022-05-24 11:12:12 [INFO]: Epoch[2/50] train loss: 0.59096, val loss: nan, lr: 0.0010000, time: 31.95
2022-05-24 11:12:45 [INFO]: Epoch[3/50] train loss: 0.53709, val loss: nan, lr: 0.0010000, time: 32.69
2022-05-24 11:13:18 [INFO]: Epoch[4/50] train loss: 0.51286, val loss: nan, lr: 0.0010000, time: 32.63
2022-05-24 11:13:49 [INFO]: Epoch[5/50] train loss: 0.48657, val loss: nan, lr: 0.0010000, time: 31.77
2022-05-24 11:14:21 [INFO]: Epoch[6/50] train loss: 0.47120, val loss: nan, lr: 0.0010000, time: 32.00
2022-05-24 11:14:54 [INFO]: Epoch[7/50] train loss: 0.46124, val loss: nan, lr: 0.0010000, time: 32.37
2022-05-24 11:15:26 [INFO]: Epoch[8/50] train loss: 0.45697, val loss: nan, lr: 0.0010000, time: 32.66
2022-05-24 11:15:58 [INFO]: Epoch[9/50] train loss: 0.45645, val loss: nan, lr: 0.0010000, time: 31.74
2022-05-24 11:16:30 [INFO]: Epoch[10/50] train loss: 0.44938, val loss: n

{'n_rnns': 1, 'rnn_hidden_dim': 512, 'weight': 8}


2022-05-24 11:39:53 [INFO]: Epoch[1/50] train loss: 0.86613, val loss: nan, lr: 0.0010000, time: 32.05
2022-05-24 11:40:25 [INFO]: Epoch[2/50] train loss: 0.68848, val loss: nan, lr: 0.0010000, time: 31.76
2022-05-24 11:40:56 [INFO]: Epoch[3/50] train loss: 0.62213, val loss: nan, lr: 0.0010000, time: 30.93
2022-05-24 11:41:27 [INFO]: Epoch[4/50] train loss: 0.57756, val loss: nan, lr: 0.0010000, time: 31.14
2022-05-24 11:41:59 [INFO]: Epoch[5/50] train loss: 0.55213, val loss: nan, lr: 0.0010000, time: 31.70
2022-05-24 11:42:31 [INFO]: Epoch[6/50] train loss: 0.50243, val loss: nan, lr: 0.0010000, time: 32.52
2022-05-24 11:43:04 [INFO]: Epoch[7/50] train loss: 0.47601, val loss: nan, lr: 0.0010000, time: 32.33
2022-05-24 11:43:36 [INFO]: Epoch[8/50] train loss: 0.45829, val loss: nan, lr: 0.0010000, time: 32.49
2022-05-24 11:44:11 [INFO]: Epoch[9/50] train loss: 0.44516, val loss: nan, lr: 0.0010000, time: 34.28
2022-05-24 11:44:47 [INFO]: Epoch[10/50] train loss: 0.43462, val loss: n

{'n_rnns': 1, 'rnn_hidden_dim': 512, 'weight': 16}


2022-05-24 12:08:52 [INFO]: Epoch[1/50] train loss: 0.65922, val loss: nan, lr: 0.0010000, time: 31.47
2022-05-24 12:09:24 [INFO]: Epoch[2/50] train loss: 0.57397, val loss: nan, lr: 0.0010000, time: 31.32
2022-05-24 12:09:57 [INFO]: Epoch[3/50] train loss: 0.53114, val loss: nan, lr: 0.0010000, time: 33.25
2022-05-24 12:10:30 [INFO]: Epoch[4/50] train loss: 0.50734, val loss: nan, lr: 0.0010000, time: 33.35
2022-05-24 12:11:01 [INFO]: Epoch[5/50] train loss: 0.48240, val loss: nan, lr: 0.0010000, time: 30.76
2022-05-24 12:11:32 [INFO]: Epoch[6/50] train loss: 0.47254, val loss: nan, lr: 0.0010000, time: 30.77
2022-05-24 12:12:03 [INFO]: Epoch[7/50] train loss: 0.46297, val loss: nan, lr: 0.0010000, time: 31.25
2022-05-24 12:12:33 [INFO]: Epoch[8/50] train loss: 0.45839, val loss: nan, lr: 0.0010000, time: 30.53
2022-05-24 12:13:04 [INFO]: Epoch[9/50] train loss: 0.45507, val loss: nan, lr: 0.0010000, time: 30.89
2022-05-24 12:13:35 [INFO]: Epoch[10/50] train loss: 0.45035, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 128, 'weight': 8}


2022-05-24 12:36:14 [INFO]: Epoch[1/50] train loss: 0.79652, val loss: nan, lr: 0.0010000, time: 32.27
2022-05-24 12:36:46 [INFO]: Epoch[2/50] train loss: 0.66115, val loss: nan, lr: 0.0010000, time: 31.67
2022-05-24 12:37:18 [INFO]: Epoch[3/50] train loss: 0.56710, val loss: nan, lr: 0.0010000, time: 32.86
2022-05-24 12:37:52 [INFO]: Epoch[4/50] train loss: 0.50286, val loss: nan, lr: 0.0010000, time: 33.22
2022-05-24 12:38:25 [INFO]: Epoch[5/50] train loss: 0.46636, val loss: nan, lr: 0.0010000, time: 33.07
2022-05-24 12:38:57 [INFO]: Epoch[6/50] train loss: 0.44554, val loss: nan, lr: 0.0010000, time: 32.20
2022-05-24 12:39:29 [INFO]: Epoch[7/50] train loss: 0.43122, val loss: nan, lr: 0.0010000, time: 31.92
2022-05-24 12:40:01 [INFO]: Epoch[8/50] train loss: 0.42314, val loss: nan, lr: 0.0010000, time: 32.06
2022-05-24 12:40:33 [INFO]: Epoch[9/50] train loss: 0.41944, val loss: nan, lr: 0.0010000, time: 32.00
2022-05-24 12:41:06 [INFO]: Epoch[10/50] train loss: 0.41562, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 128, 'weight': 16}


2022-05-24 13:05:40 [INFO]: Epoch[1/50] train loss: 0.72649, val loss: nan, lr: 0.0010000, time: 33.20
2022-05-24 13:06:12 [INFO]: Epoch[2/50] train loss: 0.62019, val loss: nan, lr: 0.0010000, time: 32.28
2022-05-24 13:06:45 [INFO]: Epoch[3/50] train loss: 0.55300, val loss: nan, lr: 0.0010000, time: 32.76
2022-05-24 13:07:18 [INFO]: Epoch[4/50] train loss: 0.51431, val loss: nan, lr: 0.0010000, time: 32.82
2022-05-24 13:07:51 [INFO]: Epoch[5/50] train loss: 0.48887, val loss: nan, lr: 0.0010000, time: 32.83
2022-05-24 13:08:23 [INFO]: Epoch[6/50] train loss: 0.47132, val loss: nan, lr: 0.0010000, time: 31.84
2022-05-24 13:08:55 [INFO]: Epoch[7/50] train loss: 0.46329, val loss: nan, lr: 0.0010000, time: 31.97
2022-05-24 13:09:27 [INFO]: Epoch[8/50] train loss: 0.45646, val loss: nan, lr: 0.0010000, time: 32.37
2022-05-24 13:09:59 [INFO]: Epoch[9/50] train loss: 0.45380, val loss: nan, lr: 0.0010000, time: 31.69
2022-05-24 13:10:31 [INFO]: Epoch[10/50] train loss: 0.45082, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 256, 'weight': 8}


2022-05-24 13:34:10 [INFO]: Epoch[1/50] train loss: 0.66191, val loss: nan, lr: 0.0010000, time: 30.78
2022-05-24 13:34:41 [INFO]: Epoch[2/50] train loss: 0.57322, val loss: nan, lr: 0.0010000, time: 31.21
2022-05-24 13:35:12 [INFO]: Epoch[3/50] train loss: 0.51147, val loss: nan, lr: 0.0010000, time: 30.58
2022-05-24 13:35:42 [INFO]: Epoch[4/50] train loss: 0.47315, val loss: nan, lr: 0.0010000, time: 30.59
2022-05-24 13:36:14 [INFO]: Epoch[5/50] train loss: 0.44804, val loss: nan, lr: 0.0010000, time: 31.41
2022-05-24 13:36:45 [INFO]: Epoch[6/50] train loss: 0.43478, val loss: nan, lr: 0.0010000, time: 31.34
2022-05-24 13:37:16 [INFO]: Epoch[7/50] train loss: 0.42878, val loss: nan, lr: 0.0010000, time: 30.87
2022-05-24 13:37:48 [INFO]: Epoch[8/50] train loss: 0.42008, val loss: nan, lr: 0.0010000, time: 32.62
2022-05-24 13:38:19 [INFO]: Epoch[9/50] train loss: 0.41698, val loss: nan, lr: 0.0010000, time: 30.73
2022-05-24 13:38:51 [INFO]: Epoch[10/50] train loss: 0.42007, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 256, 'weight': 16}


2022-05-24 14:01:52 [INFO]: Epoch[1/50] train loss: 0.86539, val loss: nan, lr: 0.0010000, time: 31.86
2022-05-24 14:02:25 [INFO]: Epoch[2/50] train loss: 0.74932, val loss: nan, lr: 0.0010000, time: 32.64
2022-05-24 14:02:57 [INFO]: Epoch[3/50] train loss: 0.68485, val loss: nan, lr: 0.0010000, time: 31.86
2022-05-24 14:03:28 [INFO]: Epoch[4/50] train loss: 0.63199, val loss: nan, lr: 0.0010000, time: 30.62
2022-05-24 14:03:59 [INFO]: Epoch[5/50] train loss: 0.59177, val loss: nan, lr: 0.0010000, time: 30.99
2022-05-24 14:04:29 [INFO]: Epoch[6/50] train loss: 0.56587, val loss: nan, lr: 0.0010000, time: 30.62
2022-05-24 14:05:00 [INFO]: Epoch[7/50] train loss: 0.53287, val loss: nan, lr: 0.0010000, time: 30.39
2022-05-24 14:05:30 [INFO]: Epoch[8/50] train loss: 0.51065, val loss: nan, lr: 0.0010000, time: 30.32
2022-05-24 14:06:00 [INFO]: Epoch[9/50] train loss: 0.49954, val loss: nan, lr: 0.0010000, time: 29.89
2022-05-24 14:06:30 [INFO]: Epoch[10/50] train loss: 0.49270, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 512, 'weight': 8}


2022-05-24 14:29:06 [INFO]: Epoch[1/50] train loss: 0.77516, val loss: nan, lr: 0.0010000, time: 36.79
2022-05-24 14:29:43 [INFO]: Epoch[2/50] train loss: 0.66046, val loss: nan, lr: 0.0010000, time: 36.45
2022-05-24 14:30:19 [INFO]: Epoch[3/50] train loss: 0.56901, val loss: nan, lr: 0.0010000, time: 36.31
2022-05-24 14:30:55 [INFO]: Epoch[4/50] train loss: 0.50977, val loss: nan, lr: 0.0010000, time: 36.16
2022-05-24 14:31:32 [INFO]: Epoch[5/50] train loss: 0.47449, val loss: nan, lr: 0.0010000, time: 36.67
2022-05-24 14:32:08 [INFO]: Epoch[6/50] train loss: 0.45279, val loss: nan, lr: 0.0010000, time: 36.56
2022-05-24 14:32:45 [INFO]: Epoch[7/50] train loss: 0.44219, val loss: nan, lr: 0.0010000, time: 36.92
2022-05-24 14:33:22 [INFO]: Epoch[8/50] train loss: 0.43499, val loss: nan, lr: 0.0010000, time: 37.13
2022-05-24 14:34:00 [INFO]: Epoch[9/50] train loss: 0.43162, val loss: nan, lr: 0.0010000, time: 37.21
2022-05-24 14:34:37 [INFO]: Epoch[10/50] train loss: 0.42844, val loss: n

{'n_rnns': 2, 'rnn_hidden_dim': 512, 'weight': 16}


2022-05-24 15:01:29 [INFO]: Epoch[1/50] train loss: 0.93263, val loss: nan, lr: 0.0010000, time: 36.46
2022-05-24 15:02:05 [INFO]: Epoch[2/50] train loss: 0.82941, val loss: nan, lr: 0.0010000, time: 36.21
2022-05-24 15:02:42 [INFO]: Epoch[3/50] train loss: 0.72291, val loss: nan, lr: 0.0010000, time: 36.70
2022-05-24 15:03:18 [INFO]: Epoch[4/50] train loss: 0.63112, val loss: nan, lr: 0.0010000, time: 35.97
2022-05-24 15:03:54 [INFO]: Epoch[5/50] train loss: 0.56852, val loss: nan, lr: 0.0010000, time: 35.68
2022-05-24 15:04:30 [INFO]: Epoch[6/50] train loss: 0.52718, val loss: nan, lr: 0.0010000, time: 36.25
2022-05-24 15:05:06 [INFO]: Epoch[7/50] train loss: 0.50087, val loss: nan, lr: 0.0010000, time: 36.15
2022-05-24 15:05:42 [INFO]: Epoch[8/50] train loss: 0.48951, val loss: nan, lr: 0.0010000, time: 36.01
2022-05-24 15:06:18 [INFO]: Epoch[9/50] train loss: 0.48298, val loss: nan, lr: 0.0010000, time: 36.47
2022-05-24 15:06:55 [INFO]: Epoch[10/50] train loss: 0.47265, val loss: n

In [10]:
print('max accuracy: ', max_acc[1])
acc, pre, rcl, f1 = max_acc[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))
print('\n')

print('max precision: ', max_pre[1])
acc, pre, rcl, f1 = max_pre[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))
print('\n')

print('max recall: ', max_rcl[1])
acc, pre, rcl, f1 = max_rcl[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))
print('\n')

print('max f1: ', max_f1[1])
acc, pre, rcl, f1 = max_f1[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

max accuracy:  {'n_rnns': 1, 'rnn_hidden_dim': 512, 'weight': 8}
accuracy: 0.995 precision: 0.000 recall: 0.000 f1_score: 0.000


max precision:  {'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 8}
accuracy: 0.993 precision: 0.328 recall: 0.528 f1_score: 0.405


max recall:  {'n_rnns': 2, 'rnn_hidden_dim': 256, 'weight': 16}
accuracy: 0.981 precision: 0.156 recall: 0.706 f1_score: 0.256


max f1:  {'n_rnns': 1, 'rnn_hidden_dim': 256, 'weight': 8}
accuracy: 0.993 precision: 0.328 recall: 0.528 f1_score: 0.405


## モデル保存

In [11]:
# select max recall
model = max_models[2]
param = max_rcl[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [12]:
model_path = f'models/passing/pass_model_lstm_recall.pth'
torch.save(model.state_dict(), model_path)

In [13]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_recall.yaml', 'w') as f:
    yaml.dump(config, f)

In [14]:
# select max f1
model = max_models[3]
param = max_f1[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [15]:
model_path = f'models/passing/pass_model_lstm_f1.pth'
torch.save(model.state_dict(), model_path)

In [16]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_f1.yaml', 'w') as f:
    yaml.dump(config, f)

## 検証

In [17]:
# create model
x_dict, y_dict = make_all_data(inds, train_cfg["dataset"]["setting"], grp_cfg["passing"]["default"], logger)

2022-05-24 15:33:05 [INFO]: => createing time series 02_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:02<00:00,  8.17it/s]
2022-05-24 15:33:08 [INFO]: => createing time series 08_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:02<00:00, 15.58it/s]
2022-05-24 15:33:10 [INFO]: => createing time series 09_001
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00,  8.76it/s]
2022-05-24 15:33:12 [INFO]: => extracting feature 02_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [18]:
# setting random seed
np.random.seed(train_cfg["dataset"]["random_seed"])
random_keys = np.random.choice(
    list(x_dict.keys()),
    size=len(x_dict),
    replace=False
)

train_ratio = train_cfg["dataset"]["train_ratio"] + train_cfg["dataset"]["val_ratio"]
train_len = int(len(x_dict) * train_ratio)

train_keys = random_keys[:train_len]
test_keys = random_keys[train_len:]

In [19]:
def plot(x_lst, y_lst, pred, seq_len=30, path=None):
    x_lst = [[0 for _ in range(x_lst.shape[1])]] + [[np.nan for _ in range(x_lst.shape[1])] for i in range(seq_len - 1)] + x_lst.tolist()
    y_lst = [0] + [np.nan for i in range(seq_len - 1)] + y_lst
    pred = [0] + [np.nan for i in range(seq_len - 1)] + pred.tolist()
    
    fig = plt.figure(figsize=(13, 4))
    ax = fig.add_axes((0.04, 0.17, 0.80, 0.81))
    
    ax.plot(pred, label='pred')
    ax.plot(y_lst, linestyle=':', label='ground truth')
    for i, feature in enumerate(np.array(x_lst).T):
        ax.plot(feature, alpha=0.4, label=columns[i])

    ax.set_ylim((-0.05, 1.05))
    ax.set_xlabel('frame')
    ax.legend(
        bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0,
        fontsize=20, handlelength=0.8, handletextpad=0.2
    )
    
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    if path is not None:
        fig.savefig(path)
    plt.show()

In [20]:
save_keys = [
    '02_06_1_3',
]

In [21]:
y_all_train = []
pred_all_train = []
y_eve_train = []
pred_eve_train = []

model.eval()
with torch.no_grad():
    for key in train_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]
        
        x, _ = create_sequence(x_lst, y_lst, **config)
        x = torch.Tensor(x).float().to(device)
        
        if len(x) == 0:
            continue

        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[SEQ_LEN - 1:]
        y_lst = y_lst[SEQ_LEN - 1:]
            
        y_all_train += y_lst
        pred_all_train += pred.tolist()
        y_eve_train.append(1 in y_lst)
        pred_eve_train.append(1 in pred.tolist())
        
        if 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join("data", "image", "passing", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, config["seq_len"], path=path)

KeyError: 'weight'

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_train, pred_all_train)))
print('precision: {:.3f}'.format(precision_score(y_all_train, pred_all_train)))
print('recall: {:.3f}'.format(recall_score(y_all_train, pred_all_train)))
print('f1_score: {:.3f}'.format(f1_score(y_all_train, pred_all_train)))

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_train, pred_eve_train)))
print('precision: {:.3f}'.format(precision_score(y_eve_train, pred_eve_train)))
print('recall: {:.3f}'.format(recall_score(y_eve_train, pred_eve_train)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_train, pred_eve_train)))

In [None]:
save_keys = [
    '08_03_2_5',
]

In [None]:
y_all_test = []
pred_all_test = []
y_eve_test = []
pred_eve_test = []
tn, fn = 0, 0

model.eval()
with torch.no_grad():
    for key in test_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]

        x, _ = create_sequence(x_lst, y_lst, **config)
        x = torch.Tensor(x).float().to(device)

        if len(x) == 0:
            tn += 1
            continue
            
        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[SEQ_LEN - 1:]
        y_lst = y_lst[SEQ_LEN - 1:]
        
        y_all_test += y_lst
        pred_all_test += pred.tolist()
        y_eve_test.append(1 in y_lst)
        pred_eve_test.append(1 in pred.tolist())
        if 1 not in y_lst:
            if 1 not in pred:
                tn += 1
            else:
                fn += 1
        
        if 1 not in pred and 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join(common.data_dir, "image", "passing", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, config["seq_len"], path=path)

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_test, pred_all_test)))
print('precision: {:.3f}'.format(precision_score(y_all_test, pred_all_test)))
print('recall: {:.3f}'.format(recall_score(y_all_test, pred_all_test)))
print('f1_score: {:.3f}'.format(f1_score(y_all_test, pred_all_test)))

cm = confusion_matrix(y_all_test, pred_all_test)
sns.heatmap(cm, cmap='Blues')

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_test, pred_eve_test)))
print('precision: {:.3f}'.format(precision_score(y_eve_test, pred_eve_test)))
print('recall: {:.3f}'.format(recall_score(y_eve_test, pred_eve_test)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_test, pred_eve_test)))

print('true negative:', tn)
print('false negative:', fn)

cm = confusion_matrix(y_eve_test, pred_eve_test)
sns.heatmap(cm, cmap='Blues')