Skip to content

Commit

Permalink
fix some bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hizhangp committed Feb 28, 2018
1 parent d87db1f commit 88aba9d
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 132 deletions.
63 changes: 42 additions & 21 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow as tf
import numpy as np
import os
import cv2
import argparse
import numpy as np
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer
Expand All @@ -22,12 +22,13 @@ def __init__(self, net, weight_file):
self.threshold = cfg.THRESHOLD
self.iou_threshold = cfg.IOU_THRESHOLD
self.boundary1 = self.cell_size * self.cell_size * self.num_class
self.boundary2 = self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell
self.boundary2 = self.boundary1 +\
self.cell_size * self.cell_size * self.boxes_per_cell

self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())

print 'Restoring weights from: ' + self.weights_file
print('Restoring weights from: ' + self.weights_file)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.weights_file)

Expand All @@ -40,7 +41,11 @@ def draw_result(self, img, result):
cv2.rectangle(img, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2)
cv2.rectangle(img, (x - w, y - h - 20),
(x + w, y - h), (125, 125, 125), -1)
cv2.putText(img, result[i][0] + ' : %.2f' % result[i][5], (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.CV_AA)
lineType = cv2.LINE_AA if cv2.__version__ > '3' else cv2.CV_AA
cv2.putText(
img, result[i][0] + ' : %.2f' % result[i][5],
(x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 0), 1, lineType)

def detect(self, img):
img_h, img_w, _ = img.shape
Expand Down Expand Up @@ -71,11 +76,22 @@ def detect_from_cvmat(self, inputs):
def interpret_output(self, output):
probs = np.zeros((self.cell_size, self.cell_size,
self.boxes_per_cell, self.num_class))
class_probs = np.reshape(output[0:self.boundary1], (self.cell_size, self.cell_size, self.num_class))
scales = np.reshape(output[self.boundary1:self.boundary2], (self.cell_size, self.cell_size, self.boxes_per_cell))
boxes = np.reshape(output[self.boundary2:], (self.cell_size, self.cell_size, self.boxes_per_cell, 4))
offset = np.transpose(np.reshape(np.array([np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell),
[self.boxes_per_cell, self.cell_size, self.cell_size]), (1, 2, 0))
class_probs = np.reshape(
output[0:self.boundary1],
(self.cell_size, self.cell_size, self.num_class))
scales = np.reshape(
output[self.boundary1:self.boundary2],
(self.cell_size, self.cell_size, self.boxes_per_cell))
boxes = np.reshape(
output[self.boundary2:],
(self.cell_size, self.cell_size, self.boxes_per_cell, 4))
offset = np.array(
[np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell)
offset = np.transpose(
np.reshape(
offset,
[self.boxes_per_cell, self.cell_size, self.cell_size]),
(1, 2, 0))

boxes[:, :, :, 0] += offset
boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2))
Expand All @@ -94,8 +110,9 @@ def interpret_output(self, output):
boxes_filtered = boxes[filter_mat_boxes[0],
filter_mat_boxes[1], filter_mat_boxes[2]]
probs_filtered = probs[filter_mat_probs]
classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[filter_mat_boxes[
0], filter_mat_boxes[1], filter_mat_boxes[2]]
classes_num_filtered = np.argmax(
filter_mat_probs, axis=3)[
filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]

argsort = np.array(np.argsort(probs_filtered))[::-1]
boxes_filtered = boxes_filtered[argsort]
Expand All @@ -116,8 +133,13 @@ def interpret_output(self, output):

result = []
for i in range(len(boxes_filtered)):
result.append([self.classes[classes_num_filtered[i]], boxes_filtered[i][0], boxes_filtered[
i][1], boxes_filtered[i][2], boxes_filtered[i][3], probs_filtered[i]])
result.append(
[self.classes[classes_num_filtered[i]],
boxes_filtered[i][0],
boxes_filtered[i][1],
boxes_filtered[i][2],
boxes_filtered[i][3],
probs_filtered[i]])

return result

Expand All @@ -126,11 +148,8 @@ def iou(self, box1, box2):
max(box1[0] - 0.5 * box1[2], box2[0] - 0.5 * box2[2])
lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - \
max(box1[1] - 0.5 * box1[3], box2[1] - 0.5 * box2[3])
if tb < 0 or lr < 0:
intersection = 0
else:
intersection = tb * lr
return intersection / (box1[2] * box1[3] + box2[2] * box2[3] - intersection)
inter = 0 if tb < 0 or lr < 0 else tb * lr
return inter / (box1[2] * box1[3] + box2[2] * box2[3] - inter)

def camera_detector(self, cap, wait=10):
detect_timer = Timer()
Expand All @@ -141,7 +160,8 @@ def camera_detector(self, cap, wait=10):
detect_timer.tic()
result = self.detect(frame)
detect_timer.toc()
print('Average detecting time: {:.3f}s'.format(detect_timer.average_time))
print('Average detecting time: {:.3f}s'.format(
detect_timer.average_time))

self.draw_result(frame, result)
cv2.imshow('Camera', frame)
Expand All @@ -156,7 +176,8 @@ def image_detector(self, imname, wait=0):
detect_timer.tic()
result = self.detect(image)
detect_timer.toc()
print('Average detecting time: {:.3f}s'.format(detect_timer.average_time))
print('Average detecting time: {:.3f}s'.format(
detect_timer.average_time))

self.draw_result(image, result)
cv2.imshow('Image', image)
Expand Down
43 changes: 21 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import tensorflow as tf
import datetime
import os
import argparse
import datetime
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer
from utils.pascal_voc import pascal_voc

slim = tf.contrib.slim


class Solver(object):

Expand All @@ -28,24 +30,19 @@ def __init__(self, net, data):
self.save_cfg()

self.variable_to_restore = tf.global_variables()
self.restorer = tf.train.Saver(self.variable_to_restore, max_to_keep=None)
self.saver = tf.train.Saver(self.variable_to_restore, max_to_keep=None)
self.ckpt_file = os.path.join(self.output_dir, 'save.ckpt')
self.ckpt_file = os.path.join(self.output_dir, 'yolo')
self.summary_op = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.output_dir, flush_secs=60)

self.global_step = tf.get_variable(
'global_step', [], initializer=tf.constant_initializer(0), trainable=False)
self.global_step = tf.train.create_global_step()
self.learning_rate = tf.train.exponential_decay(
self.initial_learning_rate, self.global_step, self.decay_steps,
self.decay_rate, self.staircase, name='learning_rate')
self.optimizer = tf.train.GradientDescentOptimizer(
learning_rate=self.learning_rate).minimize(
self.net.total_loss, global_step=self.global_step)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.averages_op = self.ema.apply(tf.trainable_variables())
with tf.control_dependencies([self.optimizer]):
self.train_op = tf.group(self.averages_op)
learning_rate=self.learning_rate)
self.train_op = slim.learning.create_train_op(
self.net.total_loss, self.optimizer, global_step=self.global_step)

gpu_options = tf.GPUOptions()
config = tf.ConfigProto(gpu_options=gpu_options)
Expand All @@ -54,7 +51,7 @@ def __init__(self, net, data):

if self.weights_file is not None:
print('Restoring weights from: ' + self.weights_file)
self.restorer.restore(self.sess, self.weights_file)
self.saver.restore(self.sess, self.weights_file)

self.writer.add_graph(self.sess.graph)

Expand All @@ -63,12 +60,13 @@ def train(self):
train_timer = Timer()
load_timer = Timer()

for step in xrange(1, self.max_iter + 1):
for step in range(1, self.max_iter + 1):

load_timer.tic()
images, labels = self.data.get()
load_timer.toc()
feed_dict = {self.net.images: images, self.net.labels: labels}
feed_dict = {self.net.images: images,
self.net.labels: labels}

if step % self.summary_iter == 0:
if step % (self.summary_iter * 10) == 0:
Expand All @@ -79,10 +77,10 @@ def train(self):
feed_dict=feed_dict)
train_timer.toc()

