You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For some reason Diffusers has historically made all of the sampling methods inside-out, where the solver cannot actually access the ODE itself, and is instead fed a fixed stream of predictions one at a time. While this is fine for the Euler method and linear multi-step methods, it makes higher order single-step methods extremely challenging, and variable step methods might even be impossible.
Effectively, diffusers does
fortintimesteps:
output=model(sample, t) # This is outside of the sampler. The sampler class is inside-out.sample=sampler.step(sample, output, t)
instead of
sample=sampler.sample_model(sample, model, timesteps) # sample_model will call model(x, t) as many times as it needs to
It's not terribly difficult to rewrite the denoise loop in a pipeline but doing so for all hundred-something pipes in the library is extremely impractical.
With the recent efforts to break up the monolithic pipelines, I've found it was actually fairly easy to convert the modular denoise wrapper block to a system where the sampler is responsible for the denoise loop itself. As part of my efforts to add single-step methods to my library, I've written a small demonstrator using Flux.1 and Runge-Kutta which seems a lot more reasonable.
However, even if the code is straightforward, doing this for every model family will still become very tedious once more architectures are made modular. Therefore I think it would be significantly better to move Modular Diffusers to a system where the sampler is a functional/stateless system like you would typically expect from a differential solver, especially since the project is still young. It is easy to make a one-size-fits-all adapter to map any current inside-out sampler/scheduler into a functional one for backwards-compatibility, so to me it makes sense to migrate.
Implementation details aside, I'm wondering foremost if this is something that's even being considered? I don't know why it was done this way in the first place, as I believe just about all other diffusion applications write sampling the traditional way.
cc @yiyixuxu probably as someone who works on Modular a lot
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
For some reason Diffusers has historically made all of the sampling methods inside-out, where the solver cannot actually access the ODE itself, and is instead fed a fixed stream of predictions one at a time. While this is fine for the Euler method and linear multi-step methods, it makes higher order single-step methods extremely challenging, and variable step methods might even be impossible.
Effectively, diffusers does
instead of
It's not terribly difficult to rewrite the denoise loop in a pipeline but doing so for all hundred-something pipes in the library is extremely impractical.
With the recent efforts to break up the monolithic pipelines, I've found it was actually fairly easy to convert the modular denoise wrapper block to a system where the sampler is responsible for the denoise loop itself. As part of my efforts to add single-step methods to my library, I've written a small demonstrator using Flux.1 and Runge-Kutta which seems a lot more reasonable.
However, even if the code is straightforward, doing this for every model family will still become very tedious once more architectures are made modular. Therefore I think it would be significantly better to move Modular Diffusers to a system where the sampler is a functional/stateless system like you would typically expect from a differential solver, especially since the project is still young. It is easy to make a one-size-fits-all adapter to map any current inside-out sampler/scheduler into a functional one for backwards-compatibility, so to me it makes sense to migrate.
Implementation details aside, I'm wondering foremost if this is something that's even being considered? I don't know why it was done this way in the first place, as I believe just about all other diffusion applications write sampling the traditional way.
cc @yiyixuxu probably as someone who works on Modular a lot
sampling
module with newfunctional
submodule Beinsezii/skrample#53Beta Was this translation helpful? Give feedback.
All reactions