Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
#define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000

struct SigmaSchedule {
int version = 0;
struct SigmaScheduler {
typedef std::function<float(float)> t_to_sigma_t;

virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
};

struct DiscreteSchedule : SigmaSchedule {
struct DiscreteScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result;

Expand All @@ -42,7 +41,7 @@ struct DiscreteSchedule : SigmaSchedule {
}
};

struct ExponentialSchedule : SigmaSchedule {
struct ExponentialScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas;

Expand Down Expand Up @@ -149,7 +148,10 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
/*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/
struct AYSSchedule : SigmaSchedule {
struct AYSScheduler : SigmaScheduler {
SDVersion version;
explicit AYSScheduler(SDVersion version)
: version(version) {}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
const std::vector<float> noise_levels[] = {
/* SD1.5 */
Expand All @@ -169,19 +171,19 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> results(n + 1);

if (sd_version_is_sd2((SDVersion)version)) {
LOG_WARN("AYS not designed for SD2.X models");
LOG_WARN("AYS_SCHEDULER not designed for SD2.X models");
} /* fallthrough */
else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS using SD1.5 noise levels");
LOG_INFO("AYS_SCHEDULER using SD1.5 noise levels");
inputs = noise_levels[0];
} else if (sd_version_is_sdxl((SDVersion)version)) {
LOG_INFO("AYS using SDXL noise levels");
LOG_INFO("AYS_SCHEDULER using SDXL noise levels");
inputs = noise_levels[1];
} else if (version == VERSION_SVD) {
LOG_INFO("AYS using SVD noise levels");
LOG_INFO("AYS_SCHEDULER using SVD noise levels");
inputs = noise_levels[2];
} else {
LOG_ERROR("Version not compatible with AYS scheduler");
LOG_ERROR("Version not compatible with AYS_SCHEDULER scheduler");
return results;
}

Expand All @@ -203,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
/*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/
struct GITSSchedule : SigmaSchedule {
struct GITSScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
if (sigma_max <= 0.0f) {
return std::vector<float>{};
Expand Down Expand Up @@ -232,7 +234,7 @@ struct GITSSchedule : SigmaSchedule {
}
};

struct SGMUniformSchedule : SigmaSchedule {
struct SGMUniformScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
std::vector<float> result;
if (n == 0) {
Expand All @@ -251,7 +253,7 @@ struct SGMUniformSchedule : SigmaSchedule {
}
};

struct KarrasSchedule : SigmaSchedule {
struct KarrasScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
// These *COULD* be function arguments here,
// but does anybody ever bother to touch them?
Expand All @@ -270,7 +272,7 @@ struct KarrasSchedule : SigmaSchedule {
}
};

struct SimpleSchedule : SigmaSchedule {
struct SimpleScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas;

Expand Down Expand Up @@ -299,8 +301,8 @@ struct SimpleSchedule : SigmaSchedule {
}
};

// Close to Beta Schedule, but increadably simple in code.
struct SmoothStepSchedule : SigmaSchedule {
// Close to Beta Scheduler, but increadably simple in code.
struct SmoothStepScheduler : SigmaScheduler {
static constexpr float smoothstep(float x) {
return x * x * (3.0f - 2.0f * x);
}
Expand Down Expand Up @@ -329,7 +331,6 @@ struct SmoothStepSchedule : SigmaSchedule {
};

struct Denoiser {
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0;
virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0;
Expand All @@ -338,8 +339,47 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;

virtual std::vector<float> get_sigmas(uint32_t n) {
virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
case DISCRETE_SCHEDULER:
LOG_INFO("get_sigmas with discrete scheduler");
scheduler = std::make_shared<DiscreteScheduler>();
break;
case KARRAS_SCHEDULER:
LOG_INFO("get_sigmas with Karras scheduler");
scheduler = std::make_shared<KarrasScheduler>();
break;
case EXPONENTIAL_SCHEDULER:
LOG_INFO("get_sigmas exponential scheduler");
scheduler = std::make_shared<ExponentialScheduler>();
break;
case AYS_SCHEDULER:
LOG_INFO("get_sigmas with Align-Your-Steps scheduler");
scheduler = std::make_shared<AYSScheduler>(version);
break;
case GITS_SCHEDULER:
LOG_INFO("get_sigmas with GITS scheduler");
scheduler = std::make_shared<GITSScheduler>();
break;
case SGM_UNIFORM_SCHEDULER:
LOG_INFO("get_sigmas with SGM Uniform scheduler");
scheduler = std::make_shared<SGMUniformScheduler>();
break;
case SIMPLE_SCHEDULER:
LOG_INFO("get_sigmas with Simple scheduler");
scheduler = std::make_shared<SimpleScheduler>();
break;
case SMOOTHSTEP_SCHEDULER:
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
break;
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
break;
}
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
}
};
Expand Down Expand Up @@ -426,7 +466,6 @@ struct EDMVDenoiser : public CompVisVDenoiser {

EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) {
scheduler = std::make_shared<ExponentialSchedule>();
}

float t_to_sigma(float t) override {
Expand Down Expand Up @@ -1109,7 +1148,7 @@ static void sample_k_diffusion(sample_method_t method,
// end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler)
// (k-diffusion's LMSDiscreteSchedulerr)
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
Expand Down Expand Up @@ -1137,7 +1176,7 @@ static void sample_k_diffusion(sample_method_t method,

for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps
// "Common Diffusion Noise Schedulers and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming
//
Expand Down
34 changes: 10 additions & 24 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,34 +912,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1;
};

auto on_schedule_arg = [&](int argc, const char** argv, int index) {
auto on_scheduler_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.sample_params.scheduler = str_to_schedule(arg);
if (params.sample_params.scheduler == SCHEDULE_COUNT) {
params.sample_params.scheduler = str_to_scheduler(arg);
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n",
arg);
return -1;
}
return 1;
};

auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid high noise scheduler %s\n",
arg);
return -1;
}
return 1;
};

auto on_prediction_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
Expand Down Expand Up @@ -1212,7 +1198,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_schedule_arg},
on_scheduler_arg},
{"",
"--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])",
Expand All @@ -1222,10 +1208,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
" default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg},
{"",
"--high-noise-scheduler",
"(high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_high_noise_schedule_arg},
{"",
"--high-noise-skip-layers",
"(high noise) layers to skip for SLG steps (default: [7,8,9])",
Expand Down Expand Up @@ -1442,8 +1424,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.sample_params.scheduler != DEFAULT) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
if (params.sample_params.scheduler != SCHEDULER_COUNT) {
parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
Expand Down Expand Up @@ -1925,6 +1907,10 @@ int main(int argc, const char* argv[]) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}

if (params.sample_params.scheduler == SCHEDULER_COUNT) {
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
}

if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(),
Expand Down
Loading
Loading