Skip to content

[Metadat utils] fix: json lines ordering. #7744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions utils/update_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import pandas as pd
from datasets import Dataset
from huggingface_hub import upload_folder
from huggingface_hub import hf_hub_download, upload_folder

from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
Expand All @@ -39,6 +39,9 @@
)


PIPELINE_TAG_JSON = "pipeline_tags.json"


def get_supported_pipeline_table() -> dict:
"""
Generates a dictionary containing the supported auto classes for each pipeline type,
Expand All @@ -57,8 +60,8 @@ def get_supported_pipeline_table() -> dict:
(class_name.__name__, "image-to-image", "AutoPipelineForInpainting")
for _, class_name in AUTO_INPAINT_PIPELINES_MAPPING.items()
]
all_supported_pipeline_classes.sort(key=lambda x: x[0])
all_supported_pipeline_classes = list(set(all_supported_pipeline_classes))
all_supported_pipeline_classes.sort(key=lambda x: x[0])

data = {}
data["pipeline_class"] = [sample[0] for sample in all_supported_pipeline_classes]
Expand All @@ -79,8 +82,24 @@ def update_metadata(commit_sha: str):
pipelines_table = pd.DataFrame(pipelines_table)
pipelines_dataset = Dataset.from_pandas(pipelines_table)

hub_pipeline_tags_json = hf_hub_download(
repo_id="huggingface/diffusers-metadata",
filename=PIPELINE_TAG_JSON,
repo_type="dataset",
)
with open(hub_pipeline_tags_json) as f:
hub_pipeline_tags_json = f.read()

with tempfile.TemporaryDirectory() as tmp_dir:
pipelines_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
pipelines_dataset.to_json(os.path.join(tmp_dir, PIPELINE_TAG_JSON))

with open(os.path.join(tmp_dir, PIPELINE_TAG_JSON)) as f:
pipeline_tags_json = f.read()

hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json
if hub_pipeline_tags_equal:
print("No updates, not pushing the metadata files.")
return

if commit_sha is not None:
commit_message = (
Expand Down