In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import copy
import numpy as np
from os.path import join as pj
import pandas as pd
from PIL import Image
import torch
import torch.utils.data as data

# IO
from IO.utils import refine_result_by_ovthresh, output_formatter, write_output_xml
from IO.loader import load_path, load_images
from IO.build_ds import build_classification_ds_from_result
# Dataset
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
from dataset.classification.dataset import insects_dataset
# det model
from model.refinedet.refinedet import RefineDet
from model.refinedet.utils.predict import test_prediction
# cls model
from model.resnet.resnet import ResNet
from model.resnet.predict import test_classification
# Evaluate
from evaluation.detection.evaluate import Voc_Evaluater, visualize_mean_index
from evaluation.det2cls.visualize import vis_detections

### Test Config

In [None]:
class args:
    # paths
    data_root = pj(os.getcwd(), "data")
    test_image_root = pj(os.getcwd(), "data/refined_images")
    test_target_root = pj(os.getcwd(), "data/test_detection_data/target_with_other_alldata")
    det_model_root = pj(os.getcwd(), "output_model/detection/RefineDet", "master_paper/crop_b2/tcb5_im512_freeze_All0to2_withOther")
    cls_model_root = pj(os.getcwd(), "output_model/classification", "master_paper/resnet50/b20_lr1e-5/crossvalid_20200806_OS_All5to6withResize")
    size_model_root = pj(os.getcwd(), "output_model/image2size", "ResNet34_b80_lr1e-4_all02")
    figure_root = pj(os.getcwd(), "figure/det2cls", "master_paper/refinedet_plus_other_resnet_size")
    save_img_root = pj(os.getcwd(), "figure/det2cls", "master_paper/refinedet_plus_other_resnet_size/output_image")
    save_xml_root = pj(os.getcwd(), "output_xml/det2cls", "master_paper/refinedet_plus_other_resnet_size")
    test_anno_folders = ["annotations_0", "annotations_2", "annotations_3", "annotations_4", "annotations_20200806"]
    # det model config
    input_size = 512
    crop_num = (5, 5)
    tcb_layer_num = 6
    use_extra_layer = True
    det_activation_function = "ReLU"
    # cls model config
    cls_model_name = "resnet50"
    cls_activation_function = "ReLU"
    cls_use_dropout = True
    # size model config
    size_model_name = "resnet34"
    size_activation_function = "ReLU"
    size_use_dropout = True

In [None]:
args.det_labels = ["Aquatic Insect", "Other"]
args.cls_labels = ["Diptera", "Ephemeridae", "Ephemeroptera",
                   "Lepidoptera", "Plecoptera", "Trichoptera"]

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

### Detection Model

In [None]:
det_model = RefineDet(args.input_size, len(args.det_labels)+1, args.tcb_layer_num, activation_function=args.det_activation_function, use_extra_layer=args.use_extra_layer, use_GN_WS=False)
load_name = pj(args.det_model_root, 'best.pth')
det_model.load_state_dict(torch.load(load_name))

### Classification Model

In [None]:
cls_model = ResNet(args.cls_model_name, len(args.cls_labels), use_dropout=args.cls_use_dropout, activation_function=args.cls_activation_function, decoder=None).cuda()
load_name = pj(args.cls_model_root, "valid_3_best.pth")
cls_model.load_state_dict(torch.load(load_name))

### Size Estimation Model

In [None]:
size_model = ResNet(args.size_model_name, 1, use_dropout=args.size_use_dropout, activation_function=args.size_activation_function, decoder=None).cuda()
load_name = pj(args.size_model_root, "valid_3_best.pth")
size_model.load_state_dict(torch.load(load_name))

### Make data

In [None]:
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, args.input_size, args.crop_num, training=False)
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False, collate_fn=collate_fn)
print('Loading images ...')
anno_paths, image_paths = load_path(args.data_root, "refined_images", args.test_anno_folders)
images = load_images(image_paths)

# --- result analysis ---

In [None]:
if os.path.exists(args.figure_root) is False:
    os.makedirs(args.figure_root)

In [None]:
def get_det_result(det_model, data_loader, crop_num, num_classes, nms_thresh=0.5, ovthresh=0.3):
    det_result = test_prediction(det_model, data_loader, crop_num, num_classes, nms_thresh)
    det_result = refine_result_by_ovthresh(det_result, ovthresh)
    return det_result

In [None]:
det_result = get_det_result(det_model, test_data_loader, args.crop_num, len(args.det_labels)+1)

In [None]:
insect_dataset = build_classification_ds_from_result(images, det_result)

In [None]:
def estimate_size(model, dataloader):
    """
        image2size estimation function
    """
    size_array = []
    
    model.eval()
    for image in dataloader:
        image = image.cuda()
        out = model(image)
        size_array.extend(out.cpu().detach().numpy())

    model.train()
    return np.array(size_array)

