@@ -64,25 +64,30 @@ def import_saved_model(
64
64
* , output_path , saved_model_dir , exported_names , import_type , tags
65
65
):
66
66
# From here there be dragons.
67
- from tensorflow .python import pywrap_mlir
67
+ from tensorflow .mlir .experimental import (
68
+ convert_saved_model_to_mlir ,
69
+ convert_saved_model_v1_to_mlir ,
70
+ run_pass_pipeline ,
71
+ write_bytecode ,
72
+ )
68
73
69
74
if import_type == "savedmodel_v2" :
70
- result = pywrap_mlir . experimental_convert_saved_model_to_mlir (
75
+ result = convert_saved_model_to_mlir (
71
76
saved_model_dir , exported_names = exported_names , show_debug_info = False
72
77
)
73
78
elif import_type == "savedmodel_v1" :
74
79
# You saw it here, folks: The TF team just adds random positional params
75
80
# without explanation or default. So we detect and default them on our
76
81
# own. Because this is normal and fine.
77
- sig = inspect .signature (pywrap_mlir . experimental_convert_saved_model_v1_to_mlir )
82
+ sig = inspect .signature (convert_saved_model_v1_to_mlir )
78
83
dumb_extra_kwargs = {}
79
84
if "include_variables_in_initializers" in sig .parameters :
80
85
dumb_extra_kwargs ["include_variables_in_initializers" ] = False
81
86
if "upgrade_legacy" in sig .parameters :
82
87
dumb_extra_kwargs ["upgrade_legacy" ] = False
83
88
if "lift_variables" in sig .parameters :
84
89
dumb_extra_kwargs ["lift_variables" ] = True
85
- result = pywrap_mlir . experimental_convert_saved_model_v1_to_mlir (
90
+ result = convert_saved_model_v1_to_mlir (
86
91
saved_model_dir ,
87
92
exported_names = exported_names ,
88
93
tags = tags ,
@@ -97,15 +102,13 @@ def import_saved_model(
97
102
# This is fine and normal, and totally to be expected. :(
98
103
result = re .sub (r"func @__inference_(.+)_[0-9]+\(" , r"func @\1(" , result )
99
104
pipeline = ["tf-lower-to-mlprogram-and-hlo" ]
100
- result = pywrap_mlir .experimental_run_pass_pipeline (
101
- result , "," .join (pipeline ), show_debug_info = False
102
- )
105
+ result = run_pass_pipeline (result , "," .join (pipeline ), show_debug_info = False )
103
106
104
- # TODO: The experimental_write_bytecode function does not register the
107
+ # TODO: The write_bytecode function does not register the
105
108
# stablehlo dialect. Once fixed, remove this bypass.
106
109
WRITE_BYTECODE = False
107
110
if WRITE_BYTECODE :
108
- result = pywrap_mlir . experimental_write_bytecode (output_path , result )
111
+ result = write_bytecode (output_path , result )
109
112
else :
110
113
with open (output_path , "wt" ) as f :
111
114
f .write (result )
0 commit comments