Skip to content

Commit

Permalink
tensorflow frozen facenet test
Browse files Browse the repository at this point in the history
Note that the facenet link will expire after one day and it is better to upload the model to a stable server.
  • Loading branch information
rainLiuplus committed Sep 14, 2018
1 parent 4351ef9 commit 5ade339
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 41 deletions.
7 changes: 6 additions & 1 deletion mmdnn/conversion/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _multi_thread_download(url, file_name, file_size, thread_count):
return file_name


def download_file(url, directory='./', local_fname=None, force_write=False, auto_unzip=False):
def download_file(url, directory='./', local_fname=None, force_write=False, auto_unzip=False, compre_type=''):
"""Download the data from source url, unless it's already here.
Args:
Expand All @@ -193,6 +193,11 @@ def download_file(url, directory='./', local_fname=None, force_write=False, auto
if not local_fname:
k = url.rfind('/')
local_fname = url[k + 1:]
import re
if local_fname != re.sub(r'[/:*?<>|]','',local_fname): # name is complex and translate it into simple one with the compression type
local_fname = 'temp_'
local_fname += compre_type


local_fname = os.path.join(directory, local_fname)

Expand Down
6 changes: 6 additions & 0 deletions mmdnn/conversion/coreml/coreml_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def _get_padding(IR_node):

return pads

def emit_Mul(self, IR_node):
"""
Skiped
"""
pass

def _emit_merge(self, IR_node, func):
"""
Convert concat layer to coreml.
Expand Down
3 changes: 2 additions & 1 deletion mmdnn/conversion/examples/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image



class TestKit(object):

truth = {
Expand Down Expand Up @@ -110,6 +111,7 @@ class TestKit(object):
'mobilenet_v2_1.0_224' : lambda path : TestKit.Standard(path, 224),
'nasnet-a_large' : lambda path : TestKit.Standard(path, 331),
'inception_resnet_v2' : lambda path : TestKit.Standard(path, 299),
'facenet' : lambda path: TestKit.Standard(path, 160)
},

'keras' : {
Expand Down Expand Up @@ -331,7 +333,6 @@ def print_result(self, predict):
self.result = predict
print (self.result)


@staticmethod
def print_intermediate_result(intermediate_output, if_transpose=False):
intermediate_output = np.squeeze(intermediate_output)
Expand Down
6 changes: 3 additions & 3 deletions mmdnn/conversion/examples/mxnet/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def print_result(self):
def inference(self, image_path):
self.preprocess(image_path)

# self.print_intermediate_result('pooling0', False)
self.print_intermediate_result('InceptionResnetV1/Repeat/block35_1/Conv2d_1x1/Conv2D', False)

self.print_result()
# self.print_result()

self.test_truth()
# self.test_truth()


def print_intermediate_result(self, layer_name, if_transpose = False):
Expand Down
50 changes: 38 additions & 12 deletions mmdnn/conversion/examples/tensorflow/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ class tensorflow_extractor(base_extractor):
'inception_v1_frozen' : {
'url' : 'https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
'filename' : 'inception_v1_2016_08_28_frozen.pb',
'tensor_out' : 'InceptionV1/Logits/Predictions/Reshape_1:0',
'tensor_in' : 'input:0',
'input_shape' : [224, 224, 3],
'tensor_out' : ['InceptionV1/Logits/Predictions/Reshape_1:0'],
'tensor_in' : ['input:0'],
'input_shape' : [[224, 224, 3]], # input_shape of the elem in tensor_in
'feed_dict' :lambda img: {'input:0':img},
'num_classes' : 1001,
},
'inception_v3' : {
Expand All @@ -69,9 +70,10 @@ class tensorflow_extractor(base_extractor):
'inception_v3_frozen' : {
'url' : 'https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz',
'filename' : 'inception_v3_2016_08_28_frozen.pb',
'tensor_out' : 'InceptionV3/Predictions/Softmax:0',
'tensor_in' : 'input:0',
'input_shape' : [299, 299, 3],
'tensor_out' : ['InceptionV3/Predictions/Softmax:0'],
'tensor_in' : ['input:0'],
'input_shape' : [[299, 299, 3]], # input_shape of the elem in tensor_in
'feed_dict' :lambda img: {'input:0':img},
'num_classes' : 1001,
},
'resnet_v1_50' : {
Expand Down Expand Up @@ -133,9 +135,10 @@ class tensorflow_extractor(base_extractor):
'mobilenet_v1_1.0_frozen' : {
'url' : 'https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz',
'filename' : 'mobilenet_v1_1.0_224/frozen_graph.pb',
'tensor_out' : 'MobilenetV1/Predictions/Softmax:0',
'tensor_in' : 'input:0',
'input_shape' : [224, 224, 3],
'tensor_out' : ['MobilenetV1/Predictions/Softmax:0'],
'tensor_in' : ['input:0'],
'input_shape' : [[224, 224, 3]], # input_shape of the elem in tensor_in
'feed_dict' :lambda img: {'input:0':img},
'num_classes' : 1001,
},
'mobilenet_v2_1.0_224':{
Expand All @@ -162,6 +165,26 @@ class tensorflow_extractor(base_extractor):
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 331, 331, 3]),
'num_classes' : 1001,
},
'facenet' : {
'url' : 'https://drive.google.com/facenet_frozen',
'filename' : 'facenet.pb',
'tensor_out' : 'embeddings:0',
'tensor_in' : 'input:0',
'phase_train' : 'phase_train:0',
'input_shape' : [160, 160, 3],
'num_classes' : 0,
},
# Note that the link will expire after one day and it is better to upload the model to a stable server.
'facenet_frozen' : {
'url' : 'https://doc-0k-7k-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/bilqocet7fepe2lnqpoacs1gkhrv2kq8/1536890400000/18056234690049221457/*/1R77HmFADxe87GmoLwzfgMu_HY0IhcyBz?e=download',
'filename' : '20180408-102900/20180408-102900.pb',
'compre_type' : '.zip',
'tensor_out' : ['InceptionResnetV1/Logits/AvgPool_1a_8x8/AvgPool:0'],
'tensor_in' : ['input:0','phase_train:0'],
'input_shape' : [[160, 160, 3],1], # input_shape of the elem in tensor_in
'feed_dict' : lambda img: {'input:0':img,'phase_train:0':False},
'num_classes' : 0,
}
}


