In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import copy
import numpy as np
from numpy.random import *
from os import listdir as ld
from os.path import join as pj
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import visdom
from PIL import Image

# IO
from IO.utils import refine_result_by_ovthresh, output_formatter, write_output_xml
from IO.loader import load_path, load_images
from IO.build_ds import build_classification_ds_from_result
# utils
from utils.crop import crop_adjusted_std, crop_adjusted_std_resize
# Dataset
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
# det model
from model.refinedet.refinedet import RefineDet
# cls model
from model.resnet.resnet import ResNet
from model.resnet.predict import test_classification
# Predict
from model.refinedet.utils.predict import test_prediction
# Evaluate
from evaluation.detection.evaluate import Voc_Evaluater, visualize_mean_index
from evaluation.det2cls.visualize import vis_detections

# Test Config

In [None]:
class args:
    # paths
    data_root = "/home/tanida/workspace/Insect_Phenology_Detector/data"
    test_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/ooe_pict"
    det_model_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/output_model/detection/RefineDet", "crop_b2_2_4_8_16_32_im512_other")
    cls_model_root = pj("/home/tanida/workspace/Insect_Phenology_Detector/output_model/classification/ResNet101", "resnet50_b20_r45_lr1e-5_crossvalid")
    # det model config
    input_size = 512 # choices=[320, 512, 1024]
    crop_num = (5, 5)
    tcb_layer_num = 5
    use_extra_layer = False
    det_activation_function = "ReLU"
    use_GN_WS = False
    # cls model config
    cls_model_name = "resnet50"
    cls_activation_function = "ReLU"
    decoder = None
    # test config
    det_divide_flag = True
    cls_divide_flag = False

In [None]:
# class label
if args.det_divide_flag is True:
    args.det_labels = ['Aquatic_insects', 'Other_insects']
else:
    args.det_labels = ['insects']

if args.cls_divide_flag is True:
    args.cls_labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                       'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']
else:
    args.cls_labels = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                       'Lepidoptera', 'Plecoptera', 'Trichoptera']

# Set cuda

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# Detection Model

In [None]:
det_model = RefineDet(args.input_size, len(args.det_labels)+1, args.tcb_layer_num, activation_function=args.det_activation_function, use_extra_layer=args.use_extra_layer, use_GN_WS=args.use_GN_WS)
load_name = pj(args.det_model_root, 'RefineDet{}_{}.pth'.format(args.input_size, "final"))
det_model.load_state_dict(torch.load(load_name))

# Classification Model

In [None]:
cls_model = ResNet(args.cls_model_name, len(args.cls_labels), activation_function=args.cls_activation_function, decoder=args.decoder).cuda()
load_name = pj(args.cls_model_root, "final.pth")
cls_model.load_state_dict(torch.load(load_name))

### Make data

In [None]:
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, args.input_size, args.crop_num, "RefineDet", training=False)
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False, collate_fn=collate_fn)

# --- detection result ---

In [None]:
def get_det_result(det_model, data_loader, crop_num, num_classes, nms_thresh=0.3, ovthresh=0.3):
    det_result = test_prediction(det_model, data_loader, crop_num, num_classes, nms_thresh)
    det_result = refine_result_by_ovthresh(det_result, ovthresh)
    return det_result

In [None]:
det_result = get_det_result(det_model, test_data_loader, args.crop_num, len(args.det_labels)+1)

# --- output labelImg XML ---

In [None]:
def get_det2cls_result(cls_model, insect_dataset, det_result, det_divide_flag=False):
    det2cls_result = {}
    for image_id, imgs in insect_dataset.items():
        print("classify images: {}".format(image_id))
        det2cls_result_per_image = copy.copy(det_result[image_id])
        if det_divide_flag:
            det2cls_result_per_image.update({len(args.cls_labels): det_result[image_id][1]})
        
        imgs = torch.from_numpy(imgs).cuda()
        lbls = test_classification(cls_model, imgs, bs=2)
        coords = det_result[image_id][0]
        for lbl in range(len(args.cls_labels)):
            lbl_filter = lbls == lbl
            filtered_coords = coords[lbl_filter]
            det2cls_result_per_image.update({lbl: filtered_coords})
        det2cls_result.update({image_id: det2cls_result_per_image})
    return det2cls_result

In [None]:
insect_names = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']
label_map = {}
for i, insect_name in enumerate(insect_names):
    label_map.update({i: insect_name})
label_map

In [None]:
image_paths = [pj(args.test_image_root, image_path) for image_path in ld(args.test_image_root)]

In [None]:
for image_path in image_paths:
    file_id = image_path.split("/")[-1].split(".")[0]
    if len(det_result[file_id][0].shape) == 2:
        image = load_images([image_path])
        sample_det_result = {file_id: det_result[file_id]}
        insect_dataset = build_classification_ds_from_result(image, sample_det_result)
        det2cls_result = get_det2cls_result(cls_model, insect_dataset, sample_det_result, det_divide_flag=args.det_divide_flag)
        output = output_formatter(det2cls_result, label_map)
        write_output_xml(output, "/home/tanida/workspace/Insect_Phenology_Detector/output_xml/ooe_pict_20200806", add_flag=True)