-
Notifications
You must be signed in to change notification settings - Fork 1
/
tf_inf_resnet18.py
47 lines (33 loc) · 981 Bytes
/
tf_inf_resnet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import tensorflow as tf
import cv2
import numpy as np
graph_def = 'weights/densenet161.pb'
with tf.gfile.GFile(graph_def, "rb") as f:
restored_graph_def = tf.GraphDef()
restored_graph_def.ParseFromString(f.read())
n = len(restored_graph_def.node)
print(n)
k = 0
for node in restored_graph_def.node:
#if 'output' in node.name.lower():
if k == n-1:
output_node = node.name
print(node.name, node.op)
k+=1
tf.import_graph_def(
restored_graph_def,
input_map=None,
return_elements=None,
name="")
#img = cv2.imread('data/samples/bus.jpg')
#img = cv2.resize(img, (416, 416))
#img = img[None, :, :, :]
#img = np.transpose(img, [0, 3, 1, 2])
#exit()
img = np.zeros((1,3)+(224,224)).astype(np.float32)
#img = np.transpose(img, [0,3,1,2])
with tf.Session() as sess:
pred = sess.run(output_node + ":0", feed_dict={'actual_input_1:0':img})
print(pred)
# writer = tf.summary.FileWriter("./graph", graph)
# writer.close()