Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ static void sample_k_diffusion(sample_method_t method,
size_t steps = sigmas.size() - 1;
// sample_euler_ancestral
switch (method) {
case EULER_A: {
case EULER_A_SAMPLE_METHOD: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -672,7 +672,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case EULER: // Implemented without any sigma churn
case EULER_SAMPLE_METHOD: // Implemented without any sigma churn
{
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -705,7 +705,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case HEUN: {
case HEUN_SAMPLE_METHOD: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -755,7 +755,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DPM2: {
case DPM2_SAMPLE_METHOD: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -807,7 +807,7 @@ static void sample_k_diffusion(sample_method_t method,
}

} break;
case DPMPP2S_A: {
case DPMPP2S_A_SAMPLE_METHOD: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -871,7 +871,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022)
{
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -910,7 +910,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
{
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);

Expand Down Expand Up @@ -953,7 +953,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
ggml_tensor* x_next = x;
Expand Down Expand Up @@ -1028,7 +1028,7 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{
int max_order = 4;
std::vector<ggml_tensor*> buffer_model;
Expand Down Expand Up @@ -1102,7 +1102,7 @@ static void sample_k_diffusion(sample_method_t method,
d_cur = ggml_dup_tensor(work_ctx, x_next);
}
} break;
case LCM: // Latent Consistency Models
case LCM_SAMPLE_METHOD: // Latent Consistency Models
{
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
Expand Down Expand Up @@ -1137,8 +1137,8 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing
case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing
{
// See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG]
Expand Down Expand Up @@ -1331,8 +1331,8 @@ static void sample_k_diffusion(sample_method_t method,
// factor c_in.
}
} break;
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation
case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation
{
// See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation
Expand Down
6 changes: 5 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1902,10 +1902,14 @@ int main(int argc, const char* argv[]) {
return 1;
}

if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}

if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}

if (params.sample_params.scheduler == SCHEDULER_COUNT) {
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
}
Expand Down
57 changes: 33 additions & 24 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
};

const char* sampling_methods_str[] = {
"default",
"Euler",
"Euler A",
"Heun",
"DPM2",
"DPM++ (2s)",
Expand All @@ -59,7 +59,6 @@ const char* sampling_methods_str[] = {
"LCM",
"DDIM \"trailing\"",
"TCD",
"Euler A",
};

/*================================================== Helper Functions ================================================*/
Expand Down Expand Up @@ -2228,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
}

const char* sample_method_to_str[] = {
"default",
"euler",
"euler_a",
"heun",
"dpm2",
"dpm++2s_a",
Expand All @@ -2240,7 +2239,6 @@ const char* sample_method_to_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"euler_a",
};

const char* sd_sample_method_name(enum sample_method_t sample_method) {
Expand Down Expand Up @@ -2468,7 +2466,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20;
}

Expand Down Expand Up @@ -2626,19 +2624,19 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {

enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
SDVersion version = sd_ctx->sd->version;
if (sd_version_is_dit(version))
return EULER;
else
return EULER_A;
if (sd_version_is_dit(sd_ctx->sd->version)) {
return EULER_SAMPLE_METHOD;
}
}
return SAMPLE_METHOD_COUNT;
return EULER_A_SAMPLE_METHOD;
}

enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
}
return DISCRETE_SCHEDULER;
}
Expand Down Expand Up @@ -2826,7 +2824,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int C = sd_ctx->sd->get_latent_channel();
int W = width / sd_ctx->sd->get_vae_scale_factor();
int H = height / sd_ctx->sd->get_vae_scale_factor();
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);

struct ggml_tensor* control_latent = nullptr;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
Expand Down Expand Up @@ -3055,10 +3052,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->rng->manual_seed(seed);
sd_ctx->sd->sampler_rng->manual_seed(seed);

int sample_steps = sd_img_gen_params->sample_params.sample_steps;

size_t t0 = ggml_time_ms();

enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);

int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);

ggml_tensor* init_latent = nullptr;
Expand Down Expand Up @@ -3247,11 +3249,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
}

enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_DEFAULT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}

sd_image_t* result_images = generate_image_internal(sd_ctx,
work_ctx,
init_latent,
Expand Down Expand Up @@ -3301,6 +3298,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s

int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();

enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);

int high_noise_sample_steps = 0;
if (sd_ctx->sd->high_noise_diffusion_model) {
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
Expand Down Expand Up @@ -3569,6 +3572,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// High Noise Sample
if (high_noise_sample_steps > 0) {
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
high_noise_sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);

int64_t sampling_start = ggml_time_ms();

std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
Expand All @@ -3587,7 +3596,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->high_noise_sample_params.guidance,
sd_vid_gen_params->high_noise_sample_params.eta,
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
sd_vid_gen_params->high_noise_sample_params.sample_method,
high_noise_sample_method,
high_noise_sigmas,
-1,
{},
Expand Down Expand Up @@ -3624,7 +3633,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta,
sd_vid_gen_params->sample_params.shifted_timestep,
sd_vid_gen_params->sample_params.sample_method,
sample_method,
sigmas,
-1,
{},
Expand Down
25 changes: 12 additions & 13 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,18 @@ enum rng_type_t {
};

enum sample_method_t {
SAMPLE_METHOD_DEFAULT,
EULER,
HEUN,
DPM2,
DPMPP2S_A,
DPMPP2M,
DPMPP2Mv2,
IPNDM,
IPNDM_V,
LCM,
DDIM_TRAILING,
TCD,
EULER_A,
EULER_SAMPLE_METHOD,
EULER_A_SAMPLE_METHOD,
HEUN_SAMPLE_METHOD,
DPM2_SAMPLE_METHOD,
DPMPP2S_A_SAMPLE_METHOD,
DPMPP2M_SAMPLE_METHOD,
DPMPP2Mv2_SAMPLE_METHOD,
IPNDM_SAMPLE_METHOD,
IPNDM_V_SAMPLE_METHOD,
LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};

Expand Down
Loading