From 1f18854ebd2b2906caa38a719b5c562dc33addb2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 6 Mar 2023 16:56:49 +0000 Subject: [PATCH] Support revision in Flax text-to-image training. --- .../text_to_image/train_text_to_image_flax.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index ddcc8dcc4c07..8655634dfc34 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -48,6 +48,13 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--dataset_name", type=str, @@ -386,15 +393,17 @@ def collate_fn(examples): weight_dtype = jnp.bfloat16 # Load models and create wrapper for stable diffusion - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer" + ) text_encoder = FlaxCLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype ) unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype ) # Optimization