Skip to content

Commit

Permalink
support python 2.7 and 3.6, simplify training process
Browse files Browse the repository at this point in the history
  • Loading branch information
experiencor committed Mar 26, 2018
2 parents 796ed20 + 64f6c27 commit 6e11249
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 107 deletions.
25 changes: 3 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@ This repo contains the implementation of YOLOv2 in Keras with Tensorflow backend

## Todo list:
- [x] Warmup training
- [x] Raccoon detection
- [x] Self-driving car
- [x] Kangaroo detection
- [x] SqueezeNet backend
- [x] MobileNet backend
- [x] InceptionV3 backend
- [x] VGG16 backend
- [x] ResNet50 backend
- [x] Raccoon detection, Self-driving car, and Kangaroo detection
- [x] SqueezeNet, MobileNet, InceptionV3, and ResNet50 backends
- [x] Support python 2.7 and 3.6
- [ ] Multiple-GPU training
- [ ] Multiscale training
- [ ] mAP Evaluation
Expand Down Expand Up @@ -132,20 +127,6 @@ Copy the generated anchors printed on the terminal to the ```anchors``` setting

### 4. Start the training process

#### Warm up the network

Set ```warmup_epochs``` in config.json to 3 (emperically found, 4 or 5 is also fine).

`python train.py -c config.json`

This process saves the trained weights to the file specified in ```saved_weights_name``` setting.

#### Actual network training

Set ```pretrained_weights``` setting in ```config.json``` to the warmup weights (whatever in ```saved_weights_name```).

Set ```warmup_epochs``` in config.json to 0.

`python train.py -c config.json`

By the end of this process, the code will write the weights of the best model to file best_weights.h5 (or whatever name specified in the setting "saved_weights_name" in the config.json file). The training process stops when the loss on the validation set is not improved in 3 consecutive epoches.
Expand Down
12 changes: 6 additions & 6 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@
"input_size": 416,
"anchors": [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828],
"max_box_per_image": 10,
"labels": ["raccoon"]
"labels": ["kangaroo"]
},

"train": {
"train_image_folder": "/home/andy/github/raccoon_dataset/images/",
"train_annot_folder": "/home/andy/github/raccoon_dataset/annotations/",
"train_image_folder": "/home/andy/data/kangaroo/images/",
"train_annot_folder": "/home/andy/data/kangaroo/annots/",

"train_times": 10,
"pretrained_weights": "full_yolo_raccoon.h5",
"pretrained_weights": "",
"batch_size": 16,
"learning_rate": 1e-4,
"nb_epoch": 50,
"nb_epochs": 50,
"warmup_epochs": 3,

"object_scale": 5.0 ,
"no_object_scale": 1.0,
"coord_scale": 1.0,
"class_scale": 1.0,

"saved_weights_name": "full_yolo_raccoon.h5",
"saved_weights_name": "full_yolo_kangaroo3.h5",
"debug": true
},

Expand Down
44 changes: 24 additions & 20 deletions frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, architecture,

self.labels = list(labels)
self.nb_class = len(self.labels)
self.nb_box = len(anchors)/2
self.nb_box = len(anchors)//2
self.class_wt = np.ones(self.nb_class, dtype='float32')
self.anchors = anchors

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, architecture,
else:
raise Exception('Architecture not supported! Only support Full Yolo, Tiny Yolo, MobileNet, SqueezeNet, VGG16, ResNet50, and Inception3 at the moment!')

print self.feature_extractor.get_output_shape()
print(self.feature_extractor.get_output_shape())
self.grid_h, self.grid_w = self.feature_extractor.get_output_shape()
features = self.feature_extractor.extract(input_image)

