diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py index 1672656b3a1f8..2235bb2865c0d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py @@ -5,22 +5,25 @@ import sys + +def multiline_str_representer(dumper, data): + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + try: - import yaml + from yaml import YAMLObject as _YAMLObject, add_representer + + add_representer(str, multiline_str_representer) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"This tool requires PyYAML but it was not installed. " - f"Recommend: {sys.executable} -m pip install PyYAML" - ) from e -__all__ = [ - "yaml_dump", - "yaml_dump_all", - "YAMLObject", -] + class _YAMLObject: + pass -class YAMLObject(yaml.YAMLObject): +class YAMLObject(_YAMLObject): @classmethod def to_yaml(cls, dumper, self): """Default to a custom dictionary mapping.""" @@ -33,21 +36,34 @@ def as_linalg_yaml(self): return yaml_dump(self) -def multiline_str_representer(dumper, data): - if len(data.splitlines()) > 1: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") - else: - return dumper.represent_scalar("tag:yaml.org,2002:str", data) +def yaml_dump(data, sort_keys=False, **kwargs): + try: + import yaml + return yaml.dump(data, sort_keys=sort_keys, **kwargs) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e -yaml.add_representer(str, multiline_str_representer) +def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): + try: + import yaml -def yaml_dump(data, sort_keys=False, **kwargs): - return yaml.dump(data, sort_keys=sort_keys, **kwargs) + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e -def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): - return yaml.dump_all( - data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs - ) +__all__ = [ + "yaml_dump", + "yaml_dump_all", + "YAMLObject", +] diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index abe09259bb1e8..a1ff6e815d2f2 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,7 +1,9 @@ +# BUILD dependencies nanobind>=2.9, <3.0 -numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 +typing_extensions>=4.12.2 +# RUN dependencies +numpy>=1.19.5, <=2.1.2 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" -typing_extensions>=4.12.2