Skip to content

Commit

Permalink
Refactor sampling code for more advanced sampler nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 4, 2024
1 parent 6c6a392 commit 57753c9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 43 deletions.
38 changes: 25 additions & 13 deletions comfy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,23 @@ def convert_cond(cond):
out.append(temp)
return out

def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []

for i in range(len(conds)):
cnets += get_models_from_cond(conds[i], "control")
gligen += get_models_from_cond(conds[i], "gligen")

control_nets = set(cnets)

inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)

gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
Expand All @@ -73,24 +79,25 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()

def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
def prepare_sampling(model, noise_shape, conds, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
for i in range(len(conds)):
conds[i] = convert_cond(conds[i])

if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)

real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model

return real_model, positive, negative, noise_mask, models
return real_model, conds, noise_mask, models


def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
positive_copy, negative_copy = conds_copy

noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
Expand All @@ -105,14 +112,19 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
return samples

def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)

samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))

control_cleanup = []
for i in range(len(conds)):
control_cleanup += get_models_from_cond(conds[i], "control")

cleanup_additional_models(set(control_cleanup))
return samples

72 changes: 42 additions & 30 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return cfg_result

class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, cond_scale=1.0):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
self.cond_scale = cond_scale
def apply_model(self, x, timestep, conds, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
Expand All @@ -274,13 +275,13 @@ def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
return out
Expand Down Expand Up @@ -568,44 +569,55 @@ def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, *

return KSAMPLER(sampler_function, extra_options, inpaint_options)

def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise

def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
negative = negative[:]
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)

resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
for k in conds:
calculate_start_end_timesteps(model, conds[k])

model_wrap = wrap_model(model)
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)

calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)

for k in conds:
pre_run_control(model, conds[k])

if "positive" in conds:
positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])

return conds


def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)

if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed)
model_wrap = guider_class(model)

#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
extra_args = {"conds": conds, "model_options": model_options, "seed":seed}

pre_run_control(model, negative + positive)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))

apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])

extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed)

samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))

SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
Expand Down

0 comments on commit 57753c9

Please sign in to comment.