Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idea for Improvement #40

Open
bonlime opened this issue May 25, 2023 · 8 comments
Open

Idea for Improvement #40

bonlime opened this issue May 25, 2023 · 8 comments

Comments

@bonlime
Copy link

bonlime commented May 25, 2023

Hey, I've read through your papers and I like the idea of token merging. I've experimented a little bit with applications to Stable Diffusion and found one potential source of improvement. Currently you're predefining the % of tokens merged which allows to exactly predict the speedup with TOME, but it also leads to always merging some tokens, even when they are not similar enough. I've plotted mean similarity of merged tokens for multiple images and got this (one point is a mean from single layer). I've set the r=0.2. It could be seen that sometimes the 20% quantile still merges dissimilar tokens, the plot gets even worse if you set r=0.5.
unknown

Instead why not to predefine the max similarity for tokens to be merged, with some reasonable threshold to avoid merging too many tokens together (because sometimes >90% of tokens have > 0.95 similarity). And merge based on that? it would lead to adaptive "merging ratio", where similar activations get merged much more than dissimilar ones, for example on my measurements setting merging threshold >0.9 would still merge ~40% of tokens on average, but leads to less degradation. But i haven't conducted the proper comparison yet.

@dbolya
Copy link
Owner

dbolya commented May 26, 2023

Hi @bonlime, thanks for the suggestion!

Actually this is something that I considered during the development of the original ToMe, and it's in fact what other methods do already (e.g., adaptive ViT).

The main problem why I decided to use a fixed number of tokens to merge is because of batching. If you merge a different number of tokens in each element of a batch (which you would do by enforcing a threshold), each image would end up with a different number of tokens and thus couldn't be batched together anymore.

Now this was a huge problem for ViTs, which uses batch sizes of 256+ where running the model on each image one by one is extremely expensive, but for stable diffusion where you only have one image it should be okay, right?

Well, if you can live with running with a batch size of 1 (which is pretty common for the big images ToMe benefits from, so that's okay), there's still another problem: stable diffusion doesn't generate just 1 image. Even if you set the batch size to 1, stable diffusion uses "classifier-free guidance": during the denoising process, it generates 2 images---one with the prompt and one without. Then, it subtracts the two and multiplies the difference by the "cfg scale", which is what then gets added to the current noise.

Thus, actually we're always using a batch size of at least 2. So the simple thing of just setting a threshold won't work, because using a batch size of 1 there is probably slower even if you add ToMe into the mix. But, maybe there's a more clever way of doing it where you set the threshold for one of the images, and use that to determine the "r" value, and then apply that same "r" value to every element in the batch.

If you want to try to implement something like that, go for it and let me know how it goes! I think it could work in the case where you're only generating one large image (where your effective batch size is 2).

@bonlime
Copy link
Author

bonlime commented May 26, 2023

@dbolya I like that you instantly noticed the flaws of such approach :) And thanks for the reference to the Adaptive VIT, haven't seen it yet. You're right that it would not work for generation of large batches, but it would work pretty well for BS=1. Even in the presence of CFG from my experiments/observations, the number of similar tokens above threshold is very close in both cond and uncond predictions, so you could determine the r as min of r_cond and r_uncond. I'm working on the implementation, would get back if it works.

In the meantime I have another question. I have a hard requirement of being able to torch.compile everything (because our prod env has high enough load to care about inference speed). And your current implementation could not be compiled because it can't accept generator as an input arg (for some reasons). What do you think of a following idea: instead of storing args["generator"] to store args["counter"] which would be incremented by each call to bipartite_soft_matching_random2d and then inside we would select the pixel with index counter % (sx * sy) from every patch? It would still have randomness in the sense that different layers are guaranteed to merge different tokens, but do you think it would be sufficient, or randomness inside a single feature map index selection is also required?

@dbolya
Copy link
Owner

dbolya commented May 26, 2023

I'm working on the implementation, would get back if it works.

Looking forward to it!


What do you think of a following idea

I think that's a very reasonable workaround. Though the randomness within a layer probably does matter. I think I tried an experiment where I chose a random index for the entire batch, and used that same index in each patch---a bit like your counter idea but random instead of incrementing. I found that this was worse than the all random approach, so I axed it. But maybe some cheap pseudorandom implementation would work, like using an arange and multiply by some number and mod by some other number (with the counter controlling the "seed"). It probably doesn't need very strong randomness.

Fyi, the generator bit was there just to fix hugging face diffusers implementations that didn't seed everything properly---thus generating the same image with tome would drastically change it. So if you're always using tome and you don't care what the image looks like without tome you can just remove the generator argument (which uses the current global generator).

The randomness has caused a lot of problems though, so I'm totally in favor of ditching it for something deterministic that works just as well.

@fingerk28
Copy link

Hi @dbolya , thank you for the discussion; I have benefited greatly. I would like to ask why you mentioned that ToMe might not be able to accelerate when the effective batch size is 1?

@dbolya
Copy link
Owner

dbolya commented Mar 22, 2024

Hi @dbolya , thank you for the discussion; I have benefited greatly. I would like to ask why you mentioned that ToMe might not be able to accelerate when the effective batch size is 1?

GPUs are massively parallel devices with tons of cuda / tensor cores. If those cores aren't saturated, then the GPU is not the bottleneck for speed. And if the GPU is not the bottleneck, how can we hope to speed up the model by reducing the amount of GPU compute? Increasing the image size or batch size gives the GPU more work, making it more likely that it is the bottleneck.

@andupotorac
Copy link

andupotorac commented Jun 3, 2024

@bonlime Would also be interested in your approach where the emerging is automatic for the best output, even if the batch size is 1 (+1). Did you manage to implement your approach and did it work?

We're thinking of adding it as an option to this algo, from A-ViT, at some point. Wondering how much it would improve it overall / or maintain the quality of the images unchanged.

@bonlime
Copy link
Author

bonlime commented Jun 6, 2024

@andupotorac hey, it was a long time ago when i experimented with this. We ended up not using ToMe due to problems with using torch.compile for it. I don't have any public implementation of the idea mentioned, but i remember that the change was pretty trivial, just adding something like torch.quantile in few places in code instead of using a hard-coded values

@andupotorac
Copy link

@andupotorac hey, it was a long time ago when i experimented with this. We ended up not using ToMe due to problems with using torch.compile for it. I don't have any public implementation of the idea mentioned, but i remember that the change was pretty trivial, just adding something like torch.quantile in few places in code instead of using a hard-coded values

Thanks, we'll look into it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants