-
Notifications
You must be signed in to change notification settings - Fork 26
/
freeze.py
56 lines (43 loc) · 1.53 KB
/
freeze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#
# Copyright (c) 2017 cTuning foundation.
# See CK COPYRIGHT.txt for copyright details.
#
# SPDX-License-Identifier: BSD-3-Clause.
# See CK LICENSE.txt for licensing details.
#
import imp
import os
import tensorflow as tf
MODEL_MODULE = os.getenv('CK_ENV_TENSORFLOW_MODEL_MODULE')
MODEL_WEIGHTS = os.getenv('CK_ENV_TENSORFLOW_MODEL_WEIGHTS')
TARGET_FILE = os.getenv('CK_TARGET_PB_FILE')
FREEZE_AS_TEXT = os.getenv('CK_FREEZE_AS_TEXT')
from tensorflow.python.platform import gfile
if FREEZE_AS_TEXT == "YES":
from google.protobuf import text_format
def main(_):
print('Model module: ' + MODEL_MODULE)
print('Model weights: ' + MODEL_WEIGHTS)
target_file = TARGET_FILE if TARGET_FILE else 'graph.pb'
print('Target file: ' + target_file)
# Load model implementation module
model = imp.load_source('tf_model', MODEL_MODULE)
# Load weights
model.load_weights(MODEL_WEIGHTS)
with tf.Graph().as_default(), tf.Session() as sess:
# Build net
input_shape = (None, 227, 227, 3)
input_node = tf.placeholder(dtype=tf.float32, shape=input_shape, name="input")
output_node = model.inference(input_node)
graph_def = tf.get_default_graph().as_graph_def()
if FREEZE_AS_TEXT == "YES":
print('Serialize as text...')
with gfile.GFile(target_file, "w") as f:
f.write(text_format.MessageToString(graph_def))
else:
print('Serialize as binary...')
with gfile.GFile(target_file, "wb") as f:
f.write(graph_def.SerializeToString())
print('OK')
if __name__ == '__main__':
tf.app.run()