Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add latent caching for Dreambooth Flax. #1936

Closed
wants to merge 1 commit into from

Conversation

yasyf
Copy link
Contributor

@yasyf yasyf commented Jan 6, 2023

If enabled, pre-computes the latents from the VAE to avoid redoing the calculation every epoch. Results in a ~30% speedup on average for us.

Note the logic of DreamBoothDataset wasn't actually changed; I just re-wrote it to be a little easier to follow.

If enabled, pre-computes the latents from the VAE to avoid redoing the calculation every epoch.
@yasyf
Copy link
Contributor Author

yasyf commented Jan 6, 2023

cc @patrickvonplaten @patil-suraj

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 10, 2023

A 30% speed-up would really surprise me since a compiled VAE should be quite fast. But you're definitely right there is potential of a speed-up at the expense of more required memory.

In general, I'm fine with adding such a PR. @patil-suraj @pcuenca what do you think?

@yasyf
Copy link
Contributor Author

yasyf commented Jan 12, 2023

@patrickvonplaten it's very possible I was/am doing something else wrong, but on a v4 TPU fine-tuning with 4 images (with 300 class images) for 600 steps, adding this flag takes the time from ~5 mins to ~3.2 mins.

@patil-suraj
Copy link
Contributor

Thanks a lot for the PR. I'm also surprised that this could give 30% speed-up. Also, the goal for example scripts is to be simple and easy to follow and allow the users to customize it according to their needs. It's important to keep training loop/loss function cleaner and less cluttered, so I'm not in favour of adding caching to training scripts.

@zetyquickly
Copy link
Contributor

Why don't we add latents caching to train_dreambooth.py and to train_dreambooth_lora.py too ?

@patil-suraj
Copy link
Contributor

@zetyquickly Good question!

The goal for the example scripts is to be simple and easy to read, so any user can go through it and adapt it to their needs when needed. These are not intended to provide all features out of the box. That would make the scripts more complex and hard to understand and modify. Hope this makes sense.

@zetyquickly
Copy link
Contributor

Agree, less is better. But as long as he @yasyf already started to add this, why don't he copy-paste it along all the scripts?

LoRA version took original script as a source and now they've already async in terms of the input options with FLAX version. In the future might be the 4th script which will inherit on of these 3 and discrepancy will grow

@zetyquickly
Copy link
Contributor

zetyquickly commented Jan 27, 2023

p.s. I just came here because I was trying to find an answer why train_dreambooth doesn't have latents caching, but ended up with knowing that the other script will have it after the PR, but not the one of interest:)

@patil-suraj
Copy link
Contributor

We could add this as a new script and put it under examples/research_projects directory. @yasyf would be nice if you could move the files there.

@krahnikblis
Copy link

i've been tinkering with this.. works well to just use VAE and tokenizer in the data prep, then destroy and completely ignore them in the train step function altogether, save memory and processing time. i.e., like the dreambooth function builds and destroys the pipe to create the class images, do the same with the latents from the images - build the VAE to convert all images to latents, destroy the VAE/params and all source images. roughly 24 VAE latent outputs can fit in the space of a single 512x512 image, so it's small enough you don't need a data loader to feed the model, just build a static dict full of latents and captions, sample from it over the loops of training... until you need more than a hundred or so images, then yeah the loader makes sense... ish... but most VMs have a lot of RAM, so still you can host that stuff in ram and then device_put/replicate it as you sample it in the train loops...

@github-actions
Copy link

github-actions bot commented Mar 5, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 5, 2023
@github-actions github-actions bot closed this Mar 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants