From 0d5433a724d2f1cc0b7e51d88c102f795a611736 Mon Sep 17 00:00:00 2001 From: Mrinal Jain Date: Fri, 18 Mar 2022 07:29:24 -0400 Subject: [PATCH] Consistent saved_model output format (#7032) --- export.py | 2 +- models/common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index d4f980fdb993..2d4a68e62f89 100644 --- a/export.py +++ b/export.py @@ -275,7 +275,7 @@ def export_saved_model(model, im, file, dynamic, m = m.get_concrete_function(spec) frozen_func = convert_variables_to_constants_v2(m) tfm = tf.Module() - tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec]) + tfm.__call__ = tf.function(lambda x: frozen_func(x)[0], [spec]) tfm.__call__(im) tf.saved_model.save( tfm, diff --git a/models/common.py b/models/common.py index 4ad040fcd7f1..5561d92ecb73 100644 --- a/models/common.py +++ b/models/common.py @@ -441,7 +441,7 @@ def forward(self, im, augment=False, visualize=False, val=False): else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) if self.saved_model: # SavedModel - y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy() + y = (self.model(im, training=False) if self.keras else self.model(im)).numpy() elif self.pb: # GraphDef y = self.frozen_func(x=self.tf.constant(im)).numpy() else: # Lite or Edge TPU