Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse sess in backend #273

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
49 changes: 25 additions & 24 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
self._graph = graph
self._inputs = inputs or []
self._outputs = outputs or []
self._sess = None
self._tensor_dict = tensor_dict or {}

@property
Expand Down Expand Up @@ -60,30 +61,30 @@ def run(self, inputs, **kwargs):

# TODO: handle name scope if necessary
with self.graph.as_default():
with tf.Session() as sess:
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])

feed_dict = {
self.tensor_dict[key]: feed_dict[key]
for key in self.inputs
}

sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]

output_values = sess.run(outputs, feed_dict=feed_dict)
return namedtupledict('Outputs', self.outputs)(*output_values)
sess = self._sess or tf.Session()
self._sess = sess

if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])

feed_dict = {self.tensor_dict[key]: feed_dict[key] for key in self.inputs}

sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]

output_values = sess.run(outputs, feed_dict=feed_dict)

return namedtupledict('Outputs', self.outputs)(*output_values)

def export_graph(self, path):
"""Export backend representation to a Tensorflow proto file.
Expand Down