/
evaluate_on_test.py
106 lines (76 loc) · 3.26 KB
/
evaluate_on_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import cPickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
import os
from utilities import label_img_to_color
from model import ENet_model
#project_dir = "/home/fregu856/segmentation/"
project_dir = "/root/segmentation/"
data_dir = project_dir + "data/"
model_id = "test" # (change this to not overwrite all log data when you train the model)
batch_size = 4
img_height = 512
img_width = 1024
model = ENet_model(model_id, img_height=img_height, img_width=img_width, batch_size=batch_size)
no_of_classes = model.no_of_classes
train_mean_channels = cPickle.load(open("data/mean_channels.pkl"))
seq_frames_dir = "/root/data/cityscapes/leftImg8bit/demoVideo/stuttgart_02/"
seq_frame_paths = []
frame_names = sorted(os.listdir(seq_frames_dir))
for step, frame_name in enumerate(frame_names):
if step % 100 == 0:
print step
frame_path = seq_frames_dir + frame_name
seq_frame_paths.append(frame_path)
# compute the number of batches needed to iterate through the data:
no_of_frames = len(seq_frame_paths)
no_of_batches = int(no_of_frames/batch_size)
results_dir = model.project_dir + "results_on_seq/"
# create a saver for restoring variables/parameters:
saver = tf.train.Saver(tf.trainable_variables(), write_version=tf.train.SaverDef.V2)
with tf.Session() as sess:
# initialize all variables/parameters:
init = tf.global_variables_initializer()
sess.run(init)
# restore the best trained model:
saver.restore(sess, project_dir + "training_logs/best_model/model_1_epoch_23.ckpt")
batch_pointer = 0
for step in range(no_of_batches):
batch_imgs = np.zeros((batch_size, img_height, img_width, 3), dtype=np.float32)
img_paths = []
for i in range(batch_size):
img_path = seq_frame_paths[batch_pointer + i]
img_paths.append(img_path)
img = cv2.imread(img_path, -1)
img = cv2.resize(img, (img_width, img_height))
img = img - train_mean_channels
batch_imgs[i] = img
batch_pointer += batch_size
batch_feed_dict = model.create_feed_dict(imgs_batch=batch_imgs,
early_drop_prob=0.0, late_drop_prob=0.0)
logits = sess.run(model.logits, feed_dict=batch_feed_dict)
print "step: %d/%d" % (step+1, no_of_batches)
predictions = np.argmax(logits, axis=3)
for i in range(batch_size):
pred_img = predictions[i]
pred_img_color = label_img_to_color(pred_img)
img = batch_imgs[i] + train_mean_channels
img_file_name = img_paths[i].split("/")[-1]
img_name = img_file_name.split(".png")[0]
pred_path = results_dir + img_name + "_pred.png"
merged_img = 0.3*img + 0.7*pred_img_color
cv2.imwrite(pred_path, merged_img)
fourcc = cv2.cv.CV_FOURCC("M", "J", "P", "G")
out = cv2.VideoWriter(results_dir + "cityscapes_stuttgart_02_pred.avi", fourcc, 20.0, (img_width, img_height))
frame_names = sorted(os.listdir(results_dir))
for step, frame_name in enumerate(frame_names):
if step % 100 == 0:
print step
if ".png" in frame_name:
frame_path = results_dir + frame_name
frame = cv2.imread(frame_path, -1)
out.write(frame)