Skip to content

Commit

Permalink
Improve save and load repocard metadata (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
elishowk committed Sep 27, 2021
1 parent dea8c4a commit b26da2a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_version() -> str:
"filelock",
"requests",
"tqdm",
"pyyaml",
"ruamel.yaml==0.17.16",
"typing-extensions",
"importlib_metadata;python_version<'3.8'",
"packaging>=20.9",
Expand Down
57 changes: 44 additions & 13 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import io
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union

import yaml
from ruamel.yaml import YAML


# the default loader/dumper type is 'rt' round-trip, preserving existing yaml formatting
# 'rt' derivates from safe loader/dumper
yaml = YAML()

# exact same regex as in the Hub server. Please keep in sync.
REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]")

Expand All @@ -14,7 +20,7 @@ def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
match = REGEX_YAML_BLOCK.search(content)
if match:
yaml_block = match.group(1)
data = yaml.safe_load(yaml_block)
data = yaml.load(yaml_block)
if isinstance(data, dict):
return data
else:
Expand All @@ -24,15 +30,40 @@ def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:


def metadata_save(local_path: Union[str, Path], data: Dict) -> None:
data_yaml = yaml.dump(data, sort_keys=False)
# sort_keys: keep dict order
content = Path(local_path).read_text() if Path(local_path).is_file() else ""
match = REGEX_YAML_BLOCK.search(content)
if match:
output = (
content[: match.start()] + f"---\n{data_yaml}---\n" + content[match.end() :]
)
else:
output = f"---\n{data_yaml}---\n{content}"
"""
Save the metadata dict in the upper YAML part
Trying to preserve newlines as in the existing file.
Docs about open() with newline="" parameter:
https://docs.python.org/3/library/functions.html?highlight=open#open
Does not work with "^M" linebreaks, which are replaced by \n
"""
line_break = "\n"
content = ""
# try to detect existing newline character
if os.path.exists(local_path):
with open(local_path, "r", newline="") as readme:
if type(readme.newlines) is tuple:
line_break = readme.newlines[0]
if type(readme.newlines) is str:
line_break = readme.newlines
content = readme.read()

# creates a new file if it not
with open(local_path, "w", newline="") as readme:
stream = io.StringIO()
yaml.dump(data, stream)
data_yaml = stream.getvalue()
# sort_keys: keep dict order
match = REGEX_YAML_BLOCK.search(content)
if match:
output = (
content[: match.start()]
+ f"---{line_break}{data_yaml}---{line_break}"
+ content[match.end() :]
)
else:
output = f"---{line_break}{data_yaml}---{line_break}{content}"

Path(local_path).write_text(output)
readme.write(output)
readme.close()
stream.close()
49 changes: 49 additions & 0 deletions tests/test_repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@
from .testing_utils import set_write_permission_and_retry


ROUND_TRIP_MODELCARD_CASE = """
---
language: no
datasets: CLUECorpusSmall
widget:
- text: 北京是[MASK]国的首都。
---
# Title
"""

DUMMY_MODELCARD = """
Hi
Expand Down Expand Up @@ -53,6 +64,11 @@
Hello
"""

DUMMY_NEW_MODELCARD_TARGET = """---
meaning_of_life: 42
---
"""

DUMMY_MODELCARD_TARGET_NO_TAGS = """
Hello
"""
Expand Down Expand Up @@ -94,9 +110,42 @@ def test_metadata_save_from_file_no_yaml(self):
content = filepath.read_text()
self.assertEqual(content, DUMMY_MODELCARD_TARGET_NO_YAML)

def test_metadata_save_new_file(self):
filename = "new_dummy_target.md"
filepath = Path(REPOCARD_DIR) / filename
metadata_save(filepath, {"meaning_of_life": 42})
content = filepath.read_text()
self.assertEqual(content, DUMMY_NEW_MODELCARD_TARGET)

def test_no_metadata_returns_none(self):
filename = "dummy_target_3.md"
filepath = Path(REPOCARD_DIR) / filename
filepath.write_text(DUMMY_MODELCARD_TARGET_NO_TAGS)
data = metadata_load(filepath)
self.assertEqual(data, None)

def test_metadata_roundtrip(self):
filename = "dummy_target.md"
filepath = Path(REPOCARD_DIR) / filename
filepath.write_text(ROUND_TRIP_MODELCARD_CASE)
metadata = metadata_load(filepath)
self.assertDictEqual(
{
"language": "no",
"datasets": "CLUECorpusSmall",
"widget": [{"text": "北京是[MASK]国的首都。"}],
},
metadata,
)
metadata_save(filepath, metadata)
content = filepath.read_text()
self.assertEqual(content, ROUND_TRIP_MODELCARD_CASE)
metadata = metadata_load(filepath)
self.assertDictEqual(
{
"language": "no",
"datasets": "CLUECorpusSmall",
"widget": [{"text": "北京是[MASK]国的首都。"}],
},
metadata,
)

0 comments on commit b26da2a

Please sign in to comment.