In [1]:
# back to project root
%cd ~/research

import argparse
import gc
import os
import sys
import time
from glob import glob

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from tqdm import tqdm
import yaml
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from torch import nn, optim

sys.path.append("src")
from group.passing.dataset import make_data_loaders, make_all_data
from group.passing.lstm_model import LSTMModel
from utility.activity_loader import load_individuals
from utility.logger import logger
from tools.train_passing import init_model, init_loss, init_optim, train, test

/raid6/home/yokoyama/research


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams["font.size"] = 24
plt.rcParams['xtick.direction'] = 'in'  # x axis in
plt.rcParams['ytick.direction'] = 'in'  # y axis in

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
cfg_path = "config/passing/pass_train.yaml"
with open(cfg_path, "r") as f:
    train_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["individual"], "r") as f:
    ind_cfg = yaml.safe_load(f)
with open(train_cfg["config_path"]["group"], "r") as f:
    grp_cfg = yaml.safe_load(f)

In [5]:
data_dirs_all = {}
for room_num, surgery_items in train_cfg["dataset"]["setting"].items():
    for surgery_num in surgery_items.keys():
        dirs = sorted(glob(os.path.join("data", room_num, surgery_num, "passing", "*")))
        data_dirs_all[f"{room_num}_{surgery_num}"] = dirs

inds = {}
for key_prefix, dirs in tqdm(data_dirs_all.items()):
    for model_path in dirs:
        num = model_path.split("/")[-1]
        json_path = os.path.join(model_path, ".json", "individual.json")
        tmp_inds = load_individuals(json_path, ind_cfg)
        for pid, ind in tmp_inds.items():
            inds[f"{key_prefix}_{num}_{pid}"] = ind

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:07<00:00, 11.25s/it]


# グリッドサーチ

In [6]:
# create data loader
dataset_cfg = train_cfg["dataset"]
passing_defs = grp_cfg["passing"]["default"]
train_loader, val_loader, test_loader = make_data_loaders(
    inds, dataset_cfg, passing_defs, logger
)

2022-08-13 20:29:24,408 => createing time series 02_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00,  4.04it/s]
2022-08-13 20:29:28,618 => createing time series 07_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:29<00:00,  1.63it/s]
2022-08-13 20:29:57,998 => createing time series 08_001
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:05<00:00,  6.50it/s]
2022-08-13 20:30:03,847 => createing time series 08_002
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45/45 [00:32<00:00,  1.40it/s]


In [7]:
# model config
mdl_cfg = {
    "dropouts": [0.1, 0],
    "hidden_dims": [128, 64],
    "n_classes": 2,
    "n_linears": 2,
    "rnn_dropout": 0.1,
    "size": 4,
}

# grid search parameters
params = {
    'n_rnns': [1, 2, 3],
    'rnn_hidden_dim': [128, 256],
    'pos_weight': [8, 16, 32]
}

# epoch
# epoch_len = train_cfg["optim"]["epoch"]
epoch_len = 150

In [None]:
max_acc = [[0, 0, 0, 0], None]
max_pre = [[0, 0, 0, 0], None]
max_rcl = [[0, 0, 0, 0], None]
max_f1 = [[0, 0, 0, 0], None]
max_models = [None for _ in range(4)]

for n_rnns in params['n_rnns']:
    for dim in params['rnn_hidden_dim']:
        for weight in params['pos_weight']:
            param = dict(n_rnns=n_rnns, rnn_hidden_dim=dim, weight=weight)
            print(param)
            
            # update config
            config = {}
            for key, val in mdl_cfg.items():
                config[key] = val
            for key, val in param.items():
                config[key] = val
            pos_weight = param["weight"]
                
            # init model, loss, optim
            model = init_model(config, device)
            criterion = init_loss([1, pos_weight], device)
            optimizer, scheduler = init_optim(
                model, train_cfg["optim"]["lr"], train_cfg["optim"]["lr_rate"]
            )
            
            # training
            model, epoch, history = train(
                model, train_loader, val_loader,
                criterion, optimizer, scheduler,
                epoch_len, logger, device
            )
            
            # test
            score = test(model, test_loader, logger, device)
            acc, pre, rcl, f1 = score
            
            # update max scores
            if acc > max_acc[0][0]:
                max_acc[0] = score
                max_acc[1] = param
                max_models[0] = model
            if pre > max_pre[0][1]:
                max_pre[0] = score
                max_pre[1] = param
                max_models[1] = model
            if rcl > max_rcl[0][2]:
                max_rcl[0] = score
                max_rcl[1] = param
                max_models[2] = model
            if f1 > max_f1[0][3]:
                max_f1[0] = score
                max_f1[1] = param
                max_models[3] = model
                
            torch.cuda.empty_cache()

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 8}


