In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import cv2
import numpy as np
from os import listdir as ld
from os.path import join as pj
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from tqdm import tqdm

from model.utils.config import cfg, cfg_from_file, cfg_from_list
from model.rpn.bbox_transform import bbox_transform_inv
from model.rpn.bbox_transform import clip_boxes
from model.utils.net_utils import weights_normal_init, save_net, load_net, \
      adjust_learning_rate, save_checkpoint, clip_gradient, vis_detections

# Loader
from IO.dataset import load_path, load_images, load_images_path, load_annotations_path, load_annotations, get_all_anno_recs
# Dataset
from dataset.dataset import insects_dataset_from_voc_style_txt, collate_fn
# Predict
from evaluation.predict import test_prediction
# Evaluate
from evaluation.evaluate import evaluate
# Statistics
from evaluation.statistics import plot_df_distrib_size, compute_size_df, plot_df_distrib_class, plot_df_error, plot_pr_curve
# Visualize
from evaluation.visualize import vis_detections

import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

# Test Config

In [None]:
class args:
    # experiment name
    experiment_name = "crop_b2_2_4_8_16_32_not_pretrain"
    # paths
    data_root = "/home/tanida/workspace/Insect_project/data"
    train_image_root = "/home/tanida/workspace/Insect_project/data/train_refined_images"
    test_image_root = "/home/tanida/workspace/Insect_project/data/test_refined_images"
    train_target_root = "/home/tanida/workspace/Insect_project/data/train_detection_data/refinedet_all"
    test_target_root = "/home/tanida/workspace/Insect_project/data/test_detection_data/refinedet_all"
    model_root = pj("/home/tanida/workspace/Insect_project/model/detection/Faster_RCNN", experiment_name)
    figure_root = pj("/home/tanida/workspace/Insect_project/figure/detection/Faster_RCNN", experiment_name)
    test_anno_folders = ["annotations_4"]
    # train config
    b_bone = "vgg16"
    anchor_size = ((2, 4, 8, 16, 32),)
    aspect_ratio = ((0.5, 1.0, 2.0),)
    input_size = 512
    crop_num = (5, 5)
    batch_size = 2
    num_worker = 2
    lr = 1e-4
    lamda = 1e-2
    max_epoch = 100
    valid_interval = 2
    save_interval = 20
    max_insect_per_image = 20
    # visualization
    visdom = True
    port = 8097
    # class label
    labels = ['__background__','insects']

# Model

In [None]:
class TwoMLPHead(nn.Module):
    def __init__(self, in_channels, representation_size):
        super(TwoMLPHead, self).__init__()
        
        self.fc6 = nn.Linear(in_channels, representation_size)
        self.fc7 = nn.Linear(representation_size, representation_size)
        self.drop = nn.Dropout(p=0.5)
    
    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = self.drop(F.relu(self.fc6(x)))
        x = self.drop(F.relu(self.fc7(x)))
        return x

def make_Faster_RCNN(b_bone=args.b_bone, n_class=len(args.labels), anchor_size=args.anchor_size, aspect_ratio=args.aspect_ratio):
    if b_bone == "vgg16":
        b_outchannels = 512
        representation_channels = 4096
        backbone = torchvision.models.vgg16(pretrained=True).features
        backbone = nn.Sequential(*list(backbone.children())[:-1])
        backbone.out_channels = b_outchannels
        anchor_generator = AnchorGenerator(sizes=anchor_size, aspect_ratios=aspect_ratio)
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)
        box_header = TwoMLPHead(b_outchannels * roi_pooler.output_size[0] ** 2, representation_channels)
        predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(representation_channels, n_class)
        model = FasterRCNN(backbone, min_size=args.input_size, max_size=args.input_size, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler
                           , box_head=box_header, box_predictor=predictor
                           , box_detections_per_img=args.max_insect_per_image)
    else:
        b_outchannels = 512
        representation_channels = 4096
        backbone = torchvision.models.resnet34(pretrained=True)
        backbone = nn.Sequential(*list(backbone.children())[:-3])
        backbone.out_channels = b_outchannels
        anchor_generator = AnchorGenerator(sizes=anchor_size, aspect_ratios=aspect_ratio)
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)
        box_header = TwoMLPHead(b_outchannels * roi_pooler.output_size[0] ** 2, representation_channels)
        predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(representation_channels, n_class)
        model = FasterRCNN(backbone, min_size=args.input_size, max_size=args.input_size, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler
                           , box_head=box_header, box_predictor=predictor
                           , box_detections_per_img=args.max_insect_per_image)
    
    return model

# Load Annotations for Test

In [None]:
annos, imgs = load_path(args.data_root, args.test_anno_folders)
images = load_images(imgs)
annotations_path = load_annotations_path(annos, images)
images_path = load_images_path(imgs, annotations_path)
anno = load_annotations(annotations_path)
imagenames, recs = get_all_anno_recs(anno)

# Data and Model load

In [None]:
print("loading model for test ...")
load_name = pj(args.model_root, 'Faster_RCNN{}_{}.pth'.format(args.input_size, "40"))
print('Loading dataset for test ...')
test_dataset = insects_dataset_from_voc_style_txt(args.test_image_root, training=False, target_root=args.test_target_root, resize_size=args.input_size, crop_num=args.crop_num)
test_data_loader = data.DataLoader(test_dataset, 1, num_workers=1, shuffle=False, collate_fn=collate_fn)

In [None]:
model = make_Faster_RCNN().cuda()
model.load_state_dict(torch.load(load_name))

In [None]:
model

# --- result analysis ---

In [None]:
result = test_prediction(model, test_data_loader, args.input_size, args.crop_num, nms_thresh=0.3)

### --- visualize accuracy distribution of size ---

In [None]:
recall, precision, avg_precision, gt_dict = evaluate(result, recs)
plot_df_distrib_size(compute_size_df(gt_dict), args.figure_root, save=True, output_csv=True)

In [None]:
plot_pr_curve(precision, recall, args.figure_root, save=True)

In [None]:
avg_precision

### --- visualize accuracy distribution of class ---

In [None]:
each_label_dic = {
    'Coleoptera': 0,
    'Diptera': 1,
    'Ephemeridae': 2,
    'Ephemeroptera': 3,
    'Hemiptera': 4,
    'Lepidoptera': 5,
    'Plecoptera': 6,
    'Trichoptera': 7,
    'small insect': 8,
    'medium insect': 9
}

In [None]:
recall, precision, avg_precision, gt_dict = evaluate(result, recs)
plot_df_distrib_class(each_label_dic, gt_dict, args.figure_root, save=True, output_csv=True, color="green")

### --- visualize error count per class ---

In [None]:
recall, precision, avg_precision, gt_dict = evaluate(result, recs)
plot_df_error(each_label_dic, gt_dict, args.figure_root, save=True, output_csv=True, color="green")

### --- compare ground truth and output ---

In [None]:
import cv2
import matplotlib.pyplot as plt
from  PIL import Image
%matplotlib inline

In [None]:
result = test_prediction(model, test_data_loader, args.input_size, args.crop_num, nms_thresh=0.3)

In [None]:
im_index = 0
index_to_key = lambda x:imagenames[x]
if os.path.exists(args.figure_root) is False:
    os.makedirs(args.figure_root)

In [None]:
gt   = np.asarray(list(map(lambda x:x["bbox"]+[1], recs[imagenames[im_index]])))
gazo = np.asarray(Image.open(pj("/home/tanida/workspace/Insect_project/data/refined_images", imagenames[im_index]+".png")))
x = vis_detections(gazo, result[index_to_key(im_index)], color_name="blue")
x = vis_detections(x, gt, color_name="red")
x = Image.fromarray(x)
x.save(pj(args.figure_root, imagenames[im_index]+".png"))
print(imagenames[im_index])
im_index += 1
#x

### --- compare different iou ---

In [None]:
for thresh in [.3, .5]:
    recall, precision, avg_precision, gt_dict = evaluate(result, recs, ovthresh=thresh)
    print("thresh == {0}, ap == {1}".format(thresh, avg_precision))