Skip to content

Commit

Permalink
wip: add tcd sampler
Browse files Browse the repository at this point in the history
adapted code from comfyanonymous/ComfyUI#3370
TODO: check if virtual scheduler tcd is needed for using sampling_base ModelSamplingDiscreteDistilled or if it's better to use sgm_uniform directly without patching
  • Loading branch information
mashb1t committed May 12, 2024
1 parent 6308fb8 commit 77acf81
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 13 deletions.
20 changes: 20 additions & 0 deletions ldm_patched/contrib/external_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ def get_sampler(self, eta, s_noise, r, noise_device):
sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )


class SamplerTCD:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"

FUNCTION = "get_sampler"

def get_sampler(self, eta=0.3):
sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta})
return (sampler, )


class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -292,6 +311,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler,
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerTCD": SamplerTCD,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}
5 changes: 4 additions & 1 deletion ldm_patched/contrib/external_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"sampling": (["eps", "v_prediction", "lcm", "tcd"]),
"zsnr": ("BOOLEAN", {"default": False}),
}}

Expand All @@ -90,6 +90,9 @@ def patch(self, model, sampling, zsnr):
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "tcd":
sampling_type = ldm_patched.modules.model_sampling.EPS
sampling_base = ModelSamplingDiscreteDistilled

class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
Expand Down
28 changes: 27 additions & 1 deletion ldm_patched/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
return x



@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
Expand Down Expand Up @@ -808,3 +807,30 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x


@torch.no_grad()
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])

model_sampling = model.inner_model.inner_model.model_sampling
timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu()
timesteps_s[-1] = 0
alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s]
beta_prod_s = 1 - alpha_prod_s
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample
eps = (x - denoised) / sigmas[i]
denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps

if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})

x = denoised
if eta > 0 and sigmas[i + 1] > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()

return x
8 changes: 4 additions & 4 deletions ldm_patched/modules/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
self.linear_start = linear_start
self.linear_end = linear_end

# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
self.set_alphas_cumprod(alphas_cumprod.float())

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

def set_alphas_cumprod(self, alphas_cumprod):
self.register_buffer("alphas_cumprod", alphas_cumprod.float())

@property
def sigma_min(self):
return self.sigmas[0]
Expand Down
2 changes: 1 addition & 1 deletion ldm_patched/modules/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N

KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"]

class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
Expand Down
8 changes: 4 additions & 4 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,19 +773,19 @@ def handler(async_task):
final_sampler_name = sampler_name
final_scheduler_name = scheduler_name

if scheduler_name == 'lcm':
if scheduler_name in ['lcm', 'tcd']:
final_scheduler_name = 'sgm_uniform'
if pipeline.final_unet is not None:
pipeline.final_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_unet,
sampling='lcm',
sampling=scheduler_name,
zsnr=False)[0]
if pipeline.final_refiner_unet is not None:
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_refiner_unet,
sampling='lcm',
sampling=scheduler_name,
zsnr=False)[0]
print('Using lcm scheduler.')
print(f'Using {scheduler_name} scheduler.')

async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])

Expand Down
5 changes: 3 additions & 2 deletions modules/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"dpmpp_3m_sde": "",
"dpmpp_3m_sde_gpu": "",
"ddpm": "",
"lcm": "LCM"
"lcm": "LCM",
"tcd": "TCD"
}

SAMPLER_EXTRA = {
Expand All @@ -47,7 +48,7 @@

KSAMPLER_NAMES = list(KSAMPLER.keys())

SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "tcd"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())

sampler_list = SAMPLER_NAMES
Expand Down
2 changes: 2 additions & 0 deletions modules/patch_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32)
self.set_alphas_cumprod(alphas_cumprod)
return


Expand Down

0 comments on commit 77acf81

Please sign in to comment.