2022-08-13 20:34:18,813 => start training
2022-08-13 20:34:34,688 Epoch[1/150] train loss: 0.44169, val loss: nan, lr: 0.0010000, time: 15.87
2022-08-13 20:34:51,725 Epoch[2/150] train loss: 0.39675, val loss: nan, lr: 0.0010000, time: 17.04
2022-08-13 20:35:08,566 Epoch[3/150] train loss: 0.39408, val loss: nan, lr: 0.0010000, time: 16.84
2022-08-13 20:35:25,968 Epoch[4/150] train loss: 0.39292, val loss: nan, lr: 0.0010000, time: 17.40
2022-08-13 20:35:42,884 Epoch[5/150] train loss: 0.39169, val loss: nan, lr: 0.0010000, time: 16.91
2022-08-13 20:35:59,812 Epoch[6/150] train loss: 0.39053, val loss: nan, lr: 0.0010000, time: 16.93
2022-08-13 20:36:16,743 Epoch[7/150] train loss: 0.38935, val loss: nan, lr: 0.0010000, time: 16.93
2022-08-13 20:36:34,796 Epoch[8/150] train loss: 0.38854, val loss: nan, lr: 0.0010000, time: 18.05
2022-08-13 20:36:52,401 Epoch[9/150] train loss: 0.38782, val loss: nan, lr: 0.0010000, time: 17.60
2022-08-13 20:37:08,964 Epoch[10/150] train loss: 0.38691,

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 16}


2022-08-13 21:17:38,091 Epoch[1/150] train loss: 0.50015, val loss: nan, lr: 0.0010000, time: 17.23
2022-08-13 21:17:54,331 Epoch[2/150] train loss: 0.42949, val loss: nan, lr: 0.0010000, time: 16.24
2022-08-13 21:18:09,304 Epoch[3/150] train loss: 0.42559, val loss: nan, lr: 0.0010000, time: 14.97
2022-08-13 21:18:26,046 Epoch[4/150] train loss: 0.42368, val loss: nan, lr: 0.0010000, time: 16.74
2022-08-13 21:18:42,763 Epoch[5/150] train loss: 0.42201, val loss: nan, lr: 0.0010000, time: 16.72
2022-08-13 21:19:01,382 Epoch[6/150] train loss: 0.42061, val loss: nan, lr: 0.0010000, time: 18.62
2022-08-13 21:19:18,373 Epoch[7/150] train loss: 0.41930, val loss: nan, lr: 0.0010000, time: 16.99
2022-08-13 21:19:35,313 Epoch[8/150] train loss: 0.41788, val loss: nan, lr: 0.0010000, time: 16.94
2022-08-13 21:19:51,854 Epoch[9/150] train loss: 0.41592, val loss: nan, lr: 0.0010000, time: 16.54
2022-08-13 21:20:08,847 Epoch[10/150] train loss: 0.41537, val loss: nan, lr: 0.0010000, time: 16.99

{'n_rnns': 1, 'rnn_hidden_dim': 128, 'weight': 32}


2022-08-13 22:00:23,345 Epoch[1/150] train loss: 0.52872, val loss: nan, lr: 0.0010000, time: 16.62
2022-08-13 22:00:39,786 Epoch[2/150] train loss: 0.47989, val loss: nan, lr: 0.0010000, time: 16.44
2022-08-13 22:00:55,945 Epoch[3/150] train loss: 0.47532, val loss: nan, lr: 0.0010000, time: 16.16
2022-08-13 22:01:13,021 Epoch[4/150] train loss: 0.47224, val loss: nan, lr: 0.0010000, time: 17.07
2022-08-13 22:01:30,104 Epoch[5/150] train loss: 0.46946, val loss: nan, lr: 0.0010000, time: 17.08
2022-08-13 22:01:47,487 Epoch[6/150] train loss: 0.46707, val loss: nan, lr: 0.0010000, time: 17.38
2022-08-13 22:02:04,434 Epoch[7/150] train loss: 0.46506, val loss: nan, lr: 0.0010000, time: 16.95
2022-08-13 22:02:22,336 Epoch[8/150] train loss: 0.46381, val loss: nan, lr: 0.0010000, time: 17.90
2022-08-13 22:02:39,848 Epoch[9/150] train loss: 0.46231, val loss: nan, lr: 0.0010000, time: 17.51
2022-08-13 22:02:56,579 Epoch[10/150] train loss: 0.46021, val loss: nan, lr: 0.0010000, time: 16.73