Expand Down Expand Up @@ -200,13 +223,15 @@ def handle_frozen_graph(cls, architecture, path):
@classmethod
def get_frozen_para(cls, architecture):
frozenname = architecture + '_frozen'
return cls.architecture_map[frozenname]['filename'], cls.architecture_map[frozenname]['input_shape'], cls.architecture_map[frozenname]['tensor_in'], cls.architecture_map[frozenname]['tensor_out']
tensor_in = list(map(lambda x:x.split(':')[0], cls.architecture_map[frozenname]['tensor_in']))
tensor_out = list(map(lambda x:x.split(':')[0], cls.architecture_map[frozenname]['tensor_out']))
return cls.architecture_map[frozenname]['filename'], cls.architecture_map[frozenname]['input_shape'], tensor_in, tensor_out


@classmethod
def download(cls, architecture, path="./"):
if cls.sanity_check(architecture):
architecture_file = download_file(cls.architecture_map[architecture]['url'], directory=path, auto_unzip=True)
architecture_file = download_file(cls.architecture_map[architecture]['url'], directory=path, auto_unzip=True, compre_type=cls.architecture_map[architecture].get('compre_type',''))
if not architecture_file:
return None

Expand Down Expand Up @@ -249,11 +274,12 @@ def inference(cls, architecture, files, path, image_path, is_frozen=False):
original_gdef.ParseFromString(serialized)
tf_output_name = cls.architecture_map[architecture_]['tensor_out']
tf_input_name = cls.architecture_map[architecture_]['tensor_in']
feed_dict = cls.architecture_map[architecture_]['feed_dict']

with tf.Graph().as_default() as g:
tf.import_graph_def(original_gdef, name='')
with tf.Session(graph = g) as sess:
tf_out = sess.run(tf_output_name, feed_dict={tf_input_name: img})
tf_out = sess.run(tf_output_name[0], feed_dict=feed_dict(img)) # temporarily think the num of out nodes is one
predict = np.squeeze(tf_out)
return predict

Expand Down
22 changes: 22 additions & 0 deletions mmdnn/conversion/keras/keras2_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, model):
else:
network_path = model[0]
weight_path = model[1]
self._load_weights(weight_path)

self.IR_graph = IRGraph(network_path)
self.IR_graph.build()
Expand All @@ -51,6 +52,7 @@ def header_code(self):
from keras import layers
import keras.backend as K
import numpy as np
import keras.layers.core as core
def load_weights_from_file(weight_file):
Expand Down Expand Up @@ -245,6 +247,18 @@ def emit_Conv(self, IR_node):
def emit_UNKNOWN(self, IR_node):
print (IR_node.name)

