From bc52fbcd5521df4cbefecf45e9c2d429ef71047c Mon Sep 17 00:00:00 2001 From: haixinxu <33156632+haixinxu@users.noreply.github.com> Date: Mon, 23 Jan 2023 19:59:08 +0800 Subject: [PATCH 1/4] Update textual_inversion_flax.py --- .../textual_inversion/textual_inversion_flax.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 948346c4f74e..eb05bfc65b00 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -136,6 +136,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--lr_scheduler", type=str, @@ -420,9 +427,9 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision) + vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision) # Create sampling rng rng = jax.random.PRNGKey(args.seed) From 62529b26e06c65f04dbb5215461663f494c10d6a Mon Sep 17 00:00:00 2001 From: haixinxu <33156632+haixinxu@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:37:44 +0800 Subject: [PATCH 2/4] Update textual_inversion_flax.py --- examples/textual_inversion/textual_inversion_flax.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index eb05bfc65b00..7f96f245137b 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -121,6 +121,12 @@ def parse_args(): default=5000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) parser.add_argument( "--learning_rate", type=float, @@ -626,6 +632,12 @@ def compute_loss(params): if global_step >= args.max_train_steps: break + if global_step % arg.save_steps==0: + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ + placeholder_token_id + ] + learned_embeds_dict = {args.placeholder_token: learned_embeds} + jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict) train_metric = jax_utils.unreplicate(train_metric) From b163481995d2cc3f8677ee2d8bb561b17d3203ab Mon Sep 17 00:00:00 2001 From: haixinxu <33156632+haixinxu@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:42:11 +0800 Subject: [PATCH 3/4] Typo sorry. --- examples/textual_inversion/textual_inversion_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 7f96f245137b..1b3e864285cf 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -632,7 +632,7 @@ def compute_loss(params): if global_step >= args.max_train_steps: break - if global_step % arg.save_steps==0: + if global_step % args.save_steps==0: learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ placeholder_token_id ] From be5dd32f430218d0585a2f8d8a85485492334fb4 Mon Sep 17 00:00:00 2001 From: haixinxu <33156632+haixinxu@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:01:42 +0800 Subject: [PATCH 4/4] Format source --- examples/textual_inversion/textual_inversion_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 1b3e864285cf..1a980acb39e0 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -632,7 +632,7 @@ def compute_loss(params): if global_step >= args.max_train_steps: break - if global_step % args.save_steps==0: + if global_step % args.save_steps == 0: learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ placeholder_token_id ]