In [None]:
print(f"epoch={epoch}")
print('max accuracy: ', max_acc[1])
acc, pre, rcl, f1 = max_acc[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max precision: ', max_pre[1])
acc, pre, rcl, f1 = max_pre[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max recall: ', max_rcl[1])
acc, pre, rcl, f1 = max_rcl[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

print('max f1: ', max_f1[1])
acc, pre, rcl, f1 = max_f1[0]
print('accuracy: {:.3f}'.format(acc), 'precision: {:.3f}'.format(pre), 'recall: {:.3f}'.format(rcl), 'f1_score: {:.3f}'.format(f1))

## モデル保存

In [None]:
# select max recall
model = max_models[2]
param = max_rcl[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [None]:
model_path = f'models/passing/pass_model_lstm_recall_ep{epoch}.pth'
torch.save(model.state_dict(), model_path)

In [None]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_recall_ep{epoch}.yaml', 'w') as f:
    yaml.dump(config, f)

In [None]:
# select max f1
model = max_models[3]
param = max_f1[1]
config = {}
for key, val in mdl_cfg.items():
    config[key] = val
for key, val in param.items():
    config[key] = val

In [None]:
model_path = f'models/passing/pass_model_lstm_f1_ep{epoch}.pth'
torch.save(model.state_dict(), model_path)

In [None]:
config["pretrained_path"] = model_path
with open(f'config/passing/pass_model_lstm_f1_ep{epoch}.yaml', 'w') as f:
    yaml.dump(config, f)

# 検証
## モデルロード

In [None]:
# load model
epoch = 150
rcl_f1 = "f1"

try:
    torch.cuda.empty_cache()
    del model
    gc.collect()
except NameError:
    pass

mdl_cfg_path = f'config/passing/pass_model_lstm_{rcl_f1}_ep{epoch}.yaml'
with open(mdl_cfg_path, "r") as f:
    mdl_cfg = yaml.safe_load(f)
model = init_model(mdl_cfg, device)

param = torch.load(mdl_cfg["pretrained_path"])
model.load_state_dict(param)

## データロード

In [None]:
x_dict, y_dict = make_all_data(inds, train_cfg["dataset"]["setting"], grp_cfg["passing"]["default"], logger)

In [None]:
np.random.seed(train_cfg["dataset"]["random_seed"])

seq_len = grp_cfg["passing"]["default"]["seq_len"]
size = mdl_cfg["size"]

keys_1 = [key for key in x_dict if 1 in y_dict[key]]
keys_0 = [key for key in x_dict if 1 not in y_dict[key]]
random_keys_1 = np.random.choice(keys_1, size=len(keys_1), replace=False)
random_keys_0 = np.random.choice(keys_0, size=len(keys_0), replace=False)

train_ratio = train_cfg["dataset"]["train_ratio"]
val_ratio = train_cfg["dataset"]["val_ratio"]
train_len_1 = int(len(keys_1) * train_ratio)
train_len_0 = int(len(keys_0) * train_ratio)
val_len_1 = int(len(keys_1) * val_ratio)
val_len_0 = int(len(keys_0) * val_ratio)

train_keys_1 = random_keys_1[:train_len_1].tolist()
val_keys_1 = random_keys_1[train_len_1 : train_len_1 + val_len_1].tolist()
test_keys_1 = random_keys_1[train_len_1 + val_len_1 :].tolist()
train_keys_0 = random_keys_0[:train_len_0].tolist()
test_keys_0 = random_keys_0[train_len_0:].tolist()
val_keys_0 = random_keys_1[train_len_0 : train_len_0 + val_len_0].tolist()

train_keys = sorted(train_keys_1 + train_keys_0)
val_keys = sorted(val_keys_1 + val_keys_0)
test_keys = sorted(test_keys_1 + test_keys_0)

In [None]:
def create_sequence(x_lst, y_lst, seq_len=30, size=4):
    x_seq = []
    y_seq = []
    for i in range(len(x_lst) - seq_len + 1):
        x = x_lst[i:i + seq_len]
        x_seq.append(x)
        y_seq.append(y_lst[i + seq_len - 1])
    
    return x_seq, y_seq


columns = ["distance", "body_direction", "arm_ave", "wrist_distance"]
def plot(x_lst, y_lst, pred, seq_len=30, path=None):
    x_lst = [[0 for _ in range(x_lst.shape[1])]] + [[np.nan for _ in range(x_lst.shape[1])] for i in range(seq_len - 1)] + x_lst.tolist()
    y_lst = [0] + [np.nan for i in range(seq_len - 1)] + y_lst
    pred = [0] + [np.nan for i in range(seq_len - 1)] + pred.tolist()
    
    fig = plt.figure(figsize=(13, 4))
    ax = fig.add_axes((0.04, 0.17, 0.80, 0.81))
    
    ax.plot(pred, label='pred')
    ax.plot(y_lst, linestyle=':', label='ground truth')
    for i, feature in enumerate(np.array(x_lst).T):
        ax.plot(feature, alpha=0.4, label=columns[i])

    ax.set_ylim((-0.05, 1.05))
    ax.set_xlabel('frame')
    ax.legend(
        bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0,
        fontsize=20, handlelength=0.8, handletextpad=0.2
    )
    
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    if path is not None:
        fig.savefig(path)
    plt.show()

## トレインデータ

In [None]:
save_keys = [
    '02_06_1_3',
]

In [None]:
y_all_train = []
pred_all_train = []
y_eve_train = []
pred_eve_train = []

model.eval()
with torch.no_grad():
    for key in train_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]
        
        x, _ = create_sequence(x_lst, y_lst, seq_len, size)
        x = torch.Tensor(np.array(x)).float().to(device)
        
        if len(x) == 0:
            continue

        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[seq_len - 1:]
        y_lst = y_lst[seq_len - 1:]
            
        y_all_train += y_lst
        pred_all_train += pred.tolist()
        y_eve_train.append(1 in y_lst)
        pred_eve_train.append(1 in pred.tolist())
        
        if 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join("data", "passing", "image", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, seq_len, path=path)

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_train, pred_all_train)))
print('precision: {:.3f}'.format(precision_score(y_all_train, pred_all_train)))
print('recall: {:.3f}'.format(recall_score(y_all_train, pred_all_train)))
print('f1_score: {:.3f}'.format(f1_score(y_all_train, pred_all_train)))

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_train, pred_eve_train)))
print('precision: {:.3f}'.format(precision_score(y_eve_train, pred_eve_train)))
print('recall: {:.3f}'.format(recall_score(y_eve_train, pred_eve_train)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_train, pred_eve_train)))

