Skip to content

Commit

Permalink
Support lcm models.
Browse files Browse the repository at this point in the history
Use the "lcm" sampler to sample them, you also have to use the
ModelSamplingDiscrete node to set them as lcm models to use them properly.
  • Loading branch information
comfyanonymous committed Nov 9, 2023
1 parent ca71e54 commit 002aefa
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
15 changes: 14 additions & 1 deletion comfy/k_diffusion/sampling.py
Expand Up @@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu


def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand All @@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)

@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x
2 changes: 1 addition & 1 deletion comfy/samplers.py
Expand Up @@ -519,7 +519,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N

KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]

def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler):
Expand Down
75 changes: 73 additions & 2 deletions comfy_extras/nodes_model_advanced.py
@@ -1,6 +1,72 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import torch

class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma

sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling

c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5

return c_out * x0 + c_skip * model_input

class ModelSamplingDiscreteLCM(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012

betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

original_timesteps = 50
self.skip_steps = timesteps // original_timesteps


alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]

sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

@property
def sigma_min(self):
return self.sigmas[0]

@property
def sigma_max(self):
return self.sigmas[-1]

def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)

def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()

def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))


def rescale_zero_terminal_snr_sigmas(sigmas):
Expand All @@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],),
"sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}

Expand All @@ -38,17 +104,22 @@ def INPUT_TYPES(s):
def patch(self, model, sampling, zsnr):
m = model.clone()

sampling_base = comfy.model_sampling.ModelSamplingDiscrete
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))

m.add_object_patch("model_sampling", model_sampling)
return (m, )

Expand Down

1 comment on commit 002aefa

@halr9000
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ModelSamplingDiscrete needed when running a regular model, then patching it with LCM-lora? I am not seeing a difference when I take it out.

Please sign in to comment.