In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import cv2
import numpy as np
from numpy.random import *
from os import listdir as ld
from os.path import join as pj
import sys
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.utils.data as data
import visdom

# Logger
from IO.logger import Logger
# Loader
from IO.loader import load_path, load_images, load_annotations_path, load_annotations, get_anno_recs
# Dataset
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
# Loss Function
from model.refinedet.loss.multiboxloss import RefineDetMultiBoxLoss
# Model initializer
from model.refinedet.utils.initializer import initialize_model
# Predict
from model.refinedet.utils.predict import test_prediction
# Evaluate
from evaluation.detection.evaluate import evaluate

# Train Config

In [None]:
class args:
    # experiment name
    experiment_name = "crop_b2_2_4_8_16_32_im512_CSL_param2"
    # paths
    data_root = "/home/tanida/workspace/Insect_Phenology_Detector/data"
    train_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_refined_images"
    train_target_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_detection_data/refinedet_all"
    test_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/test_refined_images"
    basenet = "/home/tanida/workspace/Insect_Phenology_Detector/output_model/detection/RefineDet/weights/vgg16_reducedfc.pth"
    model_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/output_model/detection/RefineDet", experiment_name)
    test_anno_folders = ["annotations_4"]
    # training config
    input_size = 512 # choices=[320, 512, 1024]
    crop_num = (5, 5)
    batch_size = 2
    num_workers = 2
    lr = 1e-4
    lamda = 1e-4
    tcb_layer_num = 5
    rm_last = True
    max_epoch = 100
    valid_interval = 2
    save_interval = 20
    pretrain = False
    use_CSL = True
    CSL_weight = [0.8, 1.2]
    # visualization
    visdom = True
    visdom_port = 8097

# Model Config

In [None]:
if args.tcb_layer_num == 4 and args.rm_last == False:
    from model.refinedet.config import tcb_4_rm_false as insect_refinedet
elif args.tcb_layer_num == 4 and args.rm_last == True:
    from model.refinedet.config import tcb_4_rm_true as insect_refinedet
elif args.tcb_layer_num == 5 and args.rm_last == False:
    from model.refinedet.config import tcb_5_rm_false as insect_refinedet
elif args.tcb_layer_num == 5 and args.rm_last == True:
    from model.refinedet.config import tcb_5_rm_true as insect_refinedet
elif args.tcb_layer_num == 6 and args.rm_last == False:
    from model.refinedet.config import tcb_6_rm_false as insect_refinedet

# Set cuda

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# Set Visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.visdom_port)
    
    """train_lossl"""
    win_arm_loc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='arm_loc_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_arm_conf = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='arm_conf_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_odm_loc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='odm_loc_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_odm_conf = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='odm_conf_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='normalization_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_all_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_valid_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='validation_accuracy',
            xlabel='epoch',
            ylabel='average precision',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

# Model

In [None]:
if args.rm_last == True:
    from model.refinedet.refinedet_rmlast import build_refinedet
else:
    from model.refinedet.refinedet import build_refinedet

# Train and Test

In [None]:
arm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5,
                         False, True, use_CSL=args.use_CSL, CSL_weight=args.CSL_weight)
odm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, 3, 0.5,
                         False, True, use_ARM=True, use_CSL=args.use_CSL, CSL_weight=args.CSL_weight)
l2_loss = nn.MSELoss(reduction='elementwise_mean').cuda()

