In [None]:
# -*- coding: utf-8 -*-
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

このノートブックの役割
- 検出モデルの学習

#### --- ライブラリ ---

In [None]:
import numpy as np
from numpy.random import *
from os.path import join as pj
from os import getcwd as cwd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data as data
import visdom

# IO関連
from IO.logger import Logger
from IO.visdom import visualize
# 検出データセット
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
# モデル
from model.refinedet.refinedet import RefineDet
from model.refinedet.loss.multiboxloss import RefineDetMultiBoxLoss
from model.refinedet.utils.predict import test_prediction
from model.optimizer import AdamW, RAdam
# 評価関数
from evaluation.detection.evaluate import Voc_Evaluater

#### --- 学習コンフィグ ---

In [None]:
class args:
    # 実験名
    experiment_name: str = ""
    # パス
    data_root: str = pj(cwd(), "")
    train_image_root: str = pj(cwd(), "")
    train_target_root: str = pj(cwd(), "")
    test_image_root: str = pj(cwd(), "")
    test_target_root: str = pj(cwd(), "")
    model_root: str = pj(cwd(), "", experiment_name)
    prc_root: str = pj(cwd(), "", experiment_name)
    # 学習時の設定
    input_size: int = 512 # [320, 512, 1024]から一つ選択
    crop_num = (5, 5) # (w: int, h: int)
    batch_size: int = 2
    lr: float = 1e-4
    lamda: float = 1e-4
    tcb_layer_num: int = 6
    use_extra_layer: bool = True
    max_epoch: int = 100
    valid_interval: int = 5
    save_interval: int = 20
    pretrain: bool = True
    freeze: bool = True
    optimizer: str = "AdamW" # ["Adam, AdamW", "RAdam"]から一つ選択
    activation_function: str = "ReLU" # ["ReLU", "LeakyReLU", "RReLU"]から一つ選択, 他にも色々使える
    init_function: str = "xavier_uniform_" # ["xavier_uniform_", "xavier_normal_", "kaiming_uniform_", "kaiming_normal_"]から一つ選択, 他にも色々使える
    method_aug = ["All"] # dataset.classification.dataset.create_aug_seqにあるものから選択(複数可)
    size_normalization: bool = False
    augment_target: bool = False
    use_GN_WS: bool = False
    # Visdom
    visdom: bool = False
    visdom_port: int = 8097
    # モデルタイプ
    model_detect_type: str = "all" # ["all", "each", "det2cls"]から一つ選択

In [None]:
if args.model_detect_type == "all":
    args.class_num = 2
elif args.model_detect_type == "each":
    args.class_num = 13
elif args.model_detect_type == "det2cls":
    args.class_num = 3
else:
    print("error! choice from all, each, det2cls")

#### --- CUDA関連 ---

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')
torch.multiprocessing.set_start_method('spawn')

#### --- Visdom ---

In [None]:
if args.visdom:
    # Visdomを起動
    vis = visdom.Visdom(port=args.visdom_port)
    
    """ARM Loss"""
    win_arm_loc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='arm_loc_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_arm_conf = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='arm_conf_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    """ODM Loss"""
    win_odm_loc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='odm_loc_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_odm_conf = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='odm_conf_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    """Norm Loss"""
    win_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='normalization_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    """全体 Loss"""
    win_all_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    if args.model_detect_type == "all":
        """識別率"""
        win_train_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='train_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )
        win_test_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='test_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )
    elif args.model_detect_type == "det2cls":
        """識別率"""
        win_train_aquatic_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='train_aquatic_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )
        win_train_other_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='train_other_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )
        win_test_aquatic_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='test_aquatic_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )
        win_test_other_acc = vis.line(
            X=np.array([0]),
            Y=np.array([0]),
            opts=dict(
                title='test_other_accuracy',
                xlabel='epoch',
                ylabel='average precision',
                width=800,
                height=400
            )
        )

#### --- 学習関連 ---

