Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@

import sys


def multiline_str_representer(dumper, data):
if len(data.splitlines()) > 1:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
else:
return dumper.represent_scalar("tag:yaml.org,2002:str", data)


try:
import yaml
from yaml import YAMLObject as _YAMLObject, add_representer

add_representer(str, multiline_str_representer)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"This tool requires PyYAML but it was not installed. "
f"Recommend: {sys.executable} -m pip install PyYAML"
) from e

__all__ = [
"yaml_dump",
"yaml_dump_all",
"YAMLObject",
]
class _YAMLObject:
pass


class YAMLObject(yaml.YAMLObject):
class YAMLObject(_YAMLObject):
@classmethod
def to_yaml(cls, dumper, self):
"""Default to a custom dictionary mapping."""
Expand All @@ -33,21 +36,34 @@ def as_linalg_yaml(self):
return yaml_dump(self)


def multiline_str_representer(dumper, data):
if len(data.splitlines()) > 1:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
else:
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
def yaml_dump(data, sort_keys=False, **kwargs):
try:
import yaml

return yaml.dump(data, sort_keys=sort_keys, **kwargs)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"This tool requires PyYAML but it was not installed. "
f"Recommend: {sys.executable} -m pip install PyYAML"
) from e

yaml.add_representer(str, multiline_str_representer)

def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
try:
import yaml

def yaml_dump(data, sort_keys=False, **kwargs):
return yaml.dump(data, sort_keys=sort_keys, **kwargs)
return yaml.dump_all(
data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"This tool requires PyYAML but it was not installed. "
f"Recommend: {sys.executable} -m pip install PyYAML"
) from e


def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
return yaml.dump_all(
data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
)
__all__ = [
"yaml_dump",
"yaml_dump_all",
"YAMLObject",
]
6 changes: 4 additions & 2 deletions mlir/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# BUILD dependencies
nanobind>=2.9, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
typing_extensions>=4.12.2
# RUN dependencies
numpy>=1.19.5, <=2.1.2
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
typing_extensions>=4.12.2