Skip to content

Commit a83cc0c

Browse files
bamps53sayakpaul
andauthored
Standardize model card for Controlnet flax (#6909)
* controlnet-flax * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent db5194a commit a83cc0c

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
FlaxUNet2DConditionModel,
4949
)
5050
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
51+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5152

5253

5354
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
@@ -145,28 +146,33 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
145146
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
146147
img_str += f"![images_{i})](./images_{i}.png)\n"
147148

148-
yaml = f"""
149-
---
150-
license: creativeml-openrail-m
151-
base_model: {base_model}
152-
tags:
153-
- stable-diffusion
154-
- stable-diffusion-diffusers
155-
- text-to-image
156-
- diffusers
157-
- controlnet
158-
- jax-diffusers-event
159-
inference: true
160-
---
161-
"""
162-
model_card = f"""
149+
model_description = f"""
163150
# controlnet- {repo_id}
164151
165152
These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
166153
{img_str}
167154
"""
168-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
169-
f.write(yaml + model_card)
155+
156+
model_card = load_or_create_model_card(
157+
repo_id_or_path=repo_id,
158+
from_training=True,
159+
license="creativeml-openrail-m",
160+
base_model=base_model,
161+
model_description=model_description,
162+
inference=True,
163+
)
164+
165+
tags = [
166+
"stable-diffusion",
167+
"stable-diffusion-diffusers",
168+
"text-to-image",
169+
"diffusers",
170+
"controlnet",
171+
"jax-diffusers-event",
172+
]
173+
model_card = populate_model_card(model_card, tags=tags)
174+
175+
model_card.save(os.path.join(repo_folder, "README.md"))
170176

171177

172178
def parse_args():

0 commit comments

Comments
 (0)