Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement time step discretization for Karras samplers #23

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Birch-san
Copy link

@Birch-san Birch-san commented Aug 31, 2022

Improves support for diffusion models with discrete time-steps (such as Stable Diffusion's DDIM).
I have some questions though, so this may need some iterating.

The user would invoke like so:

from k_diffusion.sampling import sample_heun, get_sigmas_karras, make_quantizer

sigmas = get_sigmas_karras(
    n=opt.steps,
    # 0.0292
    sigma_min=model_k_wrapped.sigmas[0].item(),
    # 14.6146
    sigma_max=model_k_wrapped.sigmas[-1].item(),
    rho=7.,
    device=device,
    # zero would be smaller than sigma_min
    concat_zero=False
)

samples = sample_heun(
    model_k_config,
    x,
    sigmas,
    extra_args=extra_args,
    decorate_sigma_hat=make_quantizer(model_k_wrapped.sigmas))

Implements the change to "Algorithm 2, line 5" described in the Elucidating paper arXiv:2206.00364 section C.3.4 "iDDPM practical considerations" practical challenge 3.
In other words we round sigmas to the nearest sigma supported by the DDIM.
For your convenience, here's the sigmas supported by Stable Diffusion DDIM:
https://gist.github.com/Birch-san/6cd1574e51871a5e2b88d59f0f3d4fd3

image

image

You may be wondering "okay, rounding sigma_hat solves challenge 3, but what about challenge 2".
There's an argument that solving challenge 3, solves challenge 2 for some situations.
When gamma == 0, rounding sigma_hat is equivalent to rounding sigma (which is what challenge 2 requires you to do for any outputs of get_sigmas_karras()).
Problem here is the final sigma we'll receive, 0. we probably don't want to apply the same rounding rules to that… especially because we have a special-case predicated on 0. should that be predicated on uargmin instead, or perhaps on "have we reached the final sigma?"
edit: maybe the only reason they special-case 0 is because they want to avoid dividing by zero?

If we do care about satisfying challenge 2 in the gamma > 0 situation, we'd want to round-to-nearest-sigma what comes out of get_sigmas_karras(). I happen to have made a torch snippet for running argmin on every element returned by get_sigmas_karras() simultaneously:

sigmas = model_k_wrapped.sigmas[torch.argmin((sigmas.reshape(len(sigmas), 1).repeat(1, len(model_k_wrapped.sigmas)) - model_k_wrapped.sigmas).abs(), dim=1)]

But again, not sure of what the implications are for the 0 it returns.
Anyway, maybe we can look at the outputs to decide. We'll try with keeping the 0 and without.

I tried to stress this to its limits by using as few steps as I could manage before it looked bad. All images are:

Heun, 7 steps

Excluding 0 from get_sigmas_karras()

The better-looking result was when I excluded the 0 returned by get_sigmas_karras(), in favour of ramping for 1 more step.

Recall that SD's sigmas run from max = 14.6146 to min = 0.0292.

Sigmas returned by get_sigmas_karras():
sample_heun only iterates to n-1, so never touches the 0.0292.

['14.6146', '7.9029', '4.0277', '1.9104', '0.8289', '0.3211', '0.1072', '0.0292']

Time-step discretization enabled

Sigmas (up to n-1) after discretization:

['14.6146', '7.9216', '4.0300', '1.9103', '0.8299', '0.3213', '0.1072']

00349 s68673924 n0 i0_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun7_kns_dcrt_nz

Original k-diffusion behaviour (no discretization)

00352 s68673924 n0 i0_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun7_kns_nz

Not much perceptible difference. The discrete one defines the far sleeve better, but the other subtle differences it's hard for me to say which is the better generation.

Keeping 0 from get_sigmas_karras()

So the paper didn't mention this, but the result is terrible at low step counts if you actually implement the 0 as they describe. Maybe this is just a problem for discrete time models?

Sigmas returned by get_sigmas_karras():
sample_heun only iterates to n-1, so never touches the 0.

['14.6146', '7.0944', '3.1686', '1.2741', '0.4469', '0.1303', '0.0292', '0.0000']

Time-step discretization enabled

Sigmas (up to n-1) after discretization:

['14.6146', '7.0796', '3.1667', '1.2721', '0.4471', '0.1308', '0.0292']

00350 s68673924 n0 i0_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun7_kns_dcrt

Original k-diffusion behaviour (no discretization)

00351 s68673924 n0 i0_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun7_kns

Slightly more perceptible difference. The discrete one did better on the eyes and has slightly more clothing definition.

Conclusion

Removing the "concat 0" from get_sigmas_karras() seems to be hugely beneficial for small numbers of steps. This is not backed up by the literature. The reason I tried this was due to a misunderstanding. I saw that if I discretized the whole schedule, I'd end up with a repeated uargmin (… 0.0292, 0.0292]). I removed the concat 0 to ensure I didn't end up producing duplicates. I didn't realize though that the sampler stops at n-1 so repeats aren't actually a problem. But it seems that for a different reason, the results are far better.

Discretization of time-steps doesn't have the dramatic impact I was hoping for, but is probably still a sensible thing to do on the basis that the paper recommended it.

Heun, 50 steps, excluding 0

Let's do one more example, to 50 steps

Time-step discretization enabled

grid-0194 s68673924_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun50_kns_dcrt_nz

Original k-diffusion behaviour (no discretization)

grid-0195 s68673924_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night_heun50_kns_nz

Discretization seems to be more noticeable over 50 steps. The discretized image seems to have sharper hair and clothing, and highlights are brighter. Not sure I could say which is "better" though.

It's hard to compare images scrolling on GitHub; personally I flicked between these using QuickLook in the Finder.
If you know a better way to evaluate whether this is an improvement: I'm all ears! 👂

(as described in the Elucidating paper arXiv:2206.00364 section C.3.4 "practical challenge" 3).
Add also a way to opt-out of receiving a zero in the Karras noise schedule (makes less sense when discretizing, because 0 can be out-of-range -- i.e. lower than sigma_min -- and you'd round the result back up to sigma_min again).
And for what it's worth, I moved the .to(device) to happen a little earlier in get_sigmas_karras(), on the basis that the other get_sigmas_* functions were happy to move to device *before* appending zero.
@Birch-san
Copy link
Author

Birch-san commented Aug 31, 2022

d'oh, I just noticed you already implemented a quantize boolean on ingress to the model:
https://github.com/crowsonkb/k-diffusion/blob/3f93c28088890e4d6bc593072739e1d6e759b392/k_diffusion/external.py
and certainly that approach gives me other ideas about how to factor this (i.e. make use of the DiscreteSchedule class rather than doing everything in the sampler).

our solutions are equivalent when gamma == 0
but when gamma > 0 though, I think your quantize=True diverges slightly from the paper.

the paper says to quantize sigma_hat after line 5, such that it impacts the computation of x on line 6.

whereas currently k-diffusion passes into the model:

  • an x computed from a non-discretized sigma_hat
  • a discretized sigma_hat

image

@crowsonkb
Copy link
Owner

So the paper didn't mention this, but the result is terrible at low step counts if you actually implement the 0 as they describe. Maybe this is just a problem for discrete time models?

You need the 0 on the end so the sampler outputs a fully denoised image, the ODE needs to be integrated from sigma_max to 0 for this to happen. I think the thing you are observing happens because sigma_min (the last noise level the model is evaluated at) is too low for low step counts. Have you tried increasing sigma_min instead, but keeping the concatenation of 0?

@Birch-san
Copy link
Author

Birch-san commented Sep 1, 2022

thanks very much @crowsonkb for explaining the importance of the 0!

okay, so we need to keep the 0. but ramping all the way down to sigma_min inclusive isn't the best use of our limited sigmas.
looking at the next-lowest sigma, the successful picture sampled 0.1072. the unsuccessful picture sampled 0.0292.

so one idea is to formalize the wacky way from which that 1.072 was computed, so we can intentionally use it as our sigma_min.

the 1.072 can be obtained like this:

steps=7
get_sigmas_karras(
  # there's an argument that steps+1 is wacky, so let's remember to try without the +1 too
  n=steps+1,
  # 14.6146
  sigma_max=model.sigmas[-1].item(),
  # 0.0292
  sigma_min=model.sigmas[0].item(),
  rho=7.
)[-3] # skip nth because it's 0, skip n-1th because it's the known-bad sigma_min

or more efficiently like this:

# gets the N-1th sigma from a Karras noise schedule
def get_awesome_sigma_min(
    steps: int,
    sigma_max: float,
    sigma_min_nominal: float,
    rho: float
) -> float:
    min_inv_rho = sigma_min_nominal ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    ramp = (steps-2) * 1/(steps-1)
    sigma_min = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return sigma_min

steps=7
# 14.6146
sigma_max=model.sigmas[-1].item()
sigma_min = get_awesome_sigma_min(
    steps=steps+1,
    sigma_max=sigma_max,
    # 0.0292
    sigma_min_nominal=model.sigmas[0].item(),
    rho=7.
)

having computed a new sigma_min 0.1072 using steps+1 (a bit arbitrary but matches my original experiment),

we call the (unmodified) get_sigmas_karras() the normal way, with our new sigma_min:

sigmas = get_sigmas_karras(
    n=opt.steps,
    sigma_min=sigma_min,
    sigma_max=sigma_max,
    rho=rho,
)

it returns the following noise schedule, identical to our first experiment except ending with 0 instead of 0.0292:

[14.6146, 7.9029, 4.0277, 1.9104, 0.8289, 0.3211, 0.1072, 0.0000]

the sigma_hats get discretized to these before being passed into the model:

[14.6146, 7.9216, 4.0300, 1.9103, 0.8299, 0.3213, 0.1072]

Picture still looks good (the power of the 0.1072 sigma, probably):

00601 s68673924_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night__ 14 6146, 7 9029, 4 0277, 1 9104, 0 8289, 0 3211, 0 1072, 0 0000 _heun7_kns_dcrt

now let's try simplifying that ugly get_awesome_sigma_min(steps=steps+1 to just steps=steps.

new sigmas out of the oven. spends more time on the middle sigmas. "more time in the middle" sounds closer to the behaviour of DiscreteSchedule, based on my comparison in this thread:
https://twitter.com/Birchlabs/status/1565114066548527104

[14.6146, 8.0451, 4.1888, 2.0392, 0.9141, 0.3692, 0.1303, 0.0000]

the sigma_hats discretize to:

[14.6146, 8.0461, 4.1878, 2.0402, 0.9136, 0.3687, 0.1308]

Picture still looks good:

00602 s68673924_masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night__ 14 6146, 8 0451, 4 1888, 2 0392, 0 9141, 0 3692, 0 1303, 0 0000 _heun7_kns_dcrt

We got a new floral pattern on the sleeve! plus some new hair detail.
though the eyebrows, eyesockets, irises and shading on facial bones suffered a bit. probably not good to lose so many low sigmas.

steps+1 probably better for keeping the details we care about.

but overall, keeping 0 seems to make this a nicer algorithm than we started with.

@Birch-san
Copy link
Author

Birch-san commented Sep 1, 2022

so, I think we don't need the concat_zero change from this PR. since zeroes weren't the problem, wasteful use of the low sigma, sigma_min, was the problem.
the same schedule can be computed without changes to k-diffusion. unless you wanted to accept a get_awesome_sigma_min helper under "wacky experiments". 😛

but I still think two problems remain regarding adhering to the paper:

  • to comply with practical challenge 3:
    we need to discretize sigma_hat earlier, so that x can be computed from discretized sigma_hat
    • the current quantize=True quantizes too late (i.e. on model ingress)
    • the contribution in the PR quantizes at the right time, but a better way to factor the code would be to utilise your DiscreteSchedule abstraction somehow
  • to comply with practical challenge 2:
    we need a way to discretize the sigmas (except 0) that come out of get_sigmas_karras
    • the current quantize=True achieves the same effect, but only when gamma=0 (i.e. when sigma_hat == sigma)
    • one way to do this would be to create a DiscreteSchedule#get_sigmas_karras method, which delegates to get_sigmas_karras and applies an argmin afterward.

@crowsonkb
Copy link
Owner

If you do model.t_to_sigma(model.sigma_to_t(sigma)) inside the sampler you can get the quantized sigma... but you can't count on those methods being there because the user could just pass in any arbitrarily wrapped model. I'm not really sure what to do tbh.

@Birch-san
Copy link
Author

Birch-san commented Sep 1, 2022

an older factoring of the code that I tried was to expose a quanta parameter:
Birch-san@8ebcd00
but feels kinda like a failure to make use of the model-wrapping idiom.

given that discrete sigmas are a DiscreteSchedule concern, it feels like DiscreteSchedule should provide peers of all karras sampler methods. like DiscreteSchedule#sample_heun.

DiscreteSchedule#sample_heun would forward calls to sample_heun. or perhaps DiscreteSchedule#sample_heun and sample_heun would mutually forward calls to a private function, _sample_heun() (which could expose an optional quantize callback).

factoring out a common core might not be the craziest thing to do, since sample_euler() and sample_heun() both have a lot of code they could share.

@crowsonkb
Copy link
Owner

crowsonkb commented Sep 1, 2022

The samplers are supposed to be independent of the models, though, that would duplicate a ton of code and I might add new samplers later etc. Is there some reasonable way to guarantee that a wrapper class has all the required methods? The usual idiom here is subclassing but that doesn't really work with the wrapper idiom...

@Birch-san
Copy link
Author

you mean a way to sniff model to see if it provides a way to quantize?

if we're ruling out "checking if it extends a class/mixin", then I guess that leaves "check whether it has a particular method decorated with a decorator you provide"?

@crowsonkb
Copy link
Owner

crowsonkb commented Sep 1, 2022

Maybe there could be a model wrapper class that has all of the methods that the samplers etc. expect, and the default implementation of these methods just forwards to the wrapped model, and users could override these methods to customize the behavior. That is, all model wrappers would subclass this and override methods, maybe just forward() but they could also alter the other methods if they did something more complicated.

@Birch-san
Copy link
Author

yes, that would be a good way to do it.

if you're a continuous-time model, you don't want to quantize sigma_hat at all.

so maybe a new base model wrapper class would be introduced (from which VDenoiser and DiscreteSchedule would then inherit).

the base model wrapper class would have a decorate_karras_sigma_hat(tensor: Tensor) -> Tensor which would just be the identity function. DiscreteSchedule would override this to (if quantize=True) quantize the tensor.

@crowsonkb
Copy link
Owner

crowsonkb commented Sep 1, 2022

I need to think about which methods to make standard on the wrapper...

  • forward()/__call__() obviously
  • sigma_to_t() and t_to_sigma() (for k-diffusion native models these can just be the identity function). If you have these two methods you can quantize by going from sigma to t and back.
  • loss() probably.

Maybe forward get_sigmas() if it exists on the inner model, and don't forward it if it doesn't? Or just make the base model class have this method but raise NotImplementedError.

Oh! Maybe add encode() and decode() methods that are the identity function by default but which, for latent diffusion models, encode/decode using the autoencoder.

Maybe also have a sigma_min and a sigma_max property for easy access to the valid timestep range.

@Birch-san
Copy link
Author

disclaimer: my design patterns are based on Java experience, not Python.

I'd start by only implementing stuff that you actually have a user for.
easy to add more later once the requirement is discovered. very hard to take back something once it's shipped.

I'd start from "who consumes a base class?".
your samplers will. they currently only do one thing with model: __call__().
and we'll want at least one new capability to help us discretize sigmas. maybe that's sigma_to_t() + t_to_sigma(), as you say. but whilst it's good composition, it's anti-performance (it prevents simplifying it to just a one-line argmin). so I'd say there's a performance case for a quantize_sigma() method. this could be provided in addition to the two other functions of course.

will end-users consume this base class? I don't know a use-case that would mean they'd ever see the base class.
I, for example, construct a CompVisDenoiser(model). I know the subclass, so I don't need the base class to be descriptive.
the only reason I'd lose this information is if I'm doing something like a strategy factory, to pick a wrapper at runtime. I don't think anybody would actually do this…

another consideration vis-à-vis forcing subclasses to adhere to the same method signatures… we already see some divergence here; DiscreteSchedule's sigma_to_t() has an additional quantize: bool. it's compatible, but would enforcing method signatures from a base class restrict your design options in future (e.g. if one of the subclasses needs some bespoke, additional params on a method)?
not sure how polymorphism works in Python (i.e. whether it's still considered "overriding" a function if the method signatures are different but compatible).
I guess you can get around anything by picking a lowest common denominator, and spreading **kwargs to make it extensible.

Maybe forward get_sigmas() if it exists on the inner model, and don't forward it if it doesn't? Or just make the base model class have this method but raise NotImplementedError.

of the choices, I prefer raise NotImplementedError (principle of least astonishment — better to be limited and clear than to try and be helpful in an unpredictable way).
if it's a method that's mandatory, but for which no sensible default implementation can be provided: abstract methods are a good way to force the subclass to make the decision of how to implement instead (but yeah that could just be raise NotImplementedError).

Oh! Maybe add encode() and decode() methods that are the identity function by default but which, for latent diffusion models, encode/decode using the autoencoder.

This might be another situation where — for performance reasons — it would be good to support roundtrip() (which as you say could for some models be the identity function) rather than forcing to go through encode(decode()).

Maybe also have a sigma_min and a sigma_max property for easy access to the valid timestep range.

hmm I guess that's something I'd use (I'm currently resorting to model_k_wrapped.sigmas[-1].item()) but feels like it's something that only discrete timestep models would need?
so maybe it wouldn't go as low as the base class, but rather into a mixin or superclass that discrete models inherit?

Birch-san added a commit to Birch-san/stable-diffusion that referenced this pull request Sep 2, 2022
…ent equivalent/better ramp (--end_karras_ramp_early) without requiring a custom fork of k-diffusion crowsonkb/k-diffusion#23 (comment)
Birch-san added a commit to Birch-san/stable-diffusion that referenced this pull request Sep 3, 2022
…ucting CompVisDenoiser with quantize=True. this means we don't need a custom fork of k-diffusion (except for for MPS fixes). only downside compared to my original approach is that we cannot set churn>0 (see crowsonkb/k-diffusion#23 (comment)), but we never used that. I reckon the ability to quantize sigma_hat will be added to mainline k-diffusion eventually (discussing here: crowsonkb/k-diffusion#23 (comment)), so think it's best to keep the k-diffusion branch free of bespoke changes (with the exception of MPS), to keep it easy to rebase onto mainline. remove ability to opt in/out of discretization, now that I've finished comparing them (crowsonkb/k-diffusion#23) -- the difference is barely perceptible but discretization is the better choice in theory.
@crowsonkb
Copy link
Owner

crowsonkb commented Sep 7, 2022

I was thinking about something along the lines of the following:

class BaseModelWrapper(nn.Module):
    """The base wrapper class for the k-diffusion model wrapper idiom. Model
    wrappers should subclass this class and customize the behavior of the
    wrapped model by implementing or overriding methods."""

    def __init__(self, inner_model):
        super().__init__()
        self.inner_model = inner_model

    def __dir__(self):
        return list(set(super().__dir__() + dir(self.inner_model)))

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.inner_model, name)

    def forward(self, *args, **kwargs):
        return self.inner_model(*args, **kwargs)

I'm not sure where the standard methods should go yet, on this base wrapper or separately implemented on the different denoiser wrappers, which would be changed to be subclasses of this.

@Birch-san
Copy link
Author

Birch-san commented Sep 25, 2022

hey, sorry for slow reply.

okay, so this wrapper creates the illusion that the wrapped instance is a subtype of inner_model. seems reasonable (we want the wrapped instance to be a substitute that can be used in all the same situations).

as for where the standard methods (e.g. sigma_to_t()) should go...
I guess it depends what this base wrapper claims its responsibilities are.

if it's a "general model wrapper" (i.e. nothing to do with diffusion, but perhaps with generic responsibilities like logging), then I wouldn't put sigma_to_t() this low.

it it's a "diffusion model wrapper" (and I assume it is), then I think it makes sense to put sigma_to_t() this low if (and only if) that's something that every diffusion model needs.

if there's a one-size-fits-all implementation of sigma_to_t() that can be put here, it can go here.
if "it depends", then I think it should be an abstract method. the base wrapper forces subclasses to provide an implementation.

generally, the decision of "should I put sigma_to_t() -- or at least an abstract interface for it -- this low", is answered by "what model type will the samplers integrate against?"
if the samplers expect the user to pass in an instance of BaseModelWrapper, then sigma_to_t() needs to be on BaseModelWrapper, or the sampler needs to be prepared to sniff the model instance and look for a more specific subclass. it's preferable to avoid that.

@crowsonkb
Copy link
Owner

hey, sorry for slow reply.

okay, so this wrapper creates the illusion that the wrapped instance is a subtype of inner_model. seems reasonable (we want the wrapped instance to be a substitute that can be used in all the same situations).

as for where the standard methods (e.g. sigma_to_t()) should go... I guess it depends what this base wrapper claims its responsibilities are.

if it's a "general model wrapper" (i.e. nothing to do with diffusion, but perhaps with generic responsibilities like logging), then I wouldn't put sigma_to_t() this low.

it it's a "diffusion model wrapper" (and I assume it is), then I think it makes sense to put sigma_to_t() this low if (and only if) that's something that every diffusion model needs.

It is yeah.

if there's a one-size-fits-all implementation of sigma_to_t() that can be put here, it can go here. if "it depends", then I think it should be an abstract method. the base wrapper forces subclasses to provide an implementation.

For a k-diffusion native model, sigma(t) = t, so the default implementation can simply return its input, same for t_to_sigma().

@Birch-san
Copy link
Author

okay sure, then yes: let's put a default implementation (identity function) of sigma_to_t() and t_to_sigma() in the base diffusion model wrapper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants