Skip to content

Conversation

duongna21
Copy link
Contributor

What does this PR do?

Add Flax example for DreamBooth.

How to run (74% faster than PyTorch example with same args on Tesla A100)

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

python train_dreambooth_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="a photo of sks dog" \
  --class_prompt="a photo of dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --learning_rate=5e-6 \
  --num_class_images=200 \
  --max_train_steps=800

Prompt: a photo of sks dog

ảnh

Who can review?

cc @patil-suraj @patrickvonplaten

@patil-suraj
Copy link
Contributor

You are on fire @duongna21 ! 🔥

@patil-suraj patil-suraj self-assigned this Oct 26, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good, amazing work! Just left some comments.

Let's make sure that seed is not None as PRNGKey will break. Also let's update the readme to show hot to run this example. Then this should be good to merge :)

Comment on lines +383 to +398
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
):
prompt_ids = pipeline.prepare_inputs(example["prompt"])
prompt_ids = shard(prompt_ids)
p_params = jax_utils.replicate(params)
rng = jax.random.split(rng)[0]
sample_rng = jax.random.split(rng, jax.device_count())
images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(np.array(images))

for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

@patil-suraj
Copy link
Contributor

patil-suraj commented Oct 27, 2022

(For future PR, maybe also enable the option to allow training the text_encoder this has been found to improve results significantly)

@duongna21
Copy link
Contributor Author

(For future PR, maybe also enable the option to allow training the text_encoder this has been found to improve results significantly)

@patil-suraj Actually train_text_encoder has been allowed in this PR. Please check it out.
Also thank you for other very helpful comments. Addressed them!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @duongna21 !

Also, it seems like text_encoder is always trained , no ? I think we should add an option called --train_text_encoder and train it only when its True. Because training text_encoder is not always needed.

Comment on lines +447 to +451
weight_dtype = jnp.float32
if args.mixed_precision == "fp16":
weight_dtype = jnp.float16
elif args.mixed_precision == "bf16":
weight_dtype = jnp.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool!

Comment on lines +61 to +75
Or use the Flax implementation if you need a speedup

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"

python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

@duongna21
Copy link
Contributor Author

Also, it seems like text_encoder is always trained?

@patil-suraj No. You can ctrl + F for if args.train_text_encoder in the script and check if anything is wrong.

@patil-suraj
Copy link
Contributor

Ahh, sorry. I missed that, all looks good now, thanks a lot for working on this!

@patil-suraj patil-suraj merged commit 90f91ad into huggingface:main Oct 27, 2022
@duongna21 duongna21 deleted the add-dreambooth-flax branch October 27, 2022 14:19
@patrickvonplaten
Copy link
Contributor

Great work here!

weight_dtype = jnp.bfloat16

# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, why is this version pulled from the hub and its not just using the one in the text_encoder subfolder in args.pretrained_model_name_or_path?

Is there something wrong with the other version?

Copy link
Contributor Author

@duongna21 duongna21 Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skirsten Nice question. Look at this please.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten is the PR in transformers merged, that allows loading Flax clip with subfolder ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@douwekiela Awesome! Thanks for explaining it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj yes, it's merged

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, @duongna21 would you like to open a PR and to update this then :) We will also need to update the installation instructions for transformers to include that fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants