diff --git a/ldm_patched/contrib/external_custom_sampler.py b/ldm_patched/contrib/external_custom_sampler.py index 8f92e841f..985b03a0a 100644 --- a/ldm_patched/contrib/external_custom_sampler.py +++ b/ldm_patched/contrib/external_custom_sampler.py @@ -230,6 +230,25 @@ def get_sampler(self, eta, s_noise, r, noise_device): sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) + +class SamplerTCD: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta=0.3): + sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta}) + return (sampler, ) + + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -292,6 +311,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerTCD": SamplerTCD, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } diff --git a/ldm_patched/contrib/external_model_advanced.py b/ldm_patched/contrib/external_model_advanced.py index 03a2f0454..9b52c36b5 100644 --- a/ldm_patched/contrib/external_model_advanced.py +++ b/ldm_patched/contrib/external_model_advanced.py @@ -70,7 +70,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm"],), + "sampling": (["eps", "v_prediction", "lcm", "tcd"]), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -90,6 +90,9 @@ def patch(self, model, sampling, zsnr): elif sampling == "lcm": sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "tcd": + sampling_type = ldm_patched.modules.model_sampling.EPS + sampling_base = ModelSamplingDiscreteDistilled class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/ldm_patched/k_diffusion/sampling.py b/ldm_patched/k_diffusion/sampling.py index 761c2e0ef..d1bc1e4b2 100644 --- a/ldm_patched/k_diffusion/sampling.py +++ b/ldm_patched/k_diffusion/sampling.py @@ -752,7 +752,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n return x - @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ @@ -808,3 +807,30 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non d_prime = w1 * d + w2 * d_2 + w3 * d_3 x = x + d_prime * dt return x + + +@torch.no_grad() +def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + model_sampling = model.inner_model.inner_model.model_sampling + timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu() + timesteps_s[-1] = 0 + alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s] + beta_prod_s = 1 - alpha_prod_s + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample + eps = (x - denoised) / sigmas[i] + denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps + + if callback is not None: + callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + + x = denoised + if eta > 0 and sigmas[i + 1] > 0: + noise = noise_sampler(sigmas[i], sigmas[i + 1]) + x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt() + + return x \ No newline at end of file diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index f39e275d3..57f51a000 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -50,17 +50,17 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps self.linear_start = linear_start self.linear_end = linear_end - # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 self.set_sigmas(sigmas) + self.set_alphas_cumprod(alphas_cumprod.float()) def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas) self.register_buffer('log_sigmas', sigmas.log()) + def set_alphas_cumprod(self, alphas_cumprod): + self.register_buffer("alphas_cumprod", alphas_cumprod.float()) + @property def sigma_min(self): return self.sigmas[0] diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 1f69d2b10..35cb3d738 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/modules/async_worker.py b/modules/async_worker.py index d8a1e072d..0d0001f15 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -773,19 +773,19 @@ def handler(async_task): final_sampler_name = sampler_name final_scheduler_name = scheduler_name - if scheduler_name == 'lcm': + if scheduler_name in ['lcm', 'tcd']: final_scheduler_name = 'sgm_uniform' if pipeline.final_unet is not None: pipeline.final_unet = core.opModelSamplingDiscrete.patch( pipeline.final_unet, - sampling='lcm', + sampling=scheduler_name, zsnr=False)[0] if pipeline.final_refiner_unet is not None: pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch( pipeline.final_refiner_unet, - sampling='lcm', + sampling=scheduler_name, zsnr=False)[0] - print('Using lcm scheduler.') + print(f'Using {scheduler_name} scheduler.') async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)]) diff --git a/modules/flags.py b/modules/flags.py index c9d13fd81..c6d1fad79 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -34,7 +34,8 @@ "dpmpp_3m_sde": "", "dpmpp_3m_sde_gpu": "", "ddpm": "", - "lcm": "LCM" + "lcm": "LCM", + "tcd": "TCD" } SAMPLER_EXTRA = { @@ -47,7 +48,7 @@ KSAMPLER_NAMES = list(KSAMPLER.keys()) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "tcd"] SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES diff --git a/modules/patch_precision.py b/modules/patch_precision.py index 83569bdd1..22ffda0ad 100644 --- a/modules/patch_precision.py +++ b/modules/patch_precision.py @@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti self.linear_end = linear_end sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) self.set_sigmas(sigmas) + alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32) + self.set_alphas_cumprod(alphas_cumprod) return