Skip to content

Commit 4939b9a

Browse files
authored
[tf] Switch off of pywrap entry point (#14303)
This is expected to be available in 2.14 and pre-fetching so that we can move off of more internal API end points.
1 parent ccf886b commit 4939b9a

File tree

2 files changed

+14
-11
lines changed
  • integrations/tensorflow/python_projects
    • iree_tflite/iree/tools/tflite/scripts/iree_import_tflite
    • iree_tf/iree/tools/tf/scripts/iree_import_tf

2 files changed

+14
-11
lines changed

integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/scripts/iree_import_tf/__main__.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,30 @@ def import_saved_model(
6464
*, output_path, saved_model_dir, exported_names, import_type, tags
6565
):
6666
# From here there be dragons.
67-
from tensorflow.python import pywrap_mlir
67+
from tensorflow.mlir.experimental import (
68+
convert_saved_model_to_mlir,
69+
convert_saved_model_v1_to_mlir,
70+
run_pass_pipeline,
71+
write_bytecode,
72+
)
6873

6974
if import_type == "savedmodel_v2":
70-
result = pywrap_mlir.experimental_convert_saved_model_to_mlir(
75+
result = convert_saved_model_to_mlir(
7176
saved_model_dir, exported_names=exported_names, show_debug_info=False
7277
)
7378
elif import_type == "savedmodel_v1":
7479
# You saw it here, folks: The TF team just adds random positional params
7580
# without explanation or default. So we detect and default them on our
7681
# own. Because this is normal and fine.
77-
sig = inspect.signature(pywrap_mlir.experimental_convert_saved_model_v1_to_mlir)
82+
sig = inspect.signature(convert_saved_model_v1_to_mlir)
7883
dumb_extra_kwargs = {}
7984
if "include_variables_in_initializers" in sig.parameters:
8085
dumb_extra_kwargs["include_variables_in_initializers"] = False
8186
if "upgrade_legacy" in sig.parameters:
8287
dumb_extra_kwargs["upgrade_legacy"] = False
8388
if "lift_variables" in sig.parameters:
8489
dumb_extra_kwargs["lift_variables"] = True
85-
result = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
90+
result = convert_saved_model_v1_to_mlir(
8691
saved_model_dir,
8792
exported_names=exported_names,
8893
tags=tags,
@@ -97,15 +102,13 @@ def import_saved_model(
97102
# This is fine and normal, and totally to be expected. :(
98103
result = re.sub(r"func @__inference_(.+)_[0-9]+\(", r"func @\1(", result)
99104
pipeline = ["tf-lower-to-mlprogram-and-hlo"]
100-
result = pywrap_mlir.experimental_run_pass_pipeline(
101-
result, ",".join(pipeline), show_debug_info=False
102-
)
105+
result = run_pass_pipeline(result, ",".join(pipeline), show_debug_info=False)
103106

104-
# TODO: The experimental_write_bytecode function does not register the
107+
# TODO: The write_bytecode function does not register the
105108
# stablehlo dialect. Once fixed, remove this bypass.
106109
WRITE_BYTECODE = False
107110
if WRITE_BYTECODE:
108-
result = pywrap_mlir.experimental_write_bytecode(output_path, result)
111+
result = write_bytecode(output_path, result)
109112
else:
110113
with open(output_path, "wt") as f:
111114
f.write(result)

integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/scripts/iree_import_tflite/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212
import sys
1313
import iree.tools.tflite
14-
from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode
14+
from tensorflow.mlir.experimental import tflite_to_tosa_bytecode
1515

1616

1717
def main():
@@ -66,7 +66,7 @@ def tflite_to_tosa(
6666
ordered_input_arrays=None,
6767
ordered_output_arrays=None,
6868
):
69-
experimental_tflite_to_tosa_bytecode(
69+
tflite_to_tosa_bytecode(
7070
flatbuffer,
7171
bytecode,
7272
use_external_constant,

0 commit comments

Comments
 (0)