Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def save_model_card(
)

model_description = f"""
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}

<Gallery />

Expand All @@ -139,7 +139,7 @@ def save_model_card(
[Download]({repo_id}/tree/main) them in the Files & versions tab.

"""
if "playgroundai" in args.pretrained_model_name_or_path:
if "playground" in base_model:
model_description += """\n
## License

Expand All @@ -148,7 +148,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand All @@ -162,7 +162,7 @@ def save_model_card(
"lora" if not use_dora else "dora",
"template:sd-lora",
]
if "playgroundai" in base_model:
if "playground" in base_model:
tags.extend(["playground", "playground-diffusers"])
else:
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
Expand Down Expand Up @@ -206,7 +206,7 @@ def log_validation(
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)

with inference_ctx:
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if accelerator.is_main_process:
tracker_name = (
"dreambooth-lora-sd-xl"
if "playgroundai" not in args.pretrained_model_name_or_path
if "playground" not in args.pretrained_model_name_or_path
else "dreambooth-lora-playground"
)
accelerator.init_trackers(tracker_name, config=vars(args))
Expand Down