Skip to content

Commit

Permalink
Fix saved_model_cli tensorrt conversion
Browse files Browse the repository at this point in the history
The existing saved_model_cli convert tensorrt script fails in 2.X with module
not found "tensorflow.contrib". Updated the script to use the V2 API for
TensorRT to convert a saved_model.

The max_batch_size and is_dynamic_op parameters are not valid for the V2 API
so they have been removed.
  • Loading branch information
wdirons committed Dec 8, 2019
1 parent b41fbcb commit bf00bd6
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions tensorflow/python/tools/saved_model_cli.py
Expand Up @@ -747,19 +747,17 @@ def convert_with_tensorrt(args):
"""
# Import here instead of at top, because this will crash if TensorRT is
# not installed
from tensorflow.contrib import tensorrt # pylint: disable=g-import-not-at-top
tensorrt.create_inference_graph(
None,
None,
max_batch_size=args.max_batch_size,
max_workspace_size_bytes=args.max_workspace_size_bytes,
precision_mode=args.precision_mode,
minimum_segment_size=args.minimum_segment_size,
is_dynamic_op=args.is_dynamic_op,
input_saved_model_dir=args.dir,
input_saved_model_tags=args.tag_set.split(','),
output_saved_model_dir=args.output_dir)

from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top

params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_workspace_size_bytes=args.max_workspace_size_bytes,
precision_mode=args.precision_mode,
minimum_segment_size=args.minimum_segment_size)
converter = trt.TrtGraphConverterV2(input_saved_model_dir=args.dir,
input_saved_model_tags=args.tag_set.split(','),
conversion_params=params)
converter.convert()
converter.save(output_saved_model_dir=args.output_dir)

def create_parser():
"""Creates a parser that parse the command line arguments.
Expand Down Expand Up @@ -949,11 +947,6 @@ def create_parser():
'tensorrt',
description='Convert the SavedModel with Tensorflow-TensorRT integration',
formatter_class=argparse.RawTextHelpFormatter)
parser_convert_with_tensorrt.add_argument(
'--max_batch_size',
type=int,
default=1,
help='max size for the input batch')
parser_convert_with_tensorrt.add_argument(
'--max_workspace_size_bytes',
type=int,
Expand All @@ -971,12 +964,6 @@ def create_parser():
default=3,
help=('the minimum number of nodes required for a subgraph to be replaced'
'in a TensorRT node'))
parser_convert_with_tensorrt.add_argument(
'--is_dynamic_op',
type=bool,
default=False,
help=('whether to generate dynamic TRT ops which will build the TRT '
'network and engine at run time'))
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)

return parser
Expand Down

0 comments on commit bf00bd6

Please sign in to comment.