In [None]:
import os, sys
import numpy as np
from renom.cuda import set_cuda_active
from renom_img.api.detection.yolo_v1 import Yolov1
from renom_img.api.utility.load import parse_xml_detection
from PIL import Image
from tqdm import tqdm

from renom_img.api.model.darknet import Darknet
from renom_img.api.utility.distributor.distributor import ImageDistributor

from renom_img.api.utility.augmentation.process import Shift, Rotate, Flip, WhiteNoise, ContrastNorm
from renom_img.api.utility.augmentation.augmentation import Augmentation

from renom_img.api.utility.misc.display import draw_box



np.random.seed(2018)

In [None]:
set_cuda_active(True)

In [None]:
prefix_path1 = '/home/yamada/dataset/VOCdevkit/VOC2007/'
prefix_path2 = '/home/yamada/dataset/VOCdevkit/VOC2012/'
file_list1 = [os.path.join(prefix_path1, 'JPEGImages', path) for path in sorted(os.listdir(os.path.join(prefix_path1, 'JPEGImages')))]
file_list2 = [os.path.join(prefix_path2, 'JPEGImages', path) for path in sorted(os.listdir(os.path.join(prefix_path2, 'JPEGImages')))]


In [None]:
file_list = file_list1 + file_list2

In [None]:
annot_file_list = [ o.replace('JPEGImages', 'Annotations').replace('jpg', 'xml') for o in file_list]

In [None]:
annot, class_map = parse_xml_detection(annot_file_list)

In [None]:
class_map =  sorted(class_map.keys())

In [None]:
perm = np.random.permutation(len(file_list))

In [None]:
file_list = list(np.array(file_list)[perm])
annot = list(np.array(annot)[perm])

In [None]:
n_train = int(0.8*len(file_list))
train_img_path_list = file_list[:n_train]
train_annot_list = annot[:n_train]
valid_img_path_list = file_list[n_train:]
valid_annot_list = annot[n_train:]

In [None]:
draw_box(train_img_path_list[2], train_annot_list[2])

In [None]:
yolo = Yolov1(class_map=class_map, load_pretrained_weight=True ,train_whole_network=True, imsize=(448, 448))

augmentation = Augmentation([
                Flip(),
                Rotate(),
                WhiteNoise(),
                ContrastNorm([0.5, 1.0])
            ])


In [None]:
def callback_end_epoch(epoch, model, avg_train_loss_list, avg_valid_loss_list):
    model.save('/home/yamada/checkpoints/yolo/model_{}.h5'.format(epoch))

In [None]:
yolo.fit(train_img_path_list=train_img_path_list, 
         train_annotation_list=train_annot_list, valid_img_path_list=valid_img_path_list, 
         valid_annotation_list=valid_annot_list, augmentation=augmentation, batch_size=16, callback_end_epoch=callback_end_epoch)