Skip to content

Commit

Permalink
fix: k-diff samplers made more stable by skipping second to last step
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstern31 committed Mar 1, 2023
1 parent e0206a9 commit d1b0343
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
5 changes: 3 additions & 2 deletions imaginairy/samplers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def sample(
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]

if orig_latent is not None:
# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
init_latent=orig_latent, t=t_start - 1, schedule=schedule, noise=noise
)
else:
noisy_latent = noise
Expand Down
12 changes: 6 additions & 6 deletions imaginairy/samplers/kdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from imaginairy.utils import get_device
from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from imaginairy.vendored.k_diffusion.sampling import get_sigmas_karras


class StandardCompVisDenoiser(CompVisDenoiser):
Expand Down Expand Up @@ -96,12 +97,17 @@ def sample(
t_start = num_steps - t_start + 1
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]

# see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666
if self.short_name in (SamplerName.K_DPM_2, SamplerName.K_DPMPP_2M, SamplerName.K_DPM_2_ANCESTRAL):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

# if our number of steps is zero, just return the initial latent
if sigmas.nelement() == 0:
if orig_latent is not None:
return orig_latent
return noise

# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = noise * sigmas[0] + orig_latent
else:
Expand Down Expand Up @@ -141,12 +147,6 @@ def callback(data):

return samples

@torch.no_grad()
def noise_an_image(self, init_latent, t, sigmas, noise=None):
if isinstance(t, int):
t = torch.tensor([t], device=get_device())
t = t.clamp(0, 1000)


class DPMFastSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_FAST
Expand Down
5 changes: 3 additions & 2 deletions imaginairy/samplers/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def sample(

old_eps = []

if orig_latent is not None:
# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
init_latent=orig_latent, t=t_start - 1, schedule=schedule, noise=noise
)
else:
noisy_latent = noise
Expand Down

0 comments on commit d1b0343

Please sign in to comment.