-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Flax] Add DreamBooth #1001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add DreamBooth #1001
Conversation
You are on fire @duongna21 ! 🔥 |
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good, amazing work! Just left some comments.
Let's make sure that seed
is not None
as PRNGKey
will break. Also let's update the readme to show hot to run this example. Then this should be good to merge :)
for example in tqdm( | ||
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0 | ||
): | ||
prompt_ids = pipeline.prepare_inputs(example["prompt"]) | ||
prompt_ids = shard(prompt_ids) | ||
p_params = jax_utils.replicate(params) | ||
rng = jax.random.split(rng)[0] | ||
sample_rng = jax.random.split(rng, jax.device_count()) | ||
images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images | ||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | ||
images = pipeline.numpy_to_pil(np.array(images)) | ||
|
||
for i, image in enumerate(images): | ||
hash_image = hashlib.sha1(image.tobytes()).hexdigest() | ||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" | ||
image.save(image_filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
(For future PR, maybe also enable the option to allow training the |
@patil-suraj Actually train_text_encoder has been allowed in this PR. Please check it out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @duongna21 !
Also, it seems like text_encoder
is always trained , no ? I think we should add an option called --train_text_encoder
and train it only when its True
. Because training text_encoder
is not always needed.
weight_dtype = jnp.float32 | ||
if args.mixed_precision == "fp16": | ||
weight_dtype = jnp.float16 | ||
elif args.mixed_precision == "bf16": | ||
weight_dtype = jnp.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool!
Or use the Flax implementation if you need a speedup | ||
|
||
```bash | ||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | ||
export INSTANCE_DIR="path-to-instance-images" | ||
export OUTPUT_DIR="path-to-save-model" | ||
|
||
python train_dreambooth_flax.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--learning_rate=5e-6 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
@patil-suraj No. You can ctrl + F for |
Ahh, sorry. I missed that, all looks good now, thanks a lot for working on this! |
Great work here! |
weight_dtype = jnp.bfloat16 | ||
|
||
# Load models and create wrapper for stable diffusion | ||
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, why is this version pulled from the hub and its not just using the one in the text_encoder
subfolder in args.pretrained_model_name_or_path
?
Is there something wrong with the other version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten is the PR in transformers
merged, that allows loading Flax clip with subfolder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@douwekiela Awesome! Thanks for explaining it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj yes, it's merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, @duongna21 would you like to open a PR and to update this then :) We will also need to update the installation instructions for transformers
to include that fix
What does this PR do?
Add Flax example for DreamBooth.
How to run (74% faster than PyTorch example with same args on Tesla A100)
Prompt:
a photo of sks dog
Who can review?
cc @patil-suraj @patrickvonplaten