In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import numpy as np
from numpy.random import *
from os.path import join as pj
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import visdom

# Logger
from IO.logger import Logger
# Optimizer
from model.optimizer import AdamW
# Dataset
from dataset.detection.dataset import insects_dataset_from_voc_style_txt, collate_fn
# Loss Function
from model.refinedet.loss.multiboxloss import RefineDetMultiBoxLoss
# Model initializer
from model.refinedet.refinedet import RefineDet
# Predict
from model.refinedet.utils.predict import test_prediction
# Evaluate
from evaluation.detection.evaluate import Voc_Evaluater

In [None]:
class args:
    # paths
    data_root = "/home/tanida/workspace/Insect_Phenology_Detector/data"
    train_image_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_refined_images/20200806"
    train_target_root = "/home/tanida/workspace/Insect_Phenology_Detector/data/train_detection_data/refinedet_all_20200806"
    # training config
    input_size = 512
    crop_num = (5,5)

In [None]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

In [None]:
print('Loading dataset for train ...')
train_dataset = insects_dataset_from_voc_style_txt(args.train_image_root, args.input_size, args.crop_num, "RefineDet", training=True, target_root=args.train_target_root)
train_data_loader = data.DataLoader(train_dataset, 1, num_workers=1, shuffle=True, collate_fn=collate_fn)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from evaluation.det2cls.visualize import vis_detections

In [None]:
for images, targets, _, _, data_id in tqdm(train_data_loader, leave=False):
    imgs = np.asarray(images[0])
    tars = targets[0]

    refined_imgs = []
    refined_tars = []
    # refine imgs, tars
    for i in range(imgs.shape[0]):
        if tars[i].size(0) > 0:
            print(i)
            refined_imgs.append(imgs[i])
            refined_tars.append(tars[i])
    imgs = np.asarray(refined_imgs)
    tars = refined_tars
    break
print(data_id[0][0])

In [None]:
idx = 0

In [None]:
# create img
img = imgs[idx].transpose(1,2,0) * 255
img = img.astype("uint8")
# create target
tar = tars[idx]
tar = tar.numpy()
tar = tar * 512
tar[:, -1] = 1.0
# visualize img
img = img.copy()
img = vis_detections(img, tar, class_name="insects", color_name="green")
plt.imshow(img)
idx += 1