In [None]:
def train_per_epoch(epoch, train_dataloader, opt, model, arm_loss, odm_loss, l2_loss, batch_size, lamda=0., visdom=False):
    """
        1epochの学習コード
        引数:
            - epoch: int, 現在のエポック, 可視化に用いる
            - train_dataloader: データローダ
            - opt: 最適化器
            - model: モデル
            - arm_loss: ARMの誤差関数
            - odm_loss: ODMの誤差関数
            - l2_loss: L2誤差
            - batch_size: int, 学習時のバッチサイズ
            - lamda: float, モデル重み正規化での重み
            - visdom: bool, visdomで可視化するかどうか
    """
    # set model train mode
    model.train()

    # create loss counters
    arm_loc_loss = 0
    arm_conf_loss = 0
    odm_loc_loss = 0
    odm_conf_loss = 0
    all_norm_loss = 0

    # training
    for images, targets, _, _, _ in tqdm(train_dataloader, leave=False):
        imgs = np.array(images[0])
        tars = targets[0]

        # define batch_num
        if (imgs.shape[0] % batch_size == 0):
            batch_num = int(imgs.shape[0] / batch_size)
        else:
            batch_num = int(imgs.shape[0] / batch_size) + 1

        # random sample of batch
        iter_batch = choice(range(batch_num), batch_num, replace=False)

        # train for cropped image
        for i in iter_batch:
            images = imgs[i * batch_size:(i+1) * batch_size]
            targets = tars[i * batch_size:(i+1) * batch_size]

            # set cuda
            images = torch.from_numpy(images).cuda()
            targets = [ann.cuda() for ann in targets]

            # forward
            out = model(images)

            # calculate loss
            opt.zero_grad()
            arm_loss_l, arm_loss_c = arm_loss(out, targets)
            odm_loss_l, odm_loss_c = odm_loss(out, targets)
            arm_loss = arm_loss_l + arm_loss_c
            odm_loss = odm_loss_l + odm_loss_c
            loss = arm_loss + odm_loss

            if lamda != 0:
                norm_loss = 0
                for param in model.parameters():
                    param_target = torch.zeros(param.size()).cuda()
                    norm_loss += l2_loss(param, param_target)

                norm_loss = norm_loss * lamda
                loss += norm_loss
            else:
                norm_loss = 0

            if torch.isnan(loss) == 0:
                loss.backward()
                optimizer.step()
                arm_loc_loss += arm_loss_l.item()
                arm_conf_loss += arm_loss_c.item()
                odm_loc_loss += odm_loss_l.item()
                odm_conf_loss += odm_loss_c.item()
                all_norm_loss += norm_loss.item()

    print('epoch ' + str(epoch) + ' || ARM_L Loss: %.4f ARM_C Loss: %.4f ODM_L Loss: %.4f ODM_C Loss: %.4f NORM Loss: %.4f ||' \
    % (arm_loc_loss, arm_conf_loss, odm_loc_loss, odm_conf_loss, all_norm_loss))

    # visualize
    if visdom:
        visualize(vis, epoch+1, arm_loc_loss, win_arm_loc)
        visualize(vis, epoch+1, arm_conf_loss, win_arm_conf)
        visualize(vis, epoch+1, odm_loc_loss, win_odm_loc)
        visualize(vis, epoch+1, odm_conf_loss, win_odm_conf)
        visualize(vis, epoch+1, all_norm_loss, win_norm_loss)
        visualize(vis, epoch+1, arm_loc_loss + arm_conf_loss + odm_loc_loss + odm_conf_loss + all_norm_loss, win_all_loss)

#### --- 評価関連 ---

In [None]:
def validate(evaluater, dataloader, model, crop_num, num_classes=2, nms_thresh=0.5):
    """
        モデルのVOC-APを計算する
        引数:
            - evaluater: Voc_Evaluater, VOC-APを計算するクラス
            - dataloader: データローダ
            - model: モデル
            - crop_num: (int, int), (縦のクロップ数, 横のクロップ数)
            - num_classes: int, 分類するクラス数(前景+背景)
            - nms_thresh: Non Maximum Suppressionを適用するconfidenceの閾値
    """
    result = test_prediction(model, dataloader, crop_num, num_classes, nms_thresh)
    evaluater.set_result(result)
    eval_metrics = evaluater.get_eval_metrics()
    return eval_metrics

#### --- コンフィグの保存 ---

In [None]:
args_logger = Logger(args)
args_logger.save()

#### --- データ作成 ---

In [None]:
print('Loading dataset for train ...')
train_dataset = insects_dataset_from_voc_style_txt(args.train_image_root, args.input_size, args.crop_num, training=True, 
                                                   target_root=args.train_target_root, method_crop="SPREAD_ALL_OVER", 
                                                   method_aug=args.method_aug, model_detect_type=args.model_detect_type, 
                                                   size_normalization=args.size_normalization, augment_target=args.augment_target)
train_dataloader = data.DataLoader(train_dataset, 1, num_workers=0, shuffle=True, collate_fn=collate_fn)
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, args.input_size, args.crop_num, training=False)
test_dataloader = data.DataLoader(test_dataset, 1, num_workers=0, shuffle=False, collate_fn=collate_fn)
train_valid_dataset = insects_dataset_from_voc_style_txt(args.train_image_root, args.input_size, args.crop_num, training=False)
train_valid_dataloader = data.DataLoader(train_valid_dataset, 1, num_workers=0, shuffle=False, collate_fn=collate_fn)

#### --- モデル作成 ---

In [None]:
model = RefineDet(args.input_size, args.class_num, args.tcb_layer_num, pretrain=args.pretrain, freeze=args.freeze, activation_function=args.activation_function, init_function=args.init_function, use_extra_layer=args.use_extra_layer, use_GN_WS=args.use_GN_WS)
print(model)

