Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Aug 8, 2023
1 parent 0dc49b4 commit ddba531
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
44 changes: 31 additions & 13 deletions gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def setup(self, components: list[IOComponent], flagging_dir: str):
repo_type="dataset",
exist_ok=True,
).repo_id
path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
huggingface_hub.metadata_update(
repo_id=self.dataset_id,
repo_type="dataset",
metadata={
"configs": [
{
"config_name": "default",
"data_files": [{"split": "train", "path": path_glob}],
}
]
},
overwrite=True,
)

# Setup flagging dir
self.components = components
Expand Down Expand Up @@ -284,7 +298,7 @@ def flag(
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
unique_id = str(uuid.uuid4())
components_dir = self.dataset_dir / str(uuid.uuid4())
components_dir = self.dataset_dir / unique_id
data_file = components_dir / "metadata.jsonl"
path_in_repo = unique_id # upload in sub folder (safer for concurrency)
else:
Expand Down Expand Up @@ -416,28 +430,32 @@ def _deserialize_components(
features[label] = {"dtype": "string", "_type": "Value"}
try:
assert Path(deserialized).exists()
row.append(Path(deserialized).name)
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
except (AssertionError, TypeError, ValueError):
row.append(str(deserialized))

# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace(
"\\", "/"
)
row.append(
huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace(
"\\", "/"
)
)
row.append(
huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
Expand Down
9 changes: 6 additions & 3 deletions test/test_flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class TestHuggingFaceDatasetSaver:
return_value=MagicMock(repo_id="gradio-tests/test"),
)
@patch("huggingface_hub.hf_hub_download")
def test_saver_setup(self, mock_download, mock_create):
@patch("huggingface_hub.metadata_update")
def test_saver_setup(self, metadata_update, mock_download, mock_create):
flagger = flagging.HuggingFaceDatasetSaver("test_token", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
Expand All @@ -60,8 +61,9 @@ def test_saver_setup(self, mock_download, mock_create):
@patch("huggingface_hub.hf_hub_download")
@patch("huggingface_hub.upload_folder")
@patch("huggingface_hub.upload_file")
@patch("huggingface_hub.metadata_update")
def test_saver_flag_same_dir(
self, mock_upload_file, mock_upload, mock_download, mock_create
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
Expand Down Expand Up @@ -89,8 +91,9 @@ def test_saver_flag_same_dir(
@patch("huggingface_hub.hf_hub_download")
@patch("huggingface_hub.upload_folder")
@patch("huggingface_hub.upload_file")
@patch("huggingface_hub.metadata_update")
def test_saver_flag_separate_dirs(
self, mock_upload_file, mock_upload, mock_download, mock_create
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
Expand Down

0 comments on commit ddba531

Please sign in to comment.