Expand All @@ -69,6 +69,7 @@ def __init__(self, architecture,
output = Lambda(lambda args: args[0])([output, self.true_boxes])

self.model = Model([input_image, self.true_boxes], output)


# initialize the weights of the detection layer
layer = self.model.layers[-4]
Expand Down Expand Up @@ -194,9 +195,11 @@ def custom_loss(self, y_true, y_pred):
no_boxes_mask = tf.to_float(coord_mask < self.coord_scale/2.)
seen = tf.assign_add(seen, 1.)

true_box_xy, true_box_wh, coord_mask = tf.cond(tf.less(seen, self.warmup_bs),
true_box_xy, true_box_wh, coord_mask = tf.cond(tf.less(seen, self.warmup_batches+1),
lambda: [true_box_xy + (0.5 + cell_grid) * no_boxes_mask,
true_box_wh + tf.ones_like(true_box_wh) * np.reshape(self.anchors, [1,1,1,self.nb_box,2]) * no_boxes_mask,
true_box_wh + tf.ones_like(true_box_wh) * \
np.reshape(self.anchors, [1,1,1,self.nb_box,2]) * \
no_boxes_mask,
tf.ones_like(coord_mask)],
lambda: [true_box_xy,
true_box_wh,
Expand All @@ -215,7 +218,9 @@ def custom_loss(self, y_true, y_pred):
loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=true_box_class, logits=pred_box_class)
loss_class = tf.reduce_sum(loss_class * class_mask) / (nb_class_box + 1e-6)

loss = loss_xy + loss_wh + loss_conf + loss_class
loss = tf.cond(tf.less(seen, self.warmup_batches+1),
lambda: loss_xy + loss_wh + loss_conf + loss_class + 10,
lambda: loss_xy + loss_wh + loss_conf + loss_class)

if self.debug:
nb_true_box = tf.reduce_sum(y_true[..., 4])
Expand Down Expand Up @@ -320,13 +325,13 @@ def decode_netout(self, netout, obj_threshold=0.3, nms_threshold=0.3):
for c in range(self.nb_class):
sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))

for i in xrange(len(sorted_indices)):
for i in range(len(sorted_indices)):
index_i = sorted_indices[i]

if boxes[index_i].classes[c] == 0:
continue
else:
for j in xrange(i+1, len(sorted_indices)):
for j in range(i+1, len(sorted_indices)):
index_j = sorted_indices[j]

if self.bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:
Expand Down Expand Up @@ -354,7 +359,7 @@ def train(self, train_imgs, # the list of images to train the model
valid_imgs, # the list of images used to validate the model
train_times, # the number of time to repeat the training set, often used for small datasets
valid_times, # the number of times to repeat the validation set, often used for small datasets
nb_epoch, # number of epoches
nb_epochs, # number of epoches
learning_rate, # the learning rate
batch_size, # the size of the batch
warmup_epochs, # number of initial batches to let the model familiarize with the new dataset
Expand All @@ -366,7 +371,6 @@ def train(self, train_imgs, # the list of images to train the model
debug=False):

self.batch_size = batch_size
self.warmup_bs = warmup_epochs * (train_times*(len(train_imgs)/batch_size+1) + valid_times*(len(valid_imgs)/batch_size+1))

self.object_scale = object_scale
self.no_object_scale = no_object_scale
Expand All @@ -375,15 +379,6 @@ def train(self, train_imgs, # the list of images to train the model

self.debug = debug

if warmup_epochs > 0: nb_epoch = warmup_epochs # if it's warmup stage, don't train more than warmup_epochs

############################################
# Compile the model
############################################

optimizer = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
self.model.compile(loss=self.custom_loss, optimizer=optimizer)

############################################
# Make train and validation generators
############################################
Expand All @@ -407,7 +402,16 @@ def train(self, train_imgs, # the list of images to train the model
valid_batch = BatchGenerator(valid_imgs,
generator_config,
norm=self.feature_extractor.normalize,
jitter=False)
jitter=False)

self.warmup_batches = warmup_epochs * (train_times*len(train_batch) + valid_times*len(valid_batch))

############################################
# Compile the model
############################################

optimizer = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
self.model.compile(loss=self.custom_loss, optimizer=optimizer)

############################################
# Make a few callbacks
Expand Down Expand Up @@ -437,7 +441,7 @@ def train(self, train_imgs, # the list of images to train the model

self.model.fit_generator(generator = train_batch,
steps_per_epoch = len(train_batch) * train_times,
epochs = nb_epoch,
epochs = warmup_epochs + nb_epochs,
verbose = 1,
validation_data = valid_batch,
validation_steps = len(valid_batch) * valid_times,
Expand Down
10 changes: 5 additions & 5 deletions gen_anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def print_anchors(centroids):
r += '%0.2f,%0.2f' % (anchors[sorted_indices[-1:],0], anchors[sorted_indices[-1:],1])
r += "]"

print r
print(r)

def run_kmeans(ann_dims, anchor_num):
ann_num = ann_dims.shape[0]
Expand All @@ -82,7 +82,7 @@ def run_kmeans(ann_dims, anchor_num):
distances.append(d)
distances = np.array(distances) # distances.shape = (ann_num, anchor_num)

print "iteration {}: dists = {}".format(iteration, np.sum(np.abs(old_distances-distances)))
print("iteration {}: dists = {}".format(iteration, np.sum(np.abs(old_distances-distances))))

#assign samples to centroids
assignments = np.argmin(distances,axis=1)
Expand Down Expand Up @@ -123,13 +123,13 @@ def main(argv):
for obj in image['object']:
relative_w = (float(obj['xmax']) - float(obj['xmin']))/cell_w
relatice_h = (float(obj["ymax"]) - float(obj['ymin']))/cell_h
annotation_dims.append(map(float, (relative_w,relatice_h)))
annotation_dims = np.array(annotation_dims)
annotation_dims.append(tuple(map(float, (relative_w,relatice_h))))

annotation_dims = np.array(annotation_dims)
centroids = run_kmeans(annotation_dims, num_anchors)

# write anchors to file
print '\naverage IOU for', num_anchors, 'anchors:', '%0.2f' % avg_IOU(annotation_dims, centroids)
print('\naverage IOU for', num_anchors, 'anchors:', '%0.2f' % avg_IOU(annotation_dims, centroids))
print_anchors(centroids)

if __name__ == '__main__':
Expand Down
5 changes: 2 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _main_(args):
# Load trained weights
###############################

print weights_path
print(weights_path)
yolo.load_weights(weights_path)

###############################
Expand All @@ -63,7 +63,6 @@ def _main_(args):

if image_path[-4:] == '.mp4':
video_out = image_path[:-4] + '_detected' + image_path[-4:]

video_reader = cv2.VideoCapture(image_path)

nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
Expand All @@ -90,7 +89,7 @@ def _main_(args):
boxes = yolo.predict(image)
image = draw_boxes(image, boxes, config['model']['labels'])

print len(boxes), 'boxes are found'
print(len(boxes), 'boxes are found')

cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)

Expand Down
6 changes: 3 additions & 3 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_annotation(ann_dir, img_dir, labels=[]):

for ann in sorted(os.listdir(ann_dir)):
img = {'object':[]}

tree = ET.parse(ann_dir + ann)

for elem in tree.iter():
Expand Down Expand Up @@ -222,7 +222,7 @@ def __getitem__(self, idx):
# increase instance counter in current batch
instance_count += 1

#print ' new batch created', idx
#print(' new batch created', idx)

return [x_batch, b_batch], y_batch

Expand All @@ -233,7 +233,7 @@ def aug_image(self, train_instance, jitter):
image_name = train_instance['filename']
image = cv2.imread(image_name)

if image is None: print 'Cannot find ', image_name
if image is None: print('Cannot find ', image_name)

h, w, c = image.shape
all_objs = copy.deepcopy(train_instance['object'])
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
tensorflow-gpu==1.3
keras==2.0.8
imgaug
opencv-python
h5py
55 changes: 7 additions & 48 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,5 @@
#! /usr/bin/env python

"""
This script takes in a configuration file and produces the best model.
The configuration file is a json file and looks like this:
{
"model" : {
"architecture": "Full Yolo",
"input_size": 416,
"anchors": [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828],
"max_box_per_image": 10,
"labels": ["raccoon"]
},
"train": {
"train_image_folder": "/home/andy/data/raccoon_dataset/images/",
"train_annot_folder": "/home/andy/data/raccoon_dataset/anns/",
"train_times": 10,
"pretrained_weights": "",
"batch_size": 16,
"learning_rate": 1e-4,
"nb_epoch": 50,
"warmup_epochs": 3,
"object_scale": 5.0 ,
"no_object_scale": 1.0,
"coord_scale": 1.0,
"class_scale": 1.0,
"debug": true
},
"valid": {
"valid_image_folder": "",
"valid_annot_folder": "",
"valid_times": 1
}
}
"""

import argparse
import os
import numpy as np
Expand Down Expand Up @@ -90,15 +49,15 @@ def _main_(args):
if len(config['model']['labels']) > 0:
overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))

print 'Seen labels:\t', train_labels
print 'Given labels:\t', config['model']['labels']
print 'Overlap labels:\t', overlap_labels
print('Seen labels:\t', train_labels)
print('Given labels:\t', config['model']['labels'])
print('Overlap labels:\t', overlap_labels)

if len(overlap_labels) < len(config['model']['labels']):
print 'Some labels have no annotations! Please revise the list of labels in the config.json file!'
print('Some labels have no annotations! Please revise the list of labels in the config.json file!')
return
else:
print 'No labels are provided. Train on all seen labels.'
print('No labels are provided. Train on all seen labels.')
config['model']['labels'] = train_labels.keys()

###############################
Expand All @@ -116,7 +75,7 @@ def _main_(args):
###############################

if os.path.exists(config['train']['pretrained_weights']):
print "Loading pre-trained weights in", config['train']['pretrained_weights']
print("Loading pre-trained weights in", config['train']['pretrained_weights'])
yolo.load_weights(config['train']['pretrained_weights'])

###############################
Expand All @@ -127,7 +86,7 @@ def _main_(args):
valid_imgs = valid_imgs,
train_times = config['train']['train_times'],
valid_times = config['valid']['valid_times'],
nb_epoch = config['train']['nb_epoch'],
nb_epochs = config['train']['nb_epochs'],
learning_rate = config['train']['learning_rate'],
batch_size = config['train']['batch_size'],
warmup_epochs = config['train']['warmup_epochs'],
Expand Down

0 comments on commit 6e11249

Please sign in to comment.