## テストデータ

In [None]:
save_keys = [
    '08_03_2_5',
]

In [None]:
y_all_test = []
pred_all_test = []
y_eve_test = []
pred_eve_test = []
tn, fn = 0, 0

model.eval()
with torch.no_grad():
    for key in test_keys:
        x_lst = np.array(x_dict[key])
        y_lst = y_dict[key]

        x, _ = create_sequence(x_lst, y_lst, seq_len, size)
        x = torch.Tensor(x).float().to(device)

        if len(x) == 0:
            tn += 1
            continue
            
        pred = model(x)
        pred = pred.max(1)[1]
        pred = pred.cpu().numpy()

        x_lst = x_lst[seq_len - 1:]
        y_lst = y_lst[seq_len - 1:]
        
        y_all_test += y_lst
        pred_all_test += pred.tolist()
        y_eve_test.append(1 in y_lst)
        pred_eve_test.append(1 in pred.tolist())
        if 1 not in y_lst:
            if 1 not in pred:
                tn += 1
            else:
                fn += 1
        
        if 1 not in pred and 1 not in y_lst:
            continue
            
        print(key)
        path = None
        if key in save_keys:
            path = os.path.join("data", "passing", "image", f"rnn_test_{key}.pdf")
        plot(x_lst, y_lst, pred, seq_len, path=path)

In [None]:
print('accuracy: {:.3f}'.format(accuracy_score(y_all_test, pred_all_test)))
print('precision: {:.3f}'.format(precision_score(y_all_test, pred_all_test)))
print('recall: {:.3f}'.format(recall_score(y_all_test, pred_all_test)))
print('f1_score: {:.3f}'.format(f1_score(y_all_test, pred_all_test)))

cm = confusion_matrix(y_all_test, pred_all_test)
sns.heatmap(cm, cmap='Blues')

In [None]:
# per event
print('accuracy: {:.3f}'.format(accuracy_score(y_eve_test, pred_eve_test)))
print('precision: {:.3f}'.format(precision_score(y_eve_test, pred_eve_test)))
print('recall: {:.3f}'.format(recall_score(y_eve_test, pred_eve_test)))
print('f1_score: {:.3f}'.format(f1_score(y_eve_test, pred_eve_test)))

print('true negative:', tn)
print('false negative:', fn)

cm = confusion_matrix(y_eve_test, pred_eve_test)
sns.heatmap(cm, cmap='Blues')