-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add latent caching for Dreambooth Flax. #1936
Conversation
If enabled, pre-computes the latents from the VAE to avoid redoing the calculation every epoch.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
A 30% speed-up would really surprise me since a compiled VAE should be quite fast. But you're definitely right there is potential of a speed-up at the expense of more required memory. In general, I'm fine with adding such a PR. @patil-suraj @pcuenca what do you think? |
@patrickvonplaten it's very possible I was/am doing something else wrong, but on a v4 TPU fine-tuning with 4 images (with 300 class images) for 600 steps, adding this flag takes the time from ~5 mins to ~3.2 mins. |
Thanks a lot for the PR. I'm also surprised that this could give 30% speed-up. Also, the goal for example scripts is to be simple and easy to follow and allow the users to customize it according to their needs. It's important to keep training loop/loss function cleaner and less cluttered, so I'm not in favour of adding caching to training scripts. |
Why don't we add latents caching to |
@zetyquickly Good question! The goal for the example scripts is to be simple and easy to read, so any user can go through it and adapt it to their needs when needed. These are not intended to provide all features out of the box. That would make the scripts more complex and hard to understand and modify. Hope this makes sense. |
Agree, less is better. But as long as he @yasyf already started to add this, why don't he copy-paste it along all the scripts? LoRA version took original script as a source and now they've already async in terms of the input options with FLAX version. In the future might be the 4th script which will inherit on of these 3 and discrepancy will grow |
p.s. I just came here because I was trying to find an answer why train_dreambooth doesn't have latents caching, but ended up with knowing that the other script will have it after the PR, but not the one of interest:) |
We could add this as a new script and put it under |
i've been tinkering with this.. works well to just use VAE and tokenizer in the data prep, then destroy and completely ignore them in the train step function altogether, save memory and processing time. i.e., like the dreambooth function builds and destroys the pipe to create the class images, do the same with the latents from the images - build the VAE to convert all images to latents, destroy the VAE/params and all source images. roughly 24 VAE latent outputs can fit in the space of a single 512x512 image, so it's small enough you don't need a data loader to feed the model, just build a static dict full of latents and captions, sample from it over the loops of training... until you need more than a hundred or so images, then yeah the loader makes sense... ish... but most VMs have a lot of RAM, so still you can host that stuff in ram and then device_put/replicate it as you sample it in the train loops... |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
If enabled, pre-computes the latents from the VAE to avoid redoing the calculation every epoch. Results in a ~30% speedup on average for us.
Note the logic of
DreamBoothDataset
wasn't actually changed; I just re-wrote it to be a little easier to follow.