Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions examples/textual_inversion/textual_inversion_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def parse_args():
default=5000,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save learned_embeds.bin every X updates steps.",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand All @@ -136,6 +142,13 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
Expand Down Expand Up @@ -420,9 +433,9 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)

# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision)

# Create sampling rng
rng = jax.random.PRNGKey(args.seed)
Expand Down Expand Up @@ -619,6 +632,12 @@ def compute_loss(params):

if global_step >= args.max_train_steps:
break
if global_step % args.save_steps == 0:
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
placeholder_token_id
]
learned_embeds_dict = {args.placeholder_token: learned_embeds}
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict)

train_metric = jax_utils.unreplicate(train_metric)

Expand Down