|
48 | 48 | FlaxUNet2DConditionModel,
|
49 | 49 | )
|
50 | 50 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
| 51 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
51 | 52 |
|
52 | 53 |
|
53 | 54 | # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
|
@@ -145,28 +146,33 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
145 | 146 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
146 | 147 | img_str += f"\n"
|
147 | 148 |
|
148 |
| - yaml = f""" |
149 |
| ---- |
150 |
| -license: creativeml-openrail-m |
151 |
| -base_model: {base_model} |
152 |
| -tags: |
153 |
| -- stable-diffusion |
154 |
| -- stable-diffusion-diffusers |
155 |
| -- text-to-image |
156 |
| -- diffusers |
157 |
| -- controlnet |
158 |
| -- jax-diffusers-event |
159 |
| -inference: true |
160 |
| ---- |
161 |
| - """ |
162 |
| - model_card = f""" |
| 149 | + model_description = f""" |
163 | 150 | # controlnet- {repo_id}
|
164 | 151 |
|
165 | 152 | These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
|
166 | 153 | {img_str}
|
167 | 154 | """
|
168 |
| - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
169 |
| - f.write(yaml + model_card) |
| 155 | + |
| 156 | + model_card = load_or_create_model_card( |
| 157 | + repo_id_or_path=repo_id, |
| 158 | + from_training=True, |
| 159 | + license="creativeml-openrail-m", |
| 160 | + base_model=base_model, |
| 161 | + model_description=model_description, |
| 162 | + inference=True, |
| 163 | + ) |
| 164 | + |
| 165 | + tags = [ |
| 166 | + "stable-diffusion", |
| 167 | + "stable-diffusion-diffusers", |
| 168 | + "text-to-image", |
| 169 | + "diffusers", |
| 170 | + "controlnet", |
| 171 | + "jax-diffusers-event", |
| 172 | + ] |
| 173 | + model_card = populate_model_card(model_card, tags=tags) |
| 174 | + |
| 175 | + model_card.save(os.path.join(repo_folder, "README.md")) |
170 | 176 |
|
171 | 177 |
|
172 | 178 | def parse_args():
|
|
0 commit comments