-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo_video.py
151 lines (122 loc) · 4.82 KB
/
demo_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Nghia Tran
# Edited from demo.py
"""
Detects Cars in an video using KittiBox.
Input: Video
Output: Video (with Cars plotted in Green)
Utilizes: Trained KittiBox weights. If no logdir is given,
pretrained weights will be downloaded and used.
Usage:
python demo_video.py --input_image data/demo.png [--output_image output_image]
[--logdir /path/to/weights] [--gpus 0]
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import sys
import imageio
# configure logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf
import time
import cv2
import argparse
sys.path.insert(1, 'incl')
from utils import train_utils as kittibox_utils
try:
# Check whether setup was done correctly
import tensorvision.utils as tv_utils
import tensorvision.core as core
except ImportError:
# You forgot to initialize submodules
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
parser = argparse.ArgumentParser(description='Create summsion for Kitti')
parser.add_argument('video', type=str, help='Path to video.')
parser.add_argument('logdir', type=str, help='Path to logdir.')
parser.add_argument('--save', '-s', type=str, default='', help='Save file.')
def main():
args = parser.parse_args()
tv_utils.set_gpus_to_use()
logdir = args.logdir
# Loading hyperparameters from logdir
hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')
logging.info("Hypes loaded successfully.")
# Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
modules = tv_utils.load_modules_from_logdir(logdir)
logging.info("Modules loaded successfully. Starting to build tf graph.")
# Create tf graph and build module.
with tf.Graph().as_default():
# Create placeholder for input
image_pl = tf.placeholder(tf.float32, shape=(hypes["image_height"], hypes["image_width"], 3))
image = tf.expand_dims(image_pl, 0)
# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
image=image)
logging.info("Graph build successfully.")
# Create a session for running Ops on the Graph.
sess = tf.Session()
saver = tf.train.Saver()
# Load weights from logdir
core.load_weights(logdir, sess, saver)
logging.info("Weights loaded successfully.")
if args.save:
save_file = args.save
else:
filename, file_extension = os.path.splitext(args.video)
save_file = filename + '_rect' + file_extension
if os.path.isfile(save_file):
os.remove(save_file)
start = time.time()
logging.info("Making video")
vidcap = cv2.VideoCapture(args.video)
cnt = 0
with imageio.get_writer(save_file, mode='I', fps=20) as writer:
while True:
success, image = vidcap.read()
if not success:
break
cnt += 1
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = scp.misc.imresize(image, (hypes["image_height"],
hypes["image_width"]))
# feed = {image_pl: image}
#
# # Run KittiBox model on image
# pred_boxes = prediction['pred_boxes_new']
# pred_confidences = prediction['pred_confidences']
# (np_pred_boxes, np_pred_confidences) = sess.run([pred_boxes,
# pred_confidences],
# feed_dict=feed)
#
# # Apply non-maximal suppression
# # and draw predictions on the image
# threshold = 0.5
# output_image, _ = kittibox_utils.add_rectangles(
# hypes, [image], np_pred_confidences,
# np_pred_boxes, show_removed=False,
# use_stitching=True, rnn_len=1,
# min_conf=threshold, tau=hypes['tau'], color_acc=(0, 255, 0))
writer.append_data(image)
vidcap.release()
cv2.destroyAllWindows()
time_taken = time.time() - start
logging.info('Video saved as %s' % save_file)
logging.info('Number of images: %d' % cnt)
logging.info('Time takes: %.2f s' % (time_taken))
logging.info('Frequency: %.2f fps' % (cnt / time_taken))
if __name__ == '__main__':
main()