In [None]:
def train_per_epoch(model, data_loader, optimizer, epoch):
    # set refinedet to train mode
    model.train()
    model.phase = "train"

    # create loss counters
    arm_loc_loss = 0
    arm_conf_loss = 0
    odm_loc_loss = 0
    odm_conf_loss = 0
    all_norm_loss = 0

    # train
    for images, targets, _, _, _ in tqdm(data_loader, leave=False):
        imgs = np.asarray(images[0])
        tars = targets[0]

        refined_imgs = []
        refined_tars = []
        # refine imgs, tars
        for i in range(imgs.shape[0]):
            if tars[i].size(0) > 0:
                refined_imgs.append(imgs[i])
                refined_tars.append(tars[i])
        imgs = np.asarray(refined_imgs)
        tars = refined_tars

        # define batch_num
        if (imgs.shape[0]%args.batch_size == 0):
            batch_num = int(imgs.shape[0]/args.batch_size)
        else:
            batch_num = int(imgs.shape[0]/args.batch_size) + 1

        # random sample of batch
        iter_batch = choice(range(batch_num), batch_num, replace=False)

        # train for cropped image
        for i in iter_batch:
            images = imgs[i*args.batch_size:(i+1)*args.batch_size]
            targets = tars[i*args.batch_size:(i+1)*args.batch_size]

            # set cuda
            images = torch.from_numpy(images).cuda()
            targets = [ann.cuda() for ann in targets]

            # forward
            out = model(images)

            # calculate loss
            optimizer.zero_grad()
            arm_loss_l, arm_loss_c = arm_criterion(out, targets)
            odm_loss_l, odm_loss_c = odm_criterion(out, targets)
            arm_loss = arm_loss_l + arm_loss_c
            odm_loss = odm_loss_l + odm_loss_c
            loss = arm_loss + odm_loss

            if args.lamda != 0:
                norm_loss = 0
                for param in model.parameters():
                    param_target = torch.zeros(param.size()).cuda()
                    norm_loss += l2_loss(param, param_target)

                norm_loss = norm_loss * args.lamda
                loss += norm_loss
            else:
                norm_loss = 0

            if torch.isnan(loss) == 0:
                loss.backward()
                optimizer.step()
                arm_loc_loss += arm_loss_l.item()
                arm_conf_loss += arm_loss_c.item()
                odm_loc_loss += odm_loss_l.item()
                odm_conf_loss += odm_loss_c.item()
                all_norm_loss += norm_loss.item()

    print('epoch ' + str(epoch) + ' || ARM_L Loss: %.4f ARM_C Loss: %.4f ODM_L Loss: %.4f ODM_C Loss: %.4f NORM Loss: %.4f ||' \
    % (arm_loc_loss, arm_conf_loss, odm_loc_loss, odm_conf_loss, all_norm_loss))

    # visualize
    if args.visdom:
        visualize(epoch+1, arm_loc_loss, win_arm_loc)
        visualize(epoch+1, arm_conf_loss, win_arm_conf)
        visualize(epoch+1, odm_loc_loss, win_odm_loc)
        visualize(epoch+1, odm_conf_loss, win_odm_conf)
        visualize(epoch+1, all_norm_loss, win_norm_loss)
        visualize(epoch+1, arm_loc_loss + arm_conf_loss + odm_loc_loss + odm_conf_loss + all_norm_loss, win_all_loss)

        
def validate(model, data_loader, recs, crop_num, nms_thresh=0.3):
    # set Faster_RCNN to eval mode
    model.eval()
    # predict and evaluate
    result = test_prediction(model, data_loader, crop_num, nms_thresh=nms_thresh)
    recall, precision, avg_precision, gt_dict = evaluate(result, recs, ovthresh=0.3)
    
    return avg_precision

### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

### Get model config

In [None]:
cfg = insect_refinedet[str(args.input_size)]

### Make data

In [None]:
print('Loading dataset for train ...')
train_dataset = insects_dataset_from_voc_style_txt(args.train_image_root, args.input_size, args.crop_num, "RefineDet", training=True, target_root=args.train_target_root)
train_data_loader = data.DataLoader(train_dataset, 1, num_workers=1, shuffle=True, collate_fn=collate_fn)
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, args.input_size, args.crop_num, "RefineDet", training=False)
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False, collate_fn=collate_fn)
print('Loading annotation for test...')
test_annos, test_imgs = load_path(args.data_root, "refined_images", args.test_anno_folders)
test_images = load_images(test_imgs)
test_annotations_path = load_annotations_path(test_annos, test_images)
test_anno = load_annotations(test_annotations_path)
test_imagenames, test_recs = get_anno_recs(test_anno)

### Make model

In [None]:
model = build_refinedet('train', insect_refinedet, args.input_size, args.tcb_layer_num).cuda()
initialize_model(model, args.basenet, pretrain=args.pretrain)
print(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Train

In [None]:
for epoch in range(args.max_epoch):
    train_per_epoch(model, train_data_loader, optimizer, epoch)
    
    # validate model
    if epoch != 0 and epoch % args.valid_interval == 0:
        average_precision = validate(model, test_data_loader, test_recs, args.crop_num)
        print("epoch: {}, ap={}".format(epoch, average_precision))
        if args.visdom:
            visualize(epoch + 1, average_precision, win_valid_acc)
    
    # save model
    if epoch != 0 and epoch % args.save_interval == 0:
        print('Saving state, epoch: ' + str(epoch))
        torch.save(model.state_dict(), args.model_root + '/RefineDet{}_{}.pth'.format(args.input_size, str(epoch)))

# final save model
print('Saving state, final')
torch.save(model.state_dict(), args.model_root + '/RefineDet{}_final.pth'.format(args.input_size))