From f75828eb35a86099b489549f05c9c21241471eb4 Mon Sep 17 00:00:00 2001 From: cosmo3769 Date: Sun, 11 Feb 2024 22:57:48 +0530 Subject: [PATCH] standardize model card template t2i-sdxl --- .../text_to_image/train_text_to_image_sdxl.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 0f408e742cb4..8a5948350d50 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -48,6 +48,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -77,29 +78,33 @@ def save_model_card( image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f"![img_{i}](./image_{i}.png)\n" - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -dataset: {dataset_name} -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -inference: true ---- - """ - model_card = f""" + model_description = f""" # Text-to-image finetuning - {repo_id} -This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n +This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n {img_str} Special VAE used for training: {vae_path}. """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion-xl", + "stable-diffusion-xl-diffusers", + "text-to-image", + "diffusers", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) def import_model_class_from_model_name_or_path(