Skip to content

Commit

Permalink
Add SamplerDPMPP_2M_SDE node.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 29, 2023
1 parent 26b7372 commit 66756de
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
4 changes: 2 additions & 2 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]

def ksampler(sampler_name):
def ksampler(sampler_name, extra_options={}):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
Expand Down Expand Up @@ -627,7 +627,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
return samples
return KSAMPLER

Expand Down
25 changes: 25 additions & 0 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)()
return (sampler, )

class SamplerDPMPP_2M_SDE:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"solver_type": (['midpoint', 'heun'], ),
"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"noise_device": (['gpu', 'cpu'], ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling"

FUNCTION = "get_sampler"

def get_sampler(self, solver_type, eta, s_noise, noise_device):
if noise_device == 'cpu':
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
return (sampler, )


class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -132,6 +156,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler,
"SamplerCustom": SamplerCustom,
"KarrasScheduler": KarrasScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas,
}

0 comments on commit 66756de

Please sign in to comment.