log_str = ('{} Epoch: {}, Step: {}, Learning rate: {},'
' Loss: {:5.3f}\nSpeed: {:.3f}s/iter,'
' Load: {:.3f}s/iter, Remain: {}').format(
datetime.datetime.now().strftime('%m/%d %H:%M:%S'),
log_str = '''{} Epoch: {}, Step: {}, Learning rate: {},'''
''' Loss: {:5.3f}\nSpeed: {:.3f}s/iter,'''
'''' Load: {:.3f}s/iter, Remain: {}'''.format(
datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
self.data.epoch,
int(step),
round(self.learning_rate.eval(session=self.sess), 6),
Expand All @@ -108,10 +106,10 @@ def train(self):

if step % self.save_iter == 0:
print('{} Saving checkpoint file to: {}'.format(
datetime.datetime.now().strftime('%m/%d %H:%M:%S'),
datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
self.output_dir))
self.saver.save(self.sess, self.ckpt_file,
global_step=self.global_step)
self.saver.save(
self.sess, self.ckpt_file, global_step=self.global_step)

def save_cfg(self):

Expand Down Expand Up @@ -159,6 +157,7 @@ def main():
solver.train()
print('Done training.')


if __name__ == '__main__':

# python train.py --weights YOLO_small.ckpt --gpu 0
Expand Down
40 changes: 24 additions & 16 deletions utils/pascal_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import xml.etree.ElementTree as ET
import numpy as np
import cv2
import cPickle
import pickle
import copy
import yolo.config as cfg

Expand All @@ -16,7 +16,7 @@ def __init__(self, phase, rebuild=False):
self.image_size = cfg.IMAGE_SIZE
self.cell_size = cfg.CELL_SIZE
self.classes = cfg.CLASSES
self.class_to_ind = dict(zip(self.classes, xrange(len(self.classes))))
self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
self.flipped = cfg.FLIPPED
self.phase = phase
self.rebuild = rebuild
Expand All @@ -26,8 +26,10 @@ def __init__(self, phase, rebuild=False):
self.prepare()

def get(self):
images = np.zeros((self.batch_size, self.image_size, self.image_size, 3))
labels = np.zeros((self.batch_size, self.cell_size, self.cell_size, 25))
images = np.zeros(
(self.batch_size, self.image_size, self.image_size, 3))
labels = np.zeros(
(self.batch_size, self.cell_size, self.cell_size, 25))
count = 0
while count < self.batch_size:
imname = self.gt_labels[self.cursor]['imname']
Expand Down Expand Up @@ -58,23 +60,27 @@ def prepare(self):
gt_labels_cp = copy.deepcopy(gt_labels)
for idx in range(len(gt_labels_cp)):
gt_labels_cp[idx]['flipped'] = True
gt_labels_cp[idx]['label'] = gt_labels_cp[idx]['label'][:, ::-1, :]
for i in xrange(self.cell_size):
for j in xrange(self.cell_size):
gt_labels_cp[idx]['label'] =\
gt_labels_cp[idx]['label'][:, ::-1, :]
for i in range(self.cell_size):
for j in range(self.cell_size):
if gt_labels_cp[idx]['label'][i, j, 0] == 1:
gt_labels_cp[idx]['label'][i, j, 1] = self.image_size - 1 - gt_labels_cp[idx]['label'][i, j, 1]
gt_labels_cp[idx]['label'][i, j, 1] = \
self.image_size - 1 -\
gt_labels_cp[idx]['label'][i, j, 1]
gt_labels += gt_labels_cp
np.random.shuffle(gt_labels)
self.gt_labels = gt_labels
return gt_labels

def load_labels(self):
cache_file = os.path.join(self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')
cache_file = os.path.join(
self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl')

if os.path.isfile(cache_file) and not self.rebuild:
print('Loading gt_labels from: ' + cache_file)
with open(cache_file, 'rb') as f:
gt_labels = cPickle.load(f)
gt_labels = pickle.load(f)
return gt_labels

print('Processing gt_labels from: ' + self.data_path)
Expand All @@ -83,11 +89,11 @@ def load_labels(self):
os.makedirs(self.cache_path)

if self.phase == 'train':
txtname = os.path.join(self.data_path, 'ImageSets', 'Main',
'trainval.txt')
txtname = os.path.join(
self.data_path, 'ImageSets', 'Main', 'trainval.txt')
else:
txtname = os.path.join(self.data_path, 'ImageSets', 'Main',
'test.txt')
txtname = os.path.join(
self.data_path, 'ImageSets', 'Main', 'test.txt')
with open(txtname, 'r') as f:
self.image_index = [x.strip() for x in f.readlines()]

Expand All @@ -97,10 +103,12 @@ def load_labels(self):
if num == 0:
continue
imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
gt_labels.append({'imname': imname, 'label': label, 'flipped': False})
gt_labels.append({'imname': imname,
'label': label,
'flipped': False})
print('Saving gt_labels to: ' + cache_file)
with open(cache_file, 'wb') as f:
cPickle.dump(gt_labels, f)
pickle.dump(gt_labels, f)
return gt_labels

def load_pascal_annotation(self, index):
Expand Down
7 changes: 5 additions & 2 deletions utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import time, datetime
import time
import datetime


class Timer(object):
'''
A simple timer.
'''

def __init__(self):
self.init_time = time.time()
self.total_time = 0.
Expand Down Expand Up @@ -33,5 +36,5 @@ def remain(self, iters, max_iters):
self.remain_time = 0
else:
self.remain_time = (time.time() - self.init_time) * \
(max_iters - iters) / iters
(max_iters - iters) / iters
return str(datetime.timedelta(seconds=int(self.remain_time)))
2 changes: 1 addition & 1 deletion yolo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

OUTPUT_DIR = os.path.join(PASCAL_PATH, 'output')

WEIGHTS_DIR = os.path.join(PASCAL_PATH, 'weight')
WEIGHTS_DIR = os.path.join(PASCAL_PATH, 'weights')

WEIGHTS_FILE = None
# WEIGHTS_FILE = os.path.join(DATA_PATH, 'weights', 'YOLO_small.ckpt')
Expand Down
Loading

0 comments on commit 88aba9d

Please sign in to comment.