In [None]:
def get_det2cls_result(cls_model, insect_dataset, det_result, size_model=None):
    det2cls_result = {}
    if size_model is not None:
        lbl_array = []
        size_array = []
    
    for image_id, imgs in insect_dataset.items():
        print("classify images: {}".format(image_id))
        det2cls_result_per_image = copy.copy(det_result[image_id])
        det2cls_result_per_image.update({len(args.cls_labels): det_result[image_id][1]})
        
        test_dataset = insects_dataset(imgs, training=False)
        test_dataloader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False)
        
        # classification
        lbls = test_classification(cls_model, test_dataloader)
        
        # size estimation
        if size_model is not None:
            sizes = estimate_size(size_model, test_dataloader)
            lbl_array.extend(lbls)
            size_array.extend(sizes)
        
        coords = det_result[image_id][0]
        for lbl in range(len(args.cls_labels)):
            lbl_filter = lbls == lbl
            filtered_coords = coords[lbl_filter]
            det2cls_result_per_image.update({lbl: filtered_coords})
        det2cls_result.update({image_id: det2cls_result_per_image})
    
    if size_model is not None:
        cls_and_size_df = pd.DataFrame({"lbl": np.array(lbl_array), "size": np.array(size_array)})
        return det2cls_result, cls_and_size_df
    else:
        return det2cls_result

In [None]:
det2cls_result, cls_and_size_df = get_det2cls_result(cls_model, insect_dataset, det_result, size_model=size_model)

In [None]:
evaluater = Voc_Evaluater(args.test_image_root, args.test_target_root, args.figure_root)
evaluater.set_result(det2cls_result)
eval_metrics = evaluater.get_eval_metrics()

In [None]:
visualize_mean_index(eval_metrics, figure_root=args.figure_root)

In [None]:
cls_and_size_df.to_csv(pj(args.figure_root, "cls_and_size_df.csv"))
cls_and_size_df

### --- Output image with Result ---

In [None]:
if os.path.exists(args.save_img_root) is False:
    os.makedirs(args.save_img_root)

In [None]:
def get_imagenames_from_anno_paths(anno_paths):
    imagenames = [anno_path.split('/')[-1] for anno_path in anno_paths]
    imagenames = [imagename.split('.')[0] for imagename in imagenames if imagename != '.ipynb_checkpoints']
    return imagenames

In [None]:
imagenames = get_imagenames_from_anno_paths(anno_paths)
colors = ["white", "red", "lime", "blue", "yellow", "fuchsia", "aqua", "gray", "maroon", "green", "navy", "olive", "purple", "teal"]
insect_names = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']

In [None]:
def read_ground_truth(test_target_path, insect_names, height, width):
    gt_coord = {}
    for lbl in range(len(insect_names)):
        gt_coord.update({lbl: []})

    with open(test_target_path, mode="r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.split('\n')[0]
            elements = line.split(' ')
            target_lbl = int(elements[4])
            x1 = float(elements[0]) * width
            x2 = float(elements[2]) * width
            y1 = float(elements[1]) * height
            y2 = float(elements[3]) * height
            coord = [x1, y1, x2, y2]
            gt_coord[target_lbl].append(coord)
    
    for lbl in range(len(insect_names)):
        gt_coord[lbl] = np.asarray(gt_coord[lbl])
    
    return gt_coord

In [None]:
def output_img_with_result(det2cls_result, test_image_root, test_target_root, imagenames, insect_names, save_img_root):
    for i in range(len(imagenames)):
        print("output image: {}".format(imagenames[i]+".png"))
        img = np.asarray(Image.open(pj(test_image_root, imagenames[i]+".png")))
        height, width, _ = img.shape
        gt_coord = read_ground_truth(pj(test_target_root, imagenames[i]+".txt"), insect_names, height, width)
        coord_per_image = copy.copy(det2cls_result[imagenames[i]])
        for lbl in range(len(insect_names)):
            img = vis_detections(img, coord_per_image[lbl], class_name=insect_names[lbl], color_name=colors[lbl], thresh=0.5)
            img = vis_detections(img, gt_coord[lbl], class_name=insect_names[lbl], color_name=colors[lbl])
        
        img = Image.fromarray(img)
        img.save(pj(save_img_root, imagenames[i]+".png"))

In [None]:
output_img_with_result(det2cls_result, args.test_image_root, args.test_target_root, imagenames, insect_names, args.save_img_root)

### --- Output labelImg XML ---

In [None]:
if os.path.exists(args.save_xml_root) is False:
    os.makedirs(args.save_xml_root)

In [None]:
insect_names = ['Diptera', 'Ephemeridae', 'Ephemeroptera', 
                'Lepidoptera', 'Plecoptera', 'Trichoptera', 'Other']
label_map = {}
for i, insect_name in enumerate(insect_names):
    label_map.update({i: insect_name})
label_map

In [None]:
output = output_formatter(det2cls_result, label_map)

In [None]:
write_output_xml(output, args.save_xml_root)