#### --- 最適化器作成 ---

In [None]:
if args.optimizer == "Adam":
    print("optimizer == Adam")
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
elif args.optimizer == "AdamW":
    print("optimizer == AdamW")
    opt = AdamW(model.parameters(), lr=args.lr)
elif args.optimizer == "RAdam":
    print("optimizer == RAdam")
    opt = RAdam(model.parameters(), lr=args.lr)

#### --- 誤差定義 ---

In [None]:
arm_loss = RefineDetMultiBoxLoss(2, use_ARM=False) # ARMの誤差, クラス数は2で固定
odm_loss = RefineDetMultiBoxLoss(args.class_num, use_ARM=True) # ODMの誤差, クラス数は背景+前景クラス
l2_loss = nn.MSELoss(reduction='mean')

#### --- メイン処理 ---

In [None]:
def save_model_and_disp_ap(test_ap, best_test_ap, model, model_root, epoch):
    """
        最良モデルと, その時のAPを保存する
        引数:
            - test_ap: float, 現在のAP
            - best_test_ap: float, 最良のAP
            - model: モデル
            - model_root: str, モデルを保存する場所
            - epoch: int, 現在のエポック
    """
    if test_ap > best_test_ap:
        best_test_ap = test_ap
        torch.save(model.state_dict(), pj(model_root, "best.pth"))
        with open(pj(model_root, "best_AP.txt"), mode="w") as f:
            f.write("epoch = {}, test_ap = {}".format(epoch, test_ap))
    return best_test_ap

In [None]:
if os.path.exists(pj(args.prc_root, "train")) is False:
    os.makedirs(pj(args.prc_root, "train"))
if os.path.exists(pj(args.prc_root, "test")) is False:
    os.makedirs(pj(args.prc_root, "test"))

In [None]:
train_evaluater = Voc_Evaluater(args.train_image_root, args.train_target_root, pj(args.prc_root, "train"))
test_evaluater = Voc_Evaluater(args.test_image_root, args.test_target_root, pj(args.prc_root, "test"))
# set best AP
best_test_ap = 0

for epoch in range(args.max_epoch):
    train_per_epoch(epoch, train_dataloader, opt, model, arm_loss, odm_loss, l2_loss, args.batch_size, lamda=args.lamda, visdom=args.visdom)
    
    # validate model
    if epoch != 0 and epoch % args.valid_interval == 0:
        train_eval_metrics = validate(train_evaluater, train_valid_dataloader, model, args.crop_num, num_classes=args.class_num, nms_thresh=0.5)
        test_eval_metrics = validate(test_evaluater, test_dataloader, model, args.crop_num, num_classes=args.class_num, nms_thresh=0.5)
        if args.model_detect_type == "all":
            train_ap = train_eval_metrics[0]['AP']
            test_ap = test_eval_metrics[0]['AP']
            best_test_ap = save_model_and_disp_ap(test_ap, best_test_ap, model, args.model_root, epoch)
            print("epoch: {}, train_ap={}, test_ap={}".format(epoch, train_ap, test_ap))
            if args.visdom:
                visualize(vis, epoch+1, train_ap, win_train_acc)
                visualize(vis, epoch+1, test_ap, win_test_acc)
        elif args.model_detect_type == "det2cls":
            train_aquatic_ap = train_eval_metrics[0]['AP']
            train_other_ap = train_eval_metrics[1]['AP']
            test_aquatic_ap = test_eval_metrics[0]['AP']
            test_other_ap = test_eval_metrics[1]['AP']
            train_map = (train_aquatic_ap + train_other_ap) / 2
            test_map = (test_aquatic_ap + test_other_ap) / 2
            best_test_ap = save_model_and_disp_ap(test_aquatic_ap, best_test_ap, model, args.model_root, epoch)
            print("epoch: {}".format(epoch))
            print("train: aquatic_ap={}, other_ap={}".format(train_aquatic_ap, train_other_ap))
            print("test: aquatic_ap={}, other_ap={}".format(test_aquatic_ap, test_other_ap))
            print("mean_train_ap={}, mean_test_ap={}".format(train_map, test_map))
            if args.visdom:
                visualize(vis, epoch+1, train_aquatic_ap, win_train_aquatic_acc)
                visualize(vis, epoch+1, train_other_ap, win_train_other_acc)
                visualize(vis, epoch+1, test_aquatic_ap, win_test_aquatic_acc)
                visualize(vis, epoch+1, test_other_ap, win_test_other_acc)
    
    # save model
    if epoch != 0 and epoch % args.save_interval == 0:
        print('Saving state, epoch: ' + str(epoch))
        torch.save(model.state_dict(), pj(args.model_root, "epoch{}.pth".format(epoch)))