Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Context Options:
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
then threads will be set to the number of CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
graph splitting
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
when needed
Expand Down
9 changes: 8 additions & 1 deletion examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,12 @@ ArgOptions SDContextParams::get_options() {
&chroma_t5_mask_pad},
};

options.float_options = {};
options.float_options = {
{"",
"--max-vram",
"maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables graph splitting",
&max_vram},
};

options.bool_options = {
{"",
Expand Down Expand Up @@ -670,6 +675,7 @@ std::string SDContextParams::to_string() const {
<< " rng_type: " << sd_rng_type_name(rng_type) << ",\n"
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
<< " max_vram: " << max_vram << ",\n"
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
Expand Down Expand Up @@ -744,6 +750,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
chroma_use_t5_mask,
chroma_t5_mask_pad,
qwen_image_zero_cond_t,
max_vram,
};
return sd_ctx_params;
}
Expand Down
1 change: 1 addition & 0 deletions examples/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct SDContextParams {
rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false;
float max_vram = 0.f;
bool enable_mmap = false;
bool control_net_cpu = false;
bool clip_on_cpu = false;
Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ Context Options:
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0,
then threads will be set to the number of CPU physical cores
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
--max-vram <float> maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables
graph splitting
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM
when needed
Expand Down
1 change: 1 addition & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ typedef struct {
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
bool qwen_image_zero_cond_t;
float max_vram;
} sd_ctx_params_t;

typedef struct {
Expand Down
6 changes: 6 additions & 0 deletions src/anima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,15 @@ namespace Anima {
encoder_hidden_states = adapted_context;
}

sd::ggml_graph_cut::mark_graph_cut(x, "anima.prelude", "x");
sd::ggml_graph_cut::mark_graph_cut(embedded_timestep, "anima.prelude", "embedded_timestep");
sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb");
sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context");

for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x");
}

x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
Expand Down
10 changes: 10 additions & 0 deletions src/auto_encoder_kl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class Encoder : public GGMLBlock {
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);

auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.prelude", "h");

// downsampling
size_t num_resolutions = ch_mult.size();
Expand All @@ -337,19 +338,22 @@ class Encoder : public GGMLBlock {
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);

h = down_block->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".block." + std::to_string(j), "h");
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);

h = down_sample->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".downsample", "h");
}
}

// middle
h = mid_block_1->forward(ctx, h);
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.mid", "h");

// end
h = norm_out->forward(ctx, h);
Expand Down Expand Up @@ -450,13 +454,15 @@ class Decoder : public GGMLBlock {

// conv_in
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.prelude", "h");

// middle
h = mid_block_1->forward(ctx, h);
// return h;

h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.mid", "h");

// upsampling
int num_resolutions = static_cast<int>(ch_mult.size());
Expand All @@ -466,12 +472,14 @@ class Decoder : public GGMLBlock {
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);

h = up_block->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".block." + std::to_string(j), "h");
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);

h = up_sample->forward(ctx, h);
// sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".upsample", "h");
}
}

Expand Down Expand Up @@ -599,6 +607,7 @@ class AutoEncoderKLModel : public GGMLBlock {
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.decode.prelude", "z");
}
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);

Expand All @@ -616,6 +625,7 @@ class AutoEncoderKLModel : public GGMLBlock {
if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
// sd::ggml_graph_cut::mark_graph_cut(z, "vae.encode.final", "z");
}
if (sd_version_uses_flux2_vae(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
Expand Down
14 changes: 10 additions & 4 deletions src/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ struct CLIPEncoder : public GGMLBlock {

ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* mask = nullptr,
int clip_skip = -1) {
ggml_tensor* mask = nullptr,
int clip_skip = -1,
const std::string& graph_cut_prefix = "") {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip);
Expand All @@ -112,6 +113,9 @@ struct CLIPEncoder : public GGMLBlock {
std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
if (!graph_cut_prefix.empty()) {
sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".layers." + std::to_string(i), "x");
}
// LOG_DEBUG("layer %d", i);
}
return x;
Expand Down Expand Up @@ -304,7 +308,8 @@ class CLIPTextModel : public GGMLBlock {
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);

auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip);
sd::ggml_graph_cut::mark_graph_cut(x, "clip_text.prelude", "x");
x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip, "clip_text");
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
Expand Down Expand Up @@ -368,7 +373,8 @@ class CLIPVisionModel : public GGMLBlock {

auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, nullptr, clip_skip);
sd::ggml_graph_cut::mark_graph_cut(x, "clip_vision.prelude", "x");
x = encoder->forward(ctx, x, nullptr, clip_skip, "clip_vision");

auto last_hidden_state = x;

Expand Down
45 changes: 44 additions & 1 deletion src/conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {}
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(int n_threads,
const ConditionerParams& conditioner_params) {
Expand Down Expand Up @@ -165,6 +166,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
text_model->set_max_graph_vram_bytes(max_vram_bytes);
if (sd_version_is_sdxl(version)) {
text_model2->set_max_graph_vram_bytes(max_vram_bytes);
}
}

void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
Expand Down Expand Up @@ -781,6 +789,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes);
}
if (clip_g) {
clip_g->set_max_graph_vram_bytes(max_vram_bytes);
}
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}

void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
Expand Down Expand Up @@ -1124,6 +1144,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (clip_l) {
clip_l->set_max_graph_vram_bytes(max_vram_bytes);
}
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}

void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
Expand Down Expand Up @@ -1349,6 +1378,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
if (t5) {
t5->set_max_graph_vram_bytes(max_vram_bytes);
}
}

void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
Expand Down Expand Up @@ -1525,6 +1560,10 @@ struct AnimaConditioner : public Conditioner {
return llm->get_params_buffer_size();
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes);
}

void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
Expand Down Expand Up @@ -1657,6 +1696,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
llm->set_max_graph_vram_bytes(max_vram_bytes);
}

void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
Expand Down
33 changes: 33 additions & 0 deletions src/diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct DiffusionModel {
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};

Expand Down Expand Up @@ -98,6 +99,10 @@ struct UNetModel : public DiffusionModel {
unet.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
unet.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
unet.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -164,6 +169,10 @@ struct MMDiTModel : public DiffusionModel {
mmdit.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
mmdit.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
mmdit.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -229,6 +238,10 @@ struct FluxModel : public DiffusionModel {
flux.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
flux.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
flux.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -299,6 +312,10 @@ struct AnimaModel : public DiffusionModel {
anima.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
anima.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
anima.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -364,6 +381,10 @@ struct WanModel : public DiffusionModel {
wan.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
wan.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
wan.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -433,6 +454,10 @@ struct QwenImageModel : public DiffusionModel {
qwen_image.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
qwen_image.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
qwen_image.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -499,6 +524,10 @@ struct ZImageModel : public DiffusionModel {
z_image.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
z_image.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
z_image.set_circular_axes(circular_x, circular_y);
}
Expand Down Expand Up @@ -564,6 +593,10 @@ struct ErnieImageModel : public DiffusionModel {
ernie_image.set_flash_attention_enabled(enabled);
}

void set_max_graph_vram_bytes(size_t max_vram_bytes) override {
ernie_image.set_max_graph_vram_bytes(max_vram_bytes);
}

void set_circular_axes(bool circular_x, bool circular_y) override {
ernie_image.set_circular_axes(circular_x, circular_y);
}
Expand Down
Loading
Loading