Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse sess in backend #273

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,43 @@ def tensor_dict(self):
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict

def run(self, inputs, **kwargs):
def run(self, inputs, sess=None, **kwargs):
""" Run TensorflowRep.

:param inputs: Given inputs.
:param sess: tf.Session. The environment in which Operation objects are executed,
and Tensor objects are evaluated.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowRep, self).run(inputs, **kwargs)

should_close_sess = sess is None
# TODO: handle name scope if necessary
with self.graph.as_default():
with tf.Session() as sess:
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])

feed_dict = {
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
}

sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]

output_values = sess.run(outputs, feed_dict=feed_dict)
return namedtupledict('Outputs', self.outputs)(*output_values)
sess = sess or tf.Session()
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])

feed_dict = {self.tensor_dict[key]: feed_dict[key] for key in self.inputs}

sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]

output_values = sess.run(outputs, feed_dict=feed_dict)
if should_close_sess:
sess.close()
return namedtupledict('Outputs', self.outputs)(*output_values)

def export_graph(self, path):
"""Export backend representation to a Tensorflow proto file.
Expand All @@ -99,3 +102,12 @@ def export_graph(self, path):
file = open(path, "wb")
file.write(graph_proto.SerializeToString())
file.close()

def create_session(self):
""" Create tf.Session object by using current graph.
Pass it to `run` function could reduce the overhead of initialization
when doing inference consecutively.

:returns: A Session object.
"""
return tf.Session(graph=self.graph)