In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from os import listdir as ld
from os.path import join as pj
import numpy as np
from numpy.random import *
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from tqdm import tqdm
import visdom

# Logger
from IO.logger import Logger
# Loader
from IO.loader import load_path, load_images, load_annotations_path, load_annotations, get_anno_recs
# Dataset
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
# model
from model.faster_rcnn.faster_rcnn import make_Faster_RCNN
# Predict
from model.faster_rcnn.predict import test_prediction
# Evaluate
from evaluation.detection.evaluate import evaluate

# Train Config

In [None]:
class args:
    # experiment name
    experiment_name = "crop_b2_2_4_8_16_32_aaaaa"
    # paths
    data_root = "/home/tanida/workspace/Insect_Phenology_Detector/data"
    train_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_refined_images"
    train_target_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_detection_data/refinedet_all"
    test_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/test_refined_images"
    model_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/output_model/detection/Faster_RCNN", experiment_name)
    test_anno_folders = ["annotations_4"]
    # train config
    b_bone = "vgg16"
    anchor_size = ((2, 4, 8, 16, 32),)
    aspect_ratio = ((0.5, 1.0, 2.0),)
    input_size = 512 # choices=[320, 512, 1024]
    crop_num = (5, 5)
    lr = 1e-4
    lamda = 1e-2
    batch_size = 2
    num_workers = 2
    max_epoch = 100
    valid_interval = 2
    save_interval = 20
    max_insect_per_image = 20
    pretrain = True
    # visualization
    visdom = True
    port = 8097
    # class label
    labels = ['__background__','insects']

# Set Visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.port)
    
    """train_loss"""
    win_cls_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='cls_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_cls_box_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='cls_box_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_obj_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='obj_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_rpn_box_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='rpn_box_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_total_norm_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='norm_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_train_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    """validation_accuracy"""
    win_valid_acc = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='validation_accuracy',
            xlabel='epoch',
            ylabel='average precision',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

# Train and Test

In [None]:
l2_loss = nn.MSELoss(reduction='mean').cuda()

In [None]:
def train_per_epoch(model, data_loader, optimizer, epoch):
    # set Faster_RCNN to train mode
    cls_loss = 0
    cls_box_loss = 0
    obj_loss = 0
    rpn_box_loss = 0
    total_norm_loss = 0
    total_loss = 0
    model.train()
    
    # train
    for images, targets, _, _, _ in tqdm(data_loader, leave=False):
        imgs = np.asarray(images[0])
        tars = np.asarray(targets[0])
        refined_imgs = []
        refined_tars = []
        # refine imgs, tars
        for i in range(imgs.shape[0]):
            if len(tars[i]["boxes"]) > 0:
                refined_imgs.append(imgs[i])
                refined_tars.append(tars[i])
        imgs = np.asarray(refined_imgs)
        tars = np.asarray(refined_tars)
        
        # define batch_num
        if (imgs.shape[0]%args.batch_size == 0):
            batch_num = int(imgs.shape[0]/args.batch_size)
        else:
            batch_num = int(imgs.shape[0]/args.batch_size) + 1
        
        # random sample of batch
        iter_batch = choice(range(batch_num), batch_num, replace=False)
        
        # train for cropped image
        for i in iter_batch:
            images = imgs[i*args.batch_size:(i+1)*args.batch_size]
            targets = tars[i*args.batch_size:(i+1)*args.batch_size]
            
            # set cuda
            images = [torch.from_numpy(im).cuda() for im in images]
            targets = [{k: v.cuda() for k,v in t.items()} for t in targets]
            
            # forward
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            
            # sum loss
            losses = sum(loss for loss in loss_dict.values())
            
            # l2 normalization for model parameters
            if args.lamda != 0:
                norm_loss = 0
                for param in model.parameters():
                    param_target = torch.zeros(param.size()).cuda()
                    norm_loss += l2_loss(param, param_target)
                norm_loss = norm_loss * args.lamda
                losses += norm_loss
                total_norm_loss += norm_loss.item()
            else:
                norm_loss = 0
            
            # backward
            if torch.isnan(losses) == 0:
                losses.backward()
                optimizer.step()
                cls_loss += loss_dict['loss_classifier'].item()
                cls_box_loss += loss_dict['loss_box_reg'].item()
                obj_loss += loss_dict['loss_objectness'].item()
                rpn_box_loss += loss_dict['loss_rpn_box_reg'].item()
                total_loss += losses.item()
        
    print("epoch: {0}, cls_l: {1}, cls_box_l: {2}, obj_l: {3}, rpn_box_l: {4}, total_norm_l: {5}, total_l: {6}".format(epoch, cls_loss, cls_box_loss, obj_loss, rpn_box_loss, total_norm_loss, total_loss))
    # visualize
    if args.visdom:
        visualize(epoch + 1, cls_loss, win_cls_loss)
        visualize(epoch + 1, cls_box_loss, win_cls_box_loss)
        visualize(epoch + 1, obj_loss, win_obj_loss)
        visualize(epoch + 1, rpn_box_loss, win_rpn_box_loss)
        visualize(epoch + 1, total_norm_loss, win_total_norm_loss)
        visualize(epoch + 1, total_loss, win_train_loss)

def validate(model, data_loader, recs, input_size, crop_num, nms_thresh=0.3):
    # set Faster_RCNN to eval mode
    model.eval()
    # predict and evaluate
    result = test_prediction(model, data_loader, input_size, crop_num, nms_thresh=nms_thresh)
    recall, precision, avg_precision, gt_dict = evaluate(result, recs, ovthresh=0.3)
    
    return avg_precision

### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

### Make data

In [None]:
print('Loading dataset for train ...')
train_dataset = insects_dataset_from_voc_style_txt(args.train_image_root, args.input_size, args.crop_num, "Faster_RCNN", training=True, target_root=args.train_target_root)
train_data_loader = data.DataLoader(train_dataset, 1, num_workers=1, shuffle=True, collate_fn=collate_fn)
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, args.input_size, args.crop_num, "Faster_RCNN", training=False)
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False, collate_fn=collate_fn)
print('Loading annotation for test...')
test_annos, test_imgs = load_path(args.data_root, "refined_images", args.test_anno_folders)
test_images = load_images(test_imgs)
test_annotations_path = load_annotations_path(test_annos, test_images)
test_anno = load_annotations(test_annotations_path)
test_imagenames, test_recs = get_anno_recs(test_anno)

### Make model

In [None]:
model = make_Faster_RCNN(len(args.labels), args.input_size, args.anchor_size, args.aspect_ratio, args.b_bone, max_insect_per_image=args.max_insect_per_image, pretrain=args.pretrain).cuda()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# Train

In [None]:
for epoch in range(args.max_epoch):
    train_per_epoch(model, train_data_loader, optimizer, epoch)
    
    # validate model
    if epoch != 0 and epoch % args.valid_interval == 0:
        average_precision = validate(model, test_data_loader, test_recs, args.input_size, args.crop_num)
        print("epoch: {}, ap={}".format(epoch, average_precision))
        if args.visdom:
            visualize(epoch+1, average_precision, win_valid_acc)
            
    # save model
    if epoch != 0 and epoch % args.save_interval == 0:
        print('Saving state, epoch: ' + str(epoch))
        torch.save(model.state_dict(), args.model_root + '/Faster_RCNN{}_{}.pth'.format(args.input_size, str(epoch)))
    
print('Saving state, final')
torch.save(model.state_dict(), args.model_root + '/Faster_RCNN{}_final.pth'.format(args.input_size))