Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct SDParams {
std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params;

float moe_boundary = 0.875f;

int video_frames = 1;
int fps = 16;

Expand Down Expand Up @@ -117,6 +119,7 @@ struct SDParams {
SDParams() {
sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params);
high_noise_sample_params.sample_steps = -1;
}
};

Expand Down Expand Up @@ -167,6 +170,7 @@ void print_params(SDParams params) {
printf(" height: %d\n", params.height);
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
printf(" moe_boundary: %.3f\n", params.moe_boundary);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
printf(" seed: %ld\n", params.seed);
Expand Down Expand Up @@ -243,7 +247,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete)\n");
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" (high noise) sampling method (default: \"euler_a\")\n");
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: 20)\n");
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n");
Expand Down Expand Up @@ -274,6 +278,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
printf(" --video-frames video frames (default: 1)\n");
printf(" --fps fps (default: 24)\n");
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
printf(" -v, --verbose print extra info\n");
}

Expand Down Expand Up @@ -362,7 +368,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
std::string arg;
for (int i = 1; i < argc; i++) {
bool found_arg = false;
arg = argv[i];
arg = argv[i];

for (auto& option : options.string_options) {
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
Expand Down Expand Up @@ -423,7 +429,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
for (auto& option : options.manual_options) {
if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) {
found_arg = true;
int ret = option.cb(argc, argv, i);
int ret = option.cb(argc, argv, i);
if (ret < 0) {
invalid_arg = true;
break;
Expand All @@ -435,7 +441,7 @@ bool parse_options(int argc, const char** argv, ArgOptions& options) {
break;
}
if (!found_arg) {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
return false;
}
}
Expand Down Expand Up @@ -507,6 +513,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--strength", "", &params.strength},
{"", "--style-ratio", "", &params.style_ratio},
{"", "--control-strength", "", &params.control_strength},
{"", "--moe-boundary", "", &params.moe_boundary},
};

options.bool_options = {
Expand Down Expand Up @@ -767,8 +774,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
}

if (params.high_noise_sample_params.sample_steps <= 0) {
fprintf(stderr, "error: the high_noise_sample_steps must be greater than 0\n");
exit(1);
params.high_noise_sample_params.sample_steps = -1;
}

if (params.strength < 0.f || params.strength > 1.f) {
Expand Down Expand Up @@ -1222,6 +1228,7 @@ int main(int argc, const char* argv[]) {
params.height,
params.sample_params,
params.high_noise_sample_params,
params.moe_boundary,
params.strength,
params.seed,
params.video_frames,
Expand Down
31 changes: 25 additions & 6 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,11 +1727,13 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t));
sd_sample_params_init(&sd_vid_gen_params->sample_params);
sd_sample_params_init(&sd_vid_gen_params->high_noise_sample_params);
sd_vid_gen_params->width = 512;
sd_vid_gen_params->height = 512;
sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->high_noise_sample_params.sample_steps = -1;
sd_vid_gen_params->width = 512;
sd_vid_gen_params->height = 512;
sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->moe_boundary = 0.875f;
}

struct sd_ctx_t {
Expand Down Expand Up @@ -2381,7 +2383,24 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
}

std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
int total_steps = sample_steps;

if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps;
}
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);

if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
for (size_t i = 0; i < sigmas.size(); ++i) {
if (sigmas[i] < sd_vid_gen_params->moe_boundary) {
high_noise_sample_steps = i;
break;
}
}
LOG_DEBUG("switching from high noise model at step %d", high_noise_sample_steps);
sample_steps = total_steps - high_noise_sample_steps;
}

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ typedef struct {
int height;
sd_sample_params_t sample_params;
sd_sample_params_t high_noise_sample_params;
float moe_boundary;
float strength;
int64_t seed;
int video_frames;
Expand Down
Loading