##add mul op, just implement layer * constant
def emit_Mul(self, IR_node):
if IR_node.name in self.weights_dict and 'weights' in self.weights_dict[IR_node.name]:
self.used_layers.add('KerasMul')
weight_factor = "weights_dict['{}'].get('weights',1.0)".format(IR_node.name)
self.add_body(1, "{:<15} = mul_constant(weight_factor={}, layer_name= {})".format(
IR_node.variable_name,
weight_factor,
''.join(self.IR_graph.get_node(IR_node.in_edges[0]).real_variable_name)))
else:
raise NotImplementedError()


def emit_Add(self, IR_node):
self._emit_merge(IR_node, "add")
Expand Down Expand Up @@ -644,6 +658,14 @@ def emit_region(self, IR_node):
IR_node.get_attr("coord_scale"),
]

def _layer_KerasMul(self):
self.add_body(0, '''
def mul_constant(weight_factor, layer_name):
weight = core.Lambda(lambda x: x*weight_factor)
weight(layer_name)
return weight.output
''')

def _layer_Yolo(self):
self.add_body(0, '''
def yolo_parameter():
Expand Down
1 change: 1 addition & 0 deletions mmdnn/conversion/keras/keras2_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def gen_IR(self):
for layer in self.keras_graph.topological_sort:
current_node = self.keras_graph.get_node(layer)
node_type = current_node.type

if hasattr(self, "rename_" + node_type):
func = getattr(self, "rename_" + node_type)
func(current_node)
Expand Down
17 changes: 14 additions & 3 deletions mmdnn/conversion/mxnet/mxnet_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,21 @@ def emit_Add(self, IR_node):

def emit_Mul(self, IR_node):

# code = "{:<15} = mx.sym.broadcast_mul({}, {})".format(
# IR_node.variable_name,
# self.parent_variable_name(IR_node),
# self.parent_variable_name(IR_node, [1]))

if IR_node.name in self.weights and 'weights' in self.weights[IR_node.name]:
second_node = "mx.sym.Variable('{}', shape=(1,))".format(IR_node.name+'_weight')
self.output_weights[IR_node.name + '_weight'] = [self.weights[IR_node.name]['weights']]
else:
second_node = self.parent_variable_name(IR_node, [1])

code = "{:<15} = mx.sym.broadcast_mul({}, {})".format(
IR_node.variable_name,
self.parent_variable_name(IR_node),
self.parent_variable_name(IR_node, [1]))
IR_node.variable_name,
self.parent_variable_name(IR_node),
second_node)

return code

Expand Down
15 changes: 13 additions & 2 deletions mmdnn/conversion/tensorflow/tensorflow_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,20 @@ def emit_Flatten(self, IR_node):


def emit_Mul(self, IR_node):
self.add_body(1, "{:<15} = {}".format(

if IR_node.name in self.weights_dict and 'weights' in self.weights_dict[IR_node.name]:
weight_str = "* tf.Variable(__weights_dict['{}'].get('weights',1.0))".format(IR_node.name)
else:
weight_str = ""

# self.add_body(1, "{:<15} = {}".format(
# IR_node.variable_name,
# ' * '.join('%s' % self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges)))

self.add_body(1,"{:<15} = {}{}".format(
IR_node.variable_name,
' * '.join('%s' % self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges)))
' * '.join('%s' % self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges),
weight_str))

def emit_Const(self, IR_node):
if 'dtype' in IR_node.layer.attr:
Expand Down
11 changes: 8 additions & 3 deletions mmdnn/conversion/tensorflow/tensorflow_frozenparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
x = tensorflow.placeholder(dtype)

input_map[in_nodes[i] + ':0'] = x

tensorflow.import_graph_def(model, name='', input_map=input_map)

with tensorflow.Session(graph = g) as sess:
Expand Down Expand Up @@ -340,19 +340,21 @@ def _skip_node(cls, source_node):


def gen_IR(self):

for layer in self.src_graph.topological_sort:
current_node = self.src_graph.get_node(layer)

if self._skip_node(current_node):
continue

node_type = current_node.type

if hasattr(self, "rename_" + node_type):

func = getattr(self, "rename_" + node_type)
func(current_node)
else:

self.rename_UNKNOWN(current_node)


Expand Down Expand Up @@ -762,6 +764,9 @@ def rename_Shape(self, source_node):
IR_node = self._convert_identity_operation(source_node, new_op = 'Shape')
input_node = self.src_graph.get_parent(source_node.name, [0])
kwargs = {}
# print(input_node.layer)
# print(input_node.get_attr('_output_shapes'))
# print(self.tensor_shape_to_list(input_node.get_attr('_output_shapes')))
kwargs['shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0]

assign_IRnode_values(IR_node, kwargs)
Expand Down

0 comments on commit 5ade339

Please sign in to comment.