diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e920d1a228c..96d7c3168db5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -114,7 +114,7 @@ def save_model_card( ) model_description = f""" -# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} +# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} @@ -139,7 +139,7 @@ def save_model_card( [Download]({repo_id}/tree/main) them in the Files & versions tab. """ - if "playgroundai" in args.pretrained_model_name_or_path: + if "playground" in base_model: model_description += """\n ## License @@ -148,7 +148,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community", + license="openrail++" if "playground" not in base_model else "playground-v2dot5-community", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -162,7 +162,7 @@ def save_model_card( "lora" if not use_dora else "dora", "template:sd-lora", ] - if "playgroundai" in base_model: + if "playground" in base_model: tags.extend(["playground", "playground-diffusers"]) else: tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"]) @@ -206,7 +206,7 @@ def log_validation( # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 inference_ctx = ( - contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() + contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() ) with inference_ctx: @@ -1509,7 +1509,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: tracker_name = ( "dreambooth-lora-sd-xl" - if "playgroundai" not in args.pretrained_model_name_or_path + if "playground" not in args.pretrained_model_name_or_path else "dreambooth-lora-playground" ) accelerator.init_trackers(tracker_name, config=vars(args))