From ee0162278381314c0afa7625ae472a32d0be2abf Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 7 Jun 2026 00:52:22 +0800 Subject: [PATCH] refactor: unify model config detection --- CONTRIBUTING.md | 2 + docs/model_config.md | 118 ++++++++ src/anima.hpp | 124 ++++---- src/conditioner.hpp | 4 +- src/ernie_image.hpp | 206 ++++++------- src/flux.hpp | 465 ++++++++++++++--------------- src/ggml_extend.hpp | 2 +- src/hidream_o1.hpp | 98 ++++--- src/ideogram4.hpp | 54 ++-- src/lens.hpp | 171 +++++------ src/llm.hpp | 402 ++++++++++++------------- src/ltx_audio_vae.h | 11 +- src/ltxv.hpp | 678 ++++++++++++++++++++++--------------------- src/mmdit.hpp | 252 ++++++++++------ src/pid.hpp | 215 +++++++------- src/qwen_image.hpp | 146 +++++----- src/t5.hpp | 64 ++-- src/unet.hpp | 216 ++++++++++---- src/wan.hpp | 334 ++++++++++----------- src/z_image.hpp | 219 +++++++++----- 20 files changed, 2134 insertions(+), 1647 deletions(-) create mode 100644 docs/model_config.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9ba9177a3..f94e39049 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,6 +44,8 @@ Naming conventions: Some older code in the project may not fully follow the current conventions. Please do not submit PRs that only rewrite existing code to match style rules. +When adding or modifying model implementations, follow the model config and weight detection conventions in [docs/model_config.md](docs/model_config.md). + ## AI-Assisted Contributions AI tools may be used to assist development, but contributors are responsible for the quality and correctness of the submitted code. diff --git a/docs/model_config.md b/docs/model_config.md new file mode 100644 index 000000000..8c562fff4 --- /dev/null +++ b/docs/model_config.md @@ -0,0 +1,118 @@ +# Model Configuration Conventions + +This document describes the conventions for model configuration structs and +weight-based configuration detection. + +## Config Types + +Model configuration should live in a model-specific `*Config` struct. + +Examples: + +- `ZImageConfig` +- `UNetConfig` +- `MMDiTConfig` +- `LLMConfig` + +Preserve established acronym casing in type names, such as `UNet`, `MMDiT`, +`LLM`, `VAE`, and `T5`. + +Place the config struct near the top of the model header, before the main model +blocks and runner types that consume it. + +## Config Variables + +Variables and members that hold a config should be named `config`. + +Examples: + +```cpp +UNetConfig config; +UnetModelBlock unet; + +MMDiTRunner(...) + : DiffusionModelRunner(backend, params_backend, prefix), + config(MMDiTConfig::detect_from_weights(tensor_storage_map, prefix)), + mmdit(config) { +} +``` + +Avoid alternate names such as `params`, `params_cfg`, `model_params`, or +model-specific aliases unless an existing public API requires them. + +## Weight Detection + +If a model can derive configuration from loaded weight metadata, expose that +logic as a static method on the config type: + +```cpp +static XxxConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix); +``` + +Additional selector arguments are allowed when required by an existing model +family, for example `SDVersion version` or an architecture enum: + +```cpp +static UNetConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + SDVersion version = VERSION_SD1); +``` + +Use `TensorStorage` metadata, especially `n_dims` and `ne`, to infer shapes. +Do not load or parse tensor data for config detection. + +Detection should respect `prefix`. For nested weights, construct full names from +`prefix + "." + suffix` or filter entries with `starts_with(name, prefix)`. + +Do not add persistent config fields such as `inferred_from_weights` only to +record whether detection happened. If the function needs to decide whether to +print a debug line, keep that as local control flow inside `detect_from_weights`. + +## Logging + +When config values are inferred from weights, print one `LOG_DEBUG` line at the +end of `detect_from_weights`. + +Example: + +```cpp +LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, + config.num_layers, + config.vocab_size, + config.hidden_size, + config.intermediate_size); +``` + +Only print the config detection log when the function actually inferred values +from weights. Do not duplicate the same config summary in runner constructors or +model loading code. + +Use the correct format specifiers for field types, such as `%" PRId64 "` for +`int64_t` and `%d` for `int`. + +## Runner And Model Responsibilities + +Runners should detect the config once and pass it into the model block: + +```cpp +struct XxxRunner : public DiffusionModelRunner { + XxxConfig config; + XxxModel model; + + XxxRunner(..., const String2TensorStorage& tensor_storage_map, const std::string prefix) + : DiffusionModelRunner(backend, params_backend, prefix), + config(XxxConfig::detect_from_weights(tensor_storage_map, prefix)), + model(config) { + model.init(params_ctx, tensor_storage_map, prefix); + } +}; +``` + +Model blocks should consume `config` directly instead of re-scanning weights in +their constructors. Keep config-derived behavior centralized in the config +struct. + +If a model has no weight-derived config today, it may still provide +`detect_from_weights` for API consistency, but it should not print a config +detection log unless it actually derives values from weights. diff --git a/src/anima.hpp b/src/anima.hpp index fb4745c20..71f187641 100644 --- a/src/anima.hpp +++ b/src/anima.hpp @@ -1,6 +1,7 @@ #ifndef __ANIMA_HPP__ #define __ANIMA_HPP__ +#include #include #include #include @@ -14,6 +15,47 @@ namespace Anima { constexpr int ANIMA_GRAPH_SIZE = 65536; + struct AnimaConfig { + int64_t in_channels = 16; + int64_t out_channels = 16; + int64_t hidden_size = 2048; + int64_t text_embed_dim = 1024; + int64_t num_heads = 16; + int64_t head_dim = 128; + int patch_size = 2; + int64_t num_layers = 28; + std::vector axes_dim = {44, 42, 42}; + int theta = 10000; + + static AnimaConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + AnimaConfig config; + int64_t detected_layers = 0; + std::string layer_tag = prefix.empty() ? "blocks." : prefix + ".blocks."; + for (const auto& [name, _] : tensor_storage_map) { + size_t pos = name.find(layer_tag); + if (pos == std::string::npos) { + continue; + } + size_t start = pos + layer_tag.size(); + size_t end = name.find('.', start); + if (end == std::string::npos) { + continue; + } + int64_t layer_id = atoll(name.substr(start, end - start).c_str()); + detected_layers = std::max(detected_layers, layer_id + 1); + } + if (detected_layers > 0) { + config.num_layers = detected_layers; + LOG_DEBUG("anima: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", head_dim = %" PRId64, + config.num_layers, + config.hidden_size, + config.num_heads, + config.head_dim); + } + return config; + } + }; + __STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx, ggml_tensor* x, ggml_tensor* gate) { @@ -418,31 +460,22 @@ namespace Anima { struct AnimaNet : public GGMLBlock { public: - int64_t in_channels = 16; - int64_t out_channels = 16; - int64_t hidden_size = 2048; - int64_t text_embed_dim = 1024; - int64_t num_heads = 16; - int64_t head_dim = 128; - int patch_size = 2; - int64_t num_layers = 28; - std::vector axes_dim = {44, 42, 42}; - int theta = 10000; + AnimaConfig config; public: AnimaNet() = default; - explicit AnimaNet(int64_t num_layers) - : num_layers(num_layers) { - blocks["x_embedder"] = std::make_shared((in_channels + 1) * patch_size * patch_size, hidden_size); - blocks["t_embedder"] = std::make_shared(hidden_size, hidden_size * 3); - blocks["t_embedding_norm"] = std::make_shared(hidden_size, 1e-6f); - for (int i = 0; i < num_layers; i++) { - blocks["blocks." + std::to_string(i)] = std::make_shared(hidden_size, - text_embed_dim, - num_heads, - head_dim); + explicit AnimaNet(AnimaConfig config) + : config(config) { + blocks["x_embedder"] = std::make_shared((config.in_channels + 1) * config.patch_size * config.patch_size, config.hidden_size); + blocks["t_embedder"] = std::make_shared(config.hidden_size, config.hidden_size * 3); + blocks["t_embedding_norm"] = std::make_shared(config.hidden_size, 1e-6f); + for (int i = 0; i < config.num_layers; i++) { + blocks["blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.text_embed_dim, + config.num_heads, + config.head_dim); } - blocks["final_layer"] = std::make_shared(hidden_size, patch_size, out_channels); + blocks["final_layer"] = std::make_shared(config.hidden_size, config.patch_size, config.out_channels); blocks["llm_adapter"] = std::make_shared(1024, 1024, 1024, 6, 16); } @@ -469,11 +502,11 @@ namespace Anima { auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]); x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W] - x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw] + x = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size); // [N, h*w, (C+1)*ph*pw] x = x_embedder->forward(ctx, x); - auto timestep_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast(hidden_size)); + auto timestep_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast(config.hidden_size)); auto temb = t_embedder->forward(ctx, timestep_proj); auto embedded_timestep = t_embedding_norm->forward(ctx, timestep_proj); @@ -505,7 +538,7 @@ namespace Anima { sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb"); sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context"); - for (int i = 0; i < num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe); sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x"); @@ -513,7 +546,7 @@ namespace Anima { x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C] - x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W] + x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, config.patch_size, config.patch_size, false); // [N, C, H, W] return x; } @@ -524,35 +557,16 @@ namespace Anima { std::vector image_pe_vec; std::vector adapter_q_pe_vec; std::vector adapter_k_pe_vec; + AnimaConfig config; AnimaNet net; AnimaRunner(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model") - : DiffusionModelRunner(backend, params_backend, prefix) { - int64_t num_layers = 0; - std::string layer_tag = prefix + ".net.blocks."; - for (const auto& kv : tensor_storage_map) { - const std::string& tensor_name = kv.first; - size_t pos = tensor_name.find(layer_tag); - if (pos == std::string::npos) { - continue; - } - size_t start = pos + layer_tag.size(); - size_t end = tensor_name.find('.', start); - if (end == std::string::npos) { - continue; - } - int64_t layer_id = atoll(tensor_name.substr(start, end - start).c_str()); - num_layers = std::max(num_layers, layer_id + 1); - } - if (num_layers <= 0) { - num_layers = 28; - } - LOG_INFO("anima net layers: %" PRId64, num_layers); - - net = AnimaNet(num_layers); + : DiffusionModelRunner(backend, params_backend, prefix), + config(AnimaConfig::detect_from_weights(tensor_storage_map, prefix + ".net")) { + net = AnimaNet(config); net.init(params_ctx, tensor_storage_map, prefix + ".net"); } @@ -623,22 +637,22 @@ namespace Anima { GGML_ASSERT(x->ne[3] == 1); ggml_cgraph* gf = new_graph_custom(ANIMA_GRAPH_SIZE); - int64_t pad_h = (net.patch_size - x->ne[1] % net.patch_size) % net.patch_size; - int64_t pad_w = (net.patch_size - x->ne[0] % net.patch_size) % net.patch_size; + int64_t pad_h = (config.patch_size - x->ne[1] % config.patch_size) % config.patch_size; + int64_t pad_w = (config.patch_size - x->ne[0] % config.patch_size) % config.patch_size; int64_t h_pad = x->ne[1] + pad_h; int64_t w_pad = x->ne[0] + pad_w; image_pe_vec = gen_anima_image_pe_vec(1, static_cast(h_pad), static_cast(w_pad), - static_cast(net.patch_size), - net.theta, - net.axes_dim, + static_cast(config.patch_size), + config.theta, + config.axes_dim, 4.0f, 4.0f, 1.0f); - int64_t image_pos_len = static_cast(image_pe_vec.size()) / (2 * 2 * (net.head_dim / 2)); - auto image_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, net.head_dim / 2, image_pos_len); + int64_t image_pos_len = static_cast(image_pe_vec.size()) / (2 * 2 * (config.head_dim / 2)); + auto image_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.head_dim / 2, image_pos_len); set_backend_tensor_data(image_pe, image_pe_vec.data()); ggml_tensor* adapter_q_pe = nullptr; diff --git a/src/conditioner.hpp b/src/conditioner.hpp index f0f8e3e47..862bfc871 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -1971,7 +1971,7 @@ struct LLMEmbedder : public Conditioner { for (int i = 0; i < conditioner_params.ref_images->size(); i++) { const auto& image = (*conditioner_params.ref_images)[i]; - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + double factor = llm->config.vision.patch_size * llm->config.vision.spatial_merge_size; int height = static_cast(image.shape()[1]); int width = static_cast(image.shape()[0]); int h_bar = static_cast(std::round(height / factor) * factor); @@ -2042,7 +2042,7 @@ struct LLMEmbedder : public Conditioner { for (int i = 0; i < conditioner_params.ref_images->size(); i++) { const auto& image = (*conditioner_params.ref_images)[i]; - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + double factor = llm->config.vision.patch_size * llm->config.vision.spatial_merge_size; int height = static_cast(image.shape()[1]); int width = static_cast(image.shape()[0]); int h_bar = static_cast(std::round(height / factor) * factor); diff --git a/src/ernie_image.hpp b/src/ernie_image.hpp index 355468950..0a0f2c950 100644 --- a/src/ernie_image.hpp +++ b/src/ernie_image.hpp @@ -13,6 +13,76 @@ namespace ErnieImage { constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960; + struct ErnieImageConfig { + int64_t hidden_size = 4096; + int64_t num_heads = 32; + int64_t num_layers = 36; + int64_t ffn_hidden_size = 12288; + int64_t in_channels = 128; + int64_t out_channels = 128; + int patch_size = 1; + int64_t text_in_dim = 3072; + int theta = 256; + std::vector axes_dim = {32, 48, 48}; + int axes_dim_sum = 128; + float eps = 1e-6f; + + static ErnieImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + ErnieImageConfig config; + config.num_layers = 0; + int64_t detected_head_dim = 0; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) { + config.patch_size = static_cast(tensor_storage.ne[0]); + config.in_channels = tensor_storage.ne[2]; + config.hidden_size = tensor_storage.ne[3]; + } else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) { + config.text_in_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) { + detected_head_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) { + config.ffn_hidden_size = tensor_storage.ne[1]; + } else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) { + int64_t out_dim = tensor_storage.ne[1]; + int64_t patch_area = config.patch_size * config.patch_size; + config.out_channels = out_dim / patch_area; + } + + size_t pos = name.find("layers."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.num_layers) { + config.num_layers = block_index + 1; + } + } + } + } + if (config.num_layers == 0) { + config.num_layers = 36; + } + if (detected_head_dim > 0) { + config.num_heads = config.hidden_size / detected_head_dim; + } + config.axes_dim_sum = 0; + for (int axis_dim : config.axes_dim) { + config.axes_dim_sum += axis_dim; + } + LOG_DEBUG("ernie_image: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, + config.num_layers, + config.hidden_size, + config.num_heads, + config.ffn_hidden_size, + config.in_channels, + config.out_channels); + return config; + } + }; + __STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx, ggml_tensor* timesteps, int dim, @@ -208,51 +278,36 @@ namespace ErnieImage { } }; - struct ErnieImageParams { - int64_t hidden_size = 4096; - int64_t num_heads = 32; - int64_t num_layers = 36; - int64_t ffn_hidden_size = 12288; - int64_t in_channels = 128; - int64_t out_channels = 128; - int patch_size = 1; - int64_t text_in_dim = 3072; - int theta = 256; - std::vector axes_dim = {32, 48, 48}; - int axes_dim_sum = 128; - float eps = 1e-6f; - }; - class ErnieImageModel : public GGMLBlock { public: - ErnieImageParams params; + ErnieImageConfig config; ErnieImageModel() = default; - ErnieImageModel(ErnieImageParams params) - : params(params) { - blocks["x_embedder.proj"] = std::make_shared(params.in_channels, - params.hidden_size, - std::pair{params.patch_size, params.patch_size}, - std::pair{params.patch_size, params.patch_size}, + ErnieImageModel(ErnieImageConfig config) + : config(config) { + blocks["x_embedder.proj"] = std::make_shared(config.in_channels, + config.hidden_size, + std::pair{config.patch_size, config.patch_size}, + std::pair{config.patch_size, config.patch_size}, std::pair{0, 0}, std::pair{1, 1}, true); - if (params.text_in_dim != params.hidden_size) { - blocks["text_proj"] = std::make_shared(params.text_in_dim, params.hidden_size, false); + if (config.text_in_dim != config.hidden_size) { + blocks["text_proj"] = std::make_shared(config.text_in_dim, config.hidden_size, false); } - blocks["time_embedding"] = std::make_shared(params.hidden_size, params.hidden_size); - blocks["adaLN_modulation.1"] = std::make_shared(params.hidden_size, 6 * params.hidden_size, true); - - for (int i = 0; i < params.num_layers; i++) { - blocks["layers." + std::to_string(i)] = std::make_shared(params.hidden_size, - params.num_heads, - params.ffn_hidden_size, - params.eps); + blocks["time_embedding"] = std::make_shared(config.hidden_size, config.hidden_size); + blocks["adaLN_modulation.1"] = std::make_shared(config.hidden_size, 6 * config.hidden_size, true); + + for (int i = 0; i < config.num_layers; i++) { + blocks["layers." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.num_heads, + config.ffn_hidden_size, + config.eps); } - blocks["final_norm"] = std::make_shared(params.hidden_size, params.eps); - blocks["final_linear"] = std::make_shared(params.hidden_size, - params.patch_size * params.patch_size * params.out_channels, + blocks["final_norm"] = std::make_shared(config.hidden_size, config.eps); + blocks["final_linear"] = std::make_shared(config.hidden_size, + config.patch_size * config.patch_size * config.out_channels, true); } @@ -265,12 +320,12 @@ namespace ErnieImage { // context: [N, text_tokens, 3072] // pe: [image_tokens + text_tokens, head_dim/2, 2, 2] GGML_ASSERT(context != nullptr); - GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0); + GGML_ASSERT(x->ne[1] % config.patch_size == 0 && x->ne[0] % config.patch_size == 0); int64_t W = x->ne[0]; int64_t H = x->ne[1]; - int64_t Hp = H / params.patch_size; - int64_t Wp = W / params.patch_size; + int64_t Hp = H / config.patch_size; + int64_t Wp = W / config.patch_size; int64_t n_img = Hp * Wp; int64_t N = x->ne[3]; @@ -292,7 +347,7 @@ namespace ErnieImage { auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size] - auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast(params.hidden_size)); + auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast(config.hidden_size)); auto c = time_embedding->forward(ctx, sample); // [N, hidden_size] auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size] @@ -305,7 +360,7 @@ namespace ErnieImage { temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size] } - for (int i = 0; i < params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto layer = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); hidden_states = layer->forward(ctx, hidden_states, pe, temb); sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.layers." + std::to_string(i), "hidden_states"); @@ -319,15 +374,15 @@ namespace ErnieImage { patches, Hp, Wp, - params.patch_size, - params.patch_size, + config.patch_size, + config.patch_size, false); // [N, out_channels, H, W] return out; } }; struct ErnieImageRunner : public DiffusionModelRunner { - ErnieImageParams ernie_params; + ErnieImageConfig config; ErnieImageModel ernie_image; std::vector pe_vec; @@ -335,58 +390,9 @@ namespace ErnieImage { ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") - : DiffusionModelRunner(backend, params_backend, prefix) { - ernie_params.num_layers = 0; - for (const auto& [name, tensor_storage] : tensor_storage_map) { - if (!starts_with(name, prefix)) { - continue; - } - if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) { - ernie_params.patch_size = static_cast(tensor_storage.ne[0]); - ernie_params.in_channels = tensor_storage.ne[2]; - ernie_params.hidden_size = tensor_storage.ne[3]; - } else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) { - ernie_params.text_in_dim = tensor_storage.ne[0]; - } else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) { - int64_t head_dim = tensor_storage.ne[0]; - ernie_params.num_heads = ernie_params.hidden_size / head_dim; - } else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) { - ernie_params.ffn_hidden_size = tensor_storage.ne[1]; - } else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) { - int64_t out_dim = tensor_storage.ne[1]; - ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size; - } - - size_t pos = name.find("layers."); - if (pos != std::string::npos) { - std::string layer_name = name.substr(pos); - auto items = split_string(layer_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > ernie_params.num_layers) { - ernie_params.num_layers = block_index + 1; - } - } - } - } - if (ernie_params.num_layers == 0) { - ernie_params.num_layers = 36; - } - ernie_params.axes_dim_sum = 0; - for (int axis_dim : ernie_params.axes_dim) { - ernie_params.axes_dim_sum += axis_dim; - } - - LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64 - ", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, - ernie_params.num_layers, - ernie_params.hidden_size, - ernie_params.num_heads, - ernie_params.ffn_hidden_size, - ernie_params.in_channels, - ernie_params.out_channels); - - ernie_image = ErnieImageModel(ernie_params); + : DiffusionModelRunner(backend, params_backend, prefix), + config(ErnieImageConfig::detect_from_weights(tensor_storage_map, prefix)) { + ernie_image = ErnieImageModel(config); ernie_image.init(params_ctx, tensor_storage_map, prefix); } @@ -410,15 +416,15 @@ namespace ErnieImage { pe_vec = Rope::gen_ernie_image_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), - ernie_params.patch_size, + config.patch_size, static_cast(x->ne[3]), static_cast(context->ne[1]), - ernie_params.theta, + config.theta, circular_y_enabled, circular_x_enabled, - ernie_params.axes_dim); - int pos_len = static_cast(pe_vec.size() / ernie_params.axes_dim_sum / 2); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2); + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, config.axes_dim_sum, 1, pos_len, 2); set_backend_tensor_data(pe, pe_vec.data()); auto runner_ctx = get_context(); diff --git a/src/flux.hpp b/src/flux.hpp index 15a3f9228..4027f8fe0 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -13,6 +13,155 @@ namespace Flux { + struct ChromaRadianceConfig { + int64_t nerf_hidden_size = 64; + int nerf_mlp_ratio = 4; + int nerf_depth = 4; + int nerf_max_freqs = 8; + bool use_x0 = false; + bool fake_patch_size_x2 = false; + }; + + struct FluxConfig { + SDVersion version = VERSION_FLUX; + bool is_chroma = false; + int patch_size = 2; + int64_t in_channels = 64; + int64_t out_channels = 64; + int64_t vec_in_dim = 768; + int64_t context_in_dim = 4096; + int64_t hidden_size = 3072; + float mlp_ratio = 4.0f; + int num_heads = 24; + int depth = 19; + int depth_single_blocks = 38; + std::vector axes_dim = {16, 56, 56}; + int axes_dim_sum = 128; + int theta = 10000; + bool qkv_bias = true; + bool guidance_embed = true; + int64_t in_dim = 64; + bool disable_bias = false; + bool share_modulation = false; + bool semantic_txt_norm = false; + bool use_yak_mlp = false; + bool use_mlp_silu_act = false; + float ref_index_scale = 1.f; + ChromaRadianceConfig chroma_radiance_params; + + static FluxConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + SDVersion version = VERSION_FLUX) { + FluxConfig config; + config.version = version; + config.guidance_embed = false; + config.depth = 0; + config.depth_single_blocks = 0; + if (version == VERSION_FLUX_FILL) { + config.in_channels = 384; + } else if (version == VERSION_FLUX_CONTROLS) { + config.in_channels = 128; + } else if (version == VERSION_FLEX_2) { + config.in_channels = 196; + } else if (version == VERSION_CHROMA_RADIANCE) { + config.in_channels = 3; + config.patch_size = 16; + } else if (version == VERSION_OVIS_IMAGE) { + config.semantic_txt_norm = true; + config.use_yak_mlp = true; + config.vec_in_dim = 0; + } else if (sd_version_is_flux2(version)) { + config.in_channels = 128; + config.patch_size = 1; + config.out_channels = 128; + config.mlp_ratio = 3.f; + config.theta = 2000; + config.axes_dim = {32, 32, 32, 32}; + config.vec_in_dim = 0; + config.qkv_bias = false; + config.disable_bias = true; + config.share_modulation = true; + config.ref_index_scale = 10.f; + config.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + config.context_in_dim = 3584; + config.vec_in_dim = 0; + } + + int64_t head_dim = 0; + int64_t actual_radiance_patch_size = -1; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (name.find("guidance_in.in_layer.weight") != std::string::npos) { + config.guidance_embed = true; + } + if (name.find("__x0__") != std::string::npos) { + LOG_DEBUG("using x0 prediction"); + config.chroma_radiance_params.use_x0 = true; + } + if (name.find("__32x32__") != std::string::npos) { + LOG_DEBUG("using patch size 32"); + config.patch_size = 32; + } + if (name.find("img_in_patch.weight") != std::string::npos) { + actual_radiance_patch_size = tensor_storage.ne[0]; + LOG_DEBUG("actual radiance patch size: %" PRId64, actual_radiance_patch_size); + } + if (name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + config.is_chroma = true; + } + size_t db = name.find("double_blocks."); + if (db != std::string::npos) { + std::string block_name = name.substr(db); + int block_depth = atoi(block_name.substr(14, block_name.find(".", 14)).c_str()); + if (block_depth + 1 > config.depth) { + config.depth = block_depth + 1; + } + } + size_t sb = name.find("single_blocks."); + if (sb != std::string::npos) { + std::string block_name = name.substr(sb); + int block_depth = atoi(block_name.substr(14, block_name.find(".", 14)).c_str()); + if (block_depth + 1 > config.depth_single_blocks) { + config.depth_single_blocks = block_depth + 1; + } + } + if (ends_with(name, "txt_in.weight")) { + config.context_in_dim = tensor_storage.ne[0]; + config.hidden_size = tensor_storage.ne[1]; + } + if (ends_with(name, "single_blocks.0.norm.key_norm.scale")) { + head_dim = tensor_storage.ne[0]; + } + if (ends_with(name, "double_blocks.0.txt_attn.norm.key_norm.scale")) { + head_dim = tensor_storage.ne[0]; + } + } + if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != config.patch_size) { + GGML_ASSERT(config.patch_size == 2 * actual_radiance_patch_size); + LOG_DEBUG("using fake x2 patch size"); + config.chroma_radiance_params.fake_patch_size_x2 = true; + } + if (head_dim > 0) { + config.num_heads = static_cast(config.hidden_size / head_dim); + } + config.axes_dim_sum = 0; + for (int axis_dim : config.axes_dim) { + config.axes_dim_sum += axis_dim; + } + LOG_DEBUG("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %d", + config.depth, + config.depth_single_blocks, + config.guidance_embed ? "true" : "false", + config.context_in_dim, + config.hidden_size, + config.num_heads); + return config; + } + }; + struct MLPEmbedder : public UnaryBlock { public: MLPEmbedder(int64_t in_dim, int64_t hidden_dim, bool bias = true) { @@ -723,127 +872,90 @@ namespace Flux { } }; - struct ChromaRadianceParams { - int64_t nerf_hidden_size = 64; - int nerf_mlp_ratio = 4; - int nerf_depth = 4; - int nerf_max_freqs = 8; - bool use_x0 = false; - bool fake_patch_size_x2 = false; - }; - - struct FluxParams { - SDVersion version = VERSION_FLUX; - bool is_chroma = false; - int patch_size = 2; - int64_t in_channels = 64; - int64_t out_channels = 64; - int64_t vec_in_dim = 768; - int64_t context_in_dim = 4096; - int64_t hidden_size = 3072; - float mlp_ratio = 4.0f; - int num_heads = 24; - int depth = 19; - int depth_single_blocks = 38; - std::vector axes_dim = {16, 56, 56}; - int axes_dim_sum = 128; - int theta = 10000; - bool qkv_bias = true; - bool guidance_embed = true; - int64_t in_dim = 64; - bool disable_bias = false; - bool share_modulation = false; - bool semantic_txt_norm = false; - bool use_yak_mlp = false; - bool use_mlp_silu_act = false; - float ref_index_scale = 1.f; - ChromaRadianceParams chroma_radiance_params; - }; - struct Flux : public GGMLBlock { public: - FluxParams params; + FluxConfig config; Flux() {} - Flux(FluxParams params) - : params(params) { - if (params.version == VERSION_CHROMA_RADIANCE) { - std::pair kernel_size = {params.patch_size, params.patch_size}; - if (params.chroma_radiance_params.fake_patch_size_x2) { - kernel_size = {params.patch_size / 2, params.patch_size / 2}; + Flux(FluxConfig config) + : config(config) { + if (config.version == VERSION_CHROMA_RADIANCE) { + std::pair kernel_size = {config.patch_size, config.patch_size}; + if (config.chroma_radiance_params.fake_patch_size_x2) { + kernel_size = {config.patch_size / 2, config.patch_size / 2}; } std::pair stride = kernel_size; - blocks["img_in_patch"] = std::make_shared(params.in_channels, - params.hidden_size, + blocks["img_in_patch"] = std::make_shared(config.in_channels, + config.hidden_size, kernel_size, stride); } else { - blocks["img_in"] = std::make_shared(params.in_channels, params.hidden_size, !params.disable_bias); + blocks["img_in"] = std::make_shared(config.in_channels, config.hidden_size, !config.disable_bias); } - if (params.is_chroma) { - blocks["distilled_guidance_layer"] = std::make_shared(params.in_dim, params.hidden_size); + if (config.is_chroma) { + blocks["distilled_guidance_layer"] = std::make_shared(config.in_dim, config.hidden_size); } else { - blocks["time_in"] = std::make_shared(256, params.hidden_size, !params.disable_bias); - if (params.vec_in_dim > 0) { - blocks["vector_in"] = std::make_shared(params.vec_in_dim, params.hidden_size, !params.disable_bias); + blocks["time_in"] = std::make_shared(256, config.hidden_size, !config.disable_bias); + if (config.vec_in_dim > 0) { + blocks["vector_in"] = std::make_shared(config.vec_in_dim, config.hidden_size, !config.disable_bias); } - if (params.guidance_embed) { - blocks["guidance_in"] = std::make_shared(256, params.hidden_size, !params.disable_bias); + if (config.guidance_embed) { + blocks["guidance_in"] = std::make_shared(256, config.hidden_size, !config.disable_bias); } } - if (params.semantic_txt_norm) { - blocks["txt_norm"] = std::make_shared(params.context_in_dim); + if (config.semantic_txt_norm) { + blocks["txt_norm"] = std::make_shared(config.context_in_dim); } - blocks["txt_in"] = std::make_shared(params.context_in_dim, params.hidden_size, !params.disable_bias); + blocks["txt_in"] = std::make_shared(config.context_in_dim, config.hidden_size, !config.disable_bias); - for (int i = 0; i < params.depth; i++) { - blocks["double_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, - params.num_heads, - params.mlp_ratio, + for (int i = 0; i < config.depth; i++) { + blocks["double_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.num_heads, + config.mlp_ratio, i, - params.qkv_bias, - params.is_chroma, - params.share_modulation, - !params.disable_bias, - params.use_yak_mlp, - params.use_mlp_silu_act); + config.qkv_bias, + config.is_chroma, + config.share_modulation, + !config.disable_bias, + config.use_yak_mlp, + config.use_mlp_silu_act); } - for (int i = 0; i < params.depth_single_blocks; i++) { - blocks["single_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, - params.num_heads, - params.mlp_ratio, + for (int i = 0; i < config.depth_single_blocks; i++) { + blocks["single_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.num_heads, + config.mlp_ratio, i, 0.f, - params.is_chroma, - params.share_modulation, - !params.disable_bias, - params.use_yak_mlp, - params.use_mlp_silu_act); + config.is_chroma, + config.share_modulation, + !config.disable_bias, + config.use_yak_mlp, + config.use_mlp_silu_act); } - if (params.version == VERSION_CHROMA_RADIANCE) { - blocks["nerf_image_embedder"] = std::make_shared(params.in_channels, - params.chroma_radiance_params.nerf_hidden_size, - params.chroma_radiance_params.nerf_max_freqs); + if (config.version == VERSION_CHROMA_RADIANCE) { + blocks["nerf_image_embedder"] = std::make_shared(config.in_channels, + config.chroma_radiance_params.nerf_hidden_size, + config.chroma_radiance_params.nerf_max_freqs); - for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) { - blocks["nerf_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, - params.chroma_radiance_params.nerf_hidden_size, - params.chroma_radiance_params.nerf_mlp_ratio); + for (int i = 0; i < config.chroma_radiance_params.nerf_depth; i++) { + blocks["nerf_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.chroma_radiance_params.nerf_hidden_size, + config.chroma_radiance_params.nerf_mlp_ratio); } - blocks["nerf_final_layer_conv"] = std::make_shared(params.chroma_radiance_params.nerf_hidden_size, - params.in_channels); + blocks["nerf_final_layer_conv"] = std::make_shared(config.chroma_radiance_params.nerf_hidden_size, + config.in_channels); } else { - blocks["final_layer"] = std::make_shared(params.hidden_size, 1, params.out_channels, params.is_chroma, !params.disable_bias); + blocks["final_layer"] = std::make_shared(config.hidden_size, 1, config.out_channels, config.is_chroma, !config.disable_bias); } - if (params.share_modulation) { - blocks["double_stream_modulation_img"] = std::make_shared(params.hidden_size, true, !params.disable_bias); - blocks["double_stream_modulation_txt"] = std::make_shared(params.hidden_size, true, !params.disable_bias); - blocks["single_stream_modulation"] = std::make_shared(params.hidden_size, false, !params.disable_bias); + if (config.share_modulation) { + blocks["double_stream_modulation_img"] = std::make_shared(config.hidden_size, true, !config.disable_bias); + blocks["double_stream_modulation_txt"] = std::make_shared(config.hidden_size, true, !config.disable_bias); + blocks["single_stream_modulation"] = std::make_shared(config.hidden_size, false, !config.disable_bias); } } @@ -866,7 +978,7 @@ namespace Flux { ggml_tensor* vec; ggml_tensor* txt_img_mask = nullptr; - if (params.is_chroma) { + if (config.is_chroma) { int64_t mod_index_length = 344; auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); auto distill_timestep = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 16, 10000, 1000.f); @@ -894,7 +1006,7 @@ namespace Flux { } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f)); - if (params.guidance_embed) { + if (config.guidance_embed) { GGML_ASSERT(guidance != nullptr); auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); // bf16 and fp16 result is different @@ -902,7 +1014,7 @@ namespace Flux { vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in)); } - if (params.vec_in_dim > 0) { + if (config.vec_in_dim > 0) { auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y)); } @@ -911,7 +1023,7 @@ namespace Flux { std::vector ds_img_mods; std::vector ds_txt_mods; std::vector ss_mods; - if (params.share_modulation) { + if (config.share_modulation) { auto double_stream_modulation_img = std::dynamic_pointer_cast(blocks["double_stream_modulation_img"]); auto double_stream_modulation_txt = std::dynamic_pointer_cast(blocks["double_stream_modulation_txt"]); auto single_stream_modulation = std::dynamic_pointer_cast(blocks["single_stream_modulation"]); @@ -921,7 +1033,7 @@ namespace Flux { ss_mods = single_stream_modulation->forward(ctx, vec); } - if (params.semantic_txt_norm) { + if (config.semantic_txt_norm) { auto semantic_txt_norm = std::dynamic_pointer_cast(blocks["txt_norm"]); txt = semantic_txt_norm->forward(ctx, txt); @@ -932,7 +1044,7 @@ namespace Flux { sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt"); sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec"); - for (int i = 0; i < params.depth; i++) { + for (int i = 0; i < config.depth; i++) { if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { continue; } @@ -947,8 +1059,8 @@ namespace Flux { } auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] - for (int i = 0; i < params.depth_single_blocks; i++) { - if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) { + for (int i = 0; i < config.depth_single_blocks; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + config.depth) != skip_layers.end()) { continue; } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); @@ -999,14 +1111,14 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t C = x->ne[2]; - int patch_size = params.patch_size; + int patch_size = config.patch_size; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = DiT::pad_to_patch_size(ctx, x, params.patch_size, params.patch_size); + auto img = DiT::pad_to_patch_size(ctx, x, config.patch_size, config.patch_size); auto orig_img = img; - if (params.chroma_radiance_params.fake_patch_size_x2) { + if (config.chroma_radiance_params.fake_patch_size_x2) { // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // img = F.interpolate(img, size=(H//2, W//2), mode="nearest") @@ -1037,7 +1149,7 @@ namespace Flux { auto nerf_hidden = ggml_reshape_2d(ctx->ggml_ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size] auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size] - for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) { + for (int i = 0; i < config.chroma_radiance_params.nerf_depth; i++) { auto block = std::dynamic_pointer_cast(blocks["nerf_blocks." + std::to_string(i)]); img_dct = block->forward(ctx, img_dct, nerf_hidden); @@ -1049,7 +1161,7 @@ namespace Flux { out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] - if (params.chroma_radiance_params.use_x0) { + if (config.chroma_radiance_params.use_x0) { out = _apply_x0_residual(ctx, out, orig_img, timestep); } @@ -1073,14 +1185,14 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t C = x->ne[2]; - int patch_size = params.patch_size; + int patch_size = config.patch_size; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); int64_t img_tokens = img->ne[1]; - if (params.version == VERSION_FLUX_FILL) { + if (config.version == VERSION_FLUX_FILL) { GGML_ASSERT(c_concat != nullptr); ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -1089,7 +1201,7 @@ namespace Flux { mask = DiT::pad_and_patchify(ctx, mask, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); - } else if (params.version == VERSION_FLEX_2) { + } else if (config.version == VERSION_FLEX_2) { GGML_ASSERT(c_concat != nullptr); ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -1100,7 +1212,7 @@ namespace Flux { control = DiT::pad_and_patchify(ctx, control, patch_size, patch_size); img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); - } else if (params.version == VERSION_FLUX_CONTROLS) { + } else if (config.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); auto control = DiT::pad_and_patchify(ctx, c_concat, patch_size, patch_size); @@ -1147,7 +1259,7 @@ namespace Flux { // pe: (L, d_head/2, 2, 2) // return: (N, C, H, W) - if (params.version == VERSION_CHROMA_RADIANCE) { + if (config.version == VERSION_CHROMA_RADIANCE) { return forward_chroma_radiance(ctx, x, timestep, @@ -1179,7 +1291,7 @@ namespace Flux { struct FluxRunner : public DiffusionModelRunner { public: - FluxParams flux_params; + FluxConfig config; Flux flux; std::vector pe_vec; std::vector mod_index_arange_vec; @@ -1194,114 +1306,15 @@ namespace Flux { const std::string prefix = "", SDVersion version = VERSION_FLUX, bool use_mask = false) - : DiffusionModelRunner(backend, params_backend, prefix), version(version), use_mask(use_mask) { - flux_params.version = version; - flux_params.guidance_embed = false; - flux_params.depth = 0; - flux_params.depth_single_blocks = 0; - if (version == VERSION_FLUX_FILL) { - flux_params.in_channels = 384; - } else if (version == VERSION_FLUX_CONTROLS) { - flux_params.in_channels = 128; - } else if (version == VERSION_FLEX_2) { - flux_params.in_channels = 196; - } else if (version == VERSION_CHROMA_RADIANCE) { - flux_params.in_channels = 3; - flux_params.patch_size = 16; - } else if (version == VERSION_OVIS_IMAGE) { - flux_params.semantic_txt_norm = true; - flux_params.use_yak_mlp = true; - flux_params.vec_in_dim = 0; - } else if (sd_version_is_flux2(version)) { - flux_params.in_channels = 128; - flux_params.patch_size = 1; - flux_params.out_channels = 128; - flux_params.mlp_ratio = 3.f; - flux_params.theta = 2000; - flux_params.axes_dim = {32, 32, 32, 32}; - flux_params.vec_in_dim = 0; - flux_params.qkv_bias = false; - flux_params.disable_bias = true; - flux_params.share_modulation = true; - flux_params.ref_index_scale = 10.f; - flux_params.use_mlp_silu_act = true; - } else if (sd_version_is_longcat(version)) { - flux_params.context_in_dim = 3584; - flux_params.vec_in_dim = 0; - } - int64_t head_dim = 0; - int64_t actual_radiance_patch_size = -1; - for (auto pair : tensor_storage_map) { - std::string tensor_name = pair.first; - if (!starts_with(tensor_name, prefix)) - continue; - if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { - flux_params.guidance_embed = true; - } - if (tensor_name.find("__x0__") != std::string::npos) { - LOG_DEBUG("using x0 prediction"); - flux_params.chroma_radiance_params.use_x0 = true; - } - if (tensor_name.find("__32x32__") != std::string::npos) { - LOG_DEBUG("using patch size 32"); - flux_params.patch_size = 32; - } - if (tensor_name.find("img_in_patch.weight") != std::string::npos) { - actual_radiance_patch_size = pair.second.ne[0]; - LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size); - } - if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { - // Chroma - flux_params.is_chroma = true; - } - size_t db = tensor_name.find("double_blocks."); - if (db != std::string::npos) { - tensor_name = tensor_name.substr(db); // remove prefix - int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); - if (block_depth + 1 > flux_params.depth) { - flux_params.depth = block_depth + 1; - } - } - size_t sb = tensor_name.find("single_blocks."); - if (sb != std::string::npos) { - tensor_name = tensor_name.substr(sb); // remove prefix - int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); - if (block_depth + 1 > flux_params.depth_single_blocks) { - flux_params.depth_single_blocks = block_depth + 1; - } - } - if (ends_with(tensor_name, "txt_in.weight")) { - flux_params.context_in_dim = pair.second.ne[0]; - flux_params.hidden_size = pair.second.ne[1]; - } - if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) { - head_dim = pair.second.ne[0]; - } - if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) { - head_dim = pair.second.ne[0]; - } - } - if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) { - GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size); - LOG_DEBUG("using fake x2 patch size"); - flux_params.chroma_radiance_params.fake_patch_size_x2 = true; - } - - flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim); - - LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64 - ", hidden_size = %" PRId64 ", num_heads = %d", - flux_params.depth, - flux_params.depth_single_blocks, - flux_params.guidance_embed ? "true" : "false", - flux_params.context_in_dim, - flux_params.hidden_size, - flux_params.num_heads); - if (flux_params.is_chroma) { + : DiffusionModelRunner(backend, params_backend, prefix), + config(FluxConfig::detect_from_weights(tensor_storage_map, prefix, version)), + version(version), + use_mask(use_mask) { + if (config.is_chroma) { LOG_INFO("Using pruned modulation (Chroma)"); } - flux = Flux(flux_params); + flux = Flux(config); flux.init(params_ctx, tensor_storage_map, prefix); } @@ -1377,10 +1390,10 @@ namespace Flux { ggml_tensor* context = make_optional_input(context_tensor); ggml_tensor* c_concat = make_optional_input(c_concat_tensor); ggml_tensor* y = make_optional_input(y_tensor); - if (flux_params.guidance_embed || flux_params.is_chroma) { + if (config.guidance_embed || config.is_chroma) { if (!guidance_tensor.empty()) { this->guidance_tensor = guidance_tensor; - if (flux_params.is_chroma) { + if (config.is_chroma) { this->guidance_tensor.fill_(0.f); } } @@ -1398,7 +1411,7 @@ namespace Flux { ggml_tensor* mod_index_arange = nullptr; ggml_tensor* dct = nullptr; // for chroma radiance - if (flux_params.is_chroma) { + if (config.is_chroma) { if (!use_mask) { y = nullptr; } @@ -1417,29 +1430,29 @@ namespace Flux { } pe_vec = Rope::gen_flux_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), - flux_params.patch_size, + config.patch_size, static_cast(x->ne[3]), static_cast(context->ne[1]), txt_arange_dims, ref_latents, increase_ref_index, - flux_params.ref_index_scale, - flux_params.theta, + config.ref_index_scale, + config.theta, circular_y_enabled, circular_x_enabled, - flux_params.axes_dim, + config.axes_dim, sd_version_is_longcat(version)); - int pos_len = static_cast(pe_vec.size() / flux_params.axes_dim_sum / 2); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); if (version == VERSION_CHROMA_RADIANCE) { - int patch_size = flux_params.patch_size; - int nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs; + int patch_size = config.patch_size; + int nerf_max_freqs = config.chroma_radiance_params.nerf_max_freqs; dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs); dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size); // dct->data = dct_vec.data(); diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 1f32c9bc9..0e1ac4daa 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1707,7 +1707,7 @@ struct GGMLRunner { uint64_t resident_state_token = 0; size_t max_graph_vram_bytes = 0; - bool stream_layers_enabled = false; + bool stream_layers_enabled = false; size_t observed_max_effective_budget_ = 0; sd::layer_registry::LayerRegistry layer_registry_; diff --git a/src/hidream_o1.hpp b/src/hidream_o1.hpp index c85e04b95..bf64d2f4f 100644 --- a/src/hidream_o1.hpp +++ b/src/hidream_o1.hpp @@ -23,6 +23,39 @@ namespace HiDreamO1 { constexpr int IMAGE_TOKEN_ID = 151655; constexpr int VISION_START_TOKEN_ID = 151652; + struct HiDreamO1Config { + LLM::LLMConfig llm; + int patch_size = PATCH_SIZE; + + static HiDreamO1Config detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + (void)tensor_storage_map; + (void)prefix; + HiDreamO1Config config; + config.llm.arch = LLM::LLMArch::QWEN3_VL; + config.llm.hidden_size = 4096; + config.llm.intermediate_size = 12288; + config.llm.num_layers = 36; + config.llm.num_heads = 32; + config.llm.num_kv_heads = 8; + config.llm.head_dim = 128; + config.llm.qkv_bias = false; + config.llm.qk_norm = true; + config.llm.vocab_size = 151936; + config.llm.rms_norm_eps = 1e-6f; + config.llm.vision.arch = LLM::LLMVisionArch::QWEN3_VL; + config.llm.vision.num_layers = 27; + config.llm.vision.hidden_size = 1152; + config.llm.vision.intermediate_size = 4304; + config.llm.vision.num_heads = 16; + config.llm.vision.out_hidden_size = 4096; + config.llm.vision.patch_size = 16; + config.llm.vision.spatial_merge_size = 2; + config.llm.vision.temporal_patch_size = 2; + config.llm.vision.num_position_embeddings = 2304; + return config; + } + }; + static inline std::string repeat_special_token(const std::string& token, int64_t count) { std::string out; out.reserve(static_cast(count) * token.size()); @@ -205,50 +238,19 @@ namespace HiDreamO1 { } }; - struct HiDreamO1Params { - LLM::LLMParams llm; - int patch_size = PATCH_SIZE; - }; - - static inline HiDreamO1Params make_hidream_o1_params() { - HiDreamO1Params params; - params.llm.arch = LLM::LLMArch::QWEN3_VL; - params.llm.hidden_size = 4096; - params.llm.intermediate_size = 12288; - params.llm.num_layers = 36; - params.llm.num_heads = 32; - params.llm.num_kv_heads = 8; - params.llm.head_dim = 128; - params.llm.qkv_bias = false; - params.llm.qk_norm = true; - params.llm.vocab_size = 151936; - params.llm.rms_norm_eps = 1e-6f; - params.llm.vision.arch = LLM::LLMVisionArch::QWEN3_VL; - params.llm.vision.num_layers = 27; - params.llm.vision.hidden_size = 1152; - params.llm.vision.intermediate_size = 4304; - params.llm.vision.num_heads = 16; - params.llm.vision.out_hidden_size = 4096; - params.llm.vision.patch_size = 16; - params.llm.vision.spatial_merge_size = 2; - params.llm.vision.temporal_patch_size = 2; - params.llm.vision.num_position_embeddings = 2304; - return params; - } - struct HiDreamO1Model : public GGMLBlock { - HiDreamO1Params params; + HiDreamO1Config config; HiDreamO1Model() = default; - explicit HiDreamO1Model(HiDreamO1Params params) - : params(std::move(params)) { - blocks["language_model"] = std::make_shared(this->params.llm); - blocks["t_embedder1"] = std::make_shared(this->params.llm.hidden_size); - blocks["x_embedder"] = std::make_shared(this->params.patch_size * this->params.patch_size * 3, - this->params.llm.hidden_size / 4, - this->params.llm.hidden_size); - blocks["final_layer2"] = std::make_shared(this->params.llm.hidden_size, - this->params.patch_size * this->params.patch_size * 3); + explicit HiDreamO1Model(HiDreamO1Config config) + : config(std::move(config)) { + blocks["language_model"] = std::make_shared(this->config.llm); + blocks["t_embedder1"] = std::make_shared(this->config.llm.hidden_size); + blocks["x_embedder"] = std::make_shared(this->config.patch_size * this->config.patch_size * 3, + this->config.llm.hidden_size / 4, + this->config.llm.hidden_size); + blocks["final_layer2"] = std::make_shared(this->config.llm.hidden_size, + this->config.patch_size * this->config.patch_size * 3); } std::shared_ptr text_model() { @@ -269,7 +271,7 @@ namespace HiDreamO1 { }; struct HiDreamO1VisionRunner : public GGMLRunner { - HiDreamO1Params params; + HiDreamO1Config config; std::shared_ptr model; std::vector window_index_vec; @@ -284,8 +286,8 @@ namespace HiDreamO1 { const String2TensorStorage& tensor_storage_map = {}, const std::string& prefix = "model.visual") : GGMLRunner(backend, params_backend), - params(make_hidream_o1_params()), - model(std::make_shared(false, params.llm.vision)) { + config(HiDreamO1Config::detect_from_weights(tensor_storage_map, prefix)), + model(std::make_shared(false, config.llm.vision)) { model->init(params_ctx, tensor_storage_map, prefix); } @@ -302,7 +304,7 @@ namespace HiDreamO1 { compute_ctx, runner_ctx, image, - params.llm.vision, + config.llm.vision, model, window_index_vec, window_inverse_index_vec, @@ -331,7 +333,7 @@ namespace HiDreamO1 { }; struct HiDreamO1Runner : public DiffusionModelRunner { - HiDreamO1Params params; + HiDreamO1Config config; HiDreamO1Model model; std::vector attention_mask_vec; @@ -341,8 +343,8 @@ namespace HiDreamO1 { const String2TensorStorage& tensor_storage_map = {}, const std::string& prefix = "model") : DiffusionModelRunner(backend, params_backend, prefix), - params(make_hidream_o1_params()) { - model = HiDreamO1Model(params); + config(HiDreamO1Config::detect_from_weights(tensor_storage_map, prefix)) { + model = HiDreamO1Model(config); model.init(params_ctx, tensor_storage_map, prefix); } diff --git a/src/ideogram4.hpp b/src/ideogram4.hpp index 58cd7638a..46193adca 100644 --- a/src/ideogram4.hpp +++ b/src/ideogram4.hpp @@ -38,6 +38,34 @@ namespace Ideogram4 { std::vector mrope_section = {DEFAULT_MROPE_SECTION_T, DEFAULT_MROPE_SECTION_H, DEFAULT_MROPE_SECTION_W}; + + static Ideogram4Config detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix) { + Ideogram4Config config; + int64_t detected_layers = 0; + std::string layer_prefix = prefix.empty() ? "layers." : prefix + ".layers."; + for (const auto& [name, _] : tensor_storage_map) { + if (name.find(layer_prefix) != 0) { + continue; + } + std::string tail = name.substr(layer_prefix.size()); + size_t dot = tail.find('.'); + if (dot == std::string::npos) { + continue; + } + int layer_idx = std::atoi(tail.substr(0, dot).c_str()); + detected_layers = std::max(detected_layers, layer_idx + 1); + } + if (detected_layers > 0) { + config.num_layers = detected_layers; + LOG_DEBUG("ideogram4: num_layers = %" PRId64 ", emb_dim = %" PRId64 ", num_heads = %" PRId64 ", intermediate_size = %" PRId64, + config.num_layers, + config.emb_dim, + config.num_heads, + config.intermediate_size); + } + return config; + } }; __STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx, @@ -380,26 +408,6 @@ namespace Ideogram4 { class Ideogram4Runner : public DiffusionModelRunner { protected: - static int64_t detect_num_layers(const String2TensorStorage& tensor_storage_map, - const std::string& prefix) { - int64_t detected_layers = 0; - std::string layer_prefix = prefix.empty() ? "layers." : prefix + ".layers."; - for (const auto& pair : tensor_storage_map) { - const std::string& name = pair.first; - if (name.find(layer_prefix) != 0) { - continue; - } - std::string tail = name.substr(layer_prefix.size()); - size_t dot = tail.find('.'); - if (dot == std::string::npos) { - continue; - } - int layer_idx = std::atoi(tail.substr(0, dot).c_str()); - detected_layers = std::max(detected_layers, layer_idx + 1); - } - return detected_layers; - } - bool should_use_uncond_model(const DiffusionParams& diffusion_params) const { return has_uncond_model && diffusion_params.context == nullptr && @@ -421,12 +429,8 @@ namespace Ideogram4 { const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") : DiffusionModelRunner(backend, params_backend, prefix), + config(Ideogram4Config::detect_from_weights(tensor_storage_map, prefix)), uncond_prefix(prefix + ".uncond") { - int64_t detected_layers = detect_num_layers(tensor_storage_map, prefix); - if (detected_layers > 0) { - config.num_layers = detected_layers; - } - model = Ideogram4Transformer(config); model.init(params_ctx, tensor_storage_map, prefix); for (const auto& pair : tensor_storage_map) { diff --git a/src/lens.hpp b/src/lens.hpp index b5ff06832..072c9e08d 100644 --- a/src/lens.hpp +++ b/src/lens.hpp @@ -13,6 +13,71 @@ namespace Lens { constexpr int LENS_GRAPH_SIZE = 40960; + struct LensConfig { + int patch_size = 2; + int64_t in_channels = 128; + int64_t out_channels = 32; + int num_layers = 48; + int64_t attention_head_dim = 64; + int64_t num_attention_heads = 24; + int64_t joint_attention_dim = 2880; + int selected_layer_count = 4; + int theta = 10000; + std::vector axes_dim = {8, 28, 28}; + int axes_dim_sum = 64; + + static LensConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + LensConfig config; + config.num_layers = 0; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "img_in.weight") && tensor_storage.n_dims == 2) { + config.in_channels = tensor_storage.ne[0]; + int64_t inner_dim = tensor_storage.ne[1]; + if (config.attention_head_dim > 0) { + config.num_attention_heads = inner_dim / config.attention_head_dim; + } + } else if (ends_with(name, "txt_in.weight") && tensor_storage.n_dims == 2) { + config.selected_layer_count = static_cast(tensor_storage.ne[0] / config.joint_attention_dim); + } else if (ends_with(name, "proj_out.weight") && tensor_storage.n_dims == 2) { + int64_t patch_area = config.patch_size * config.patch_size; + config.out_channels = tensor_storage.ne[1] / patch_area; + } else if (ends_with(name, "transformer_blocks.0.attn.norm_q.weight") && tensor_storage.n_dims == 1) { + config.attention_head_dim = tensor_storage.ne[0]; + } + + size_t pos = name.find("transformer_blocks."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.num_layers) { + config.num_layers = block_index + 1; + } + } + } + } + if (config.num_layers == 0) { + config.num_layers = 48; + } + config.axes_dim_sum = 0; + for (int axis_dim : config.axes_dim) { + config.axes_dim_sum += axis_dim; + } + LOG_DEBUG("lens: num_layers = %d, selected_layer_count = %d, hidden_size = %" PRId64 ", num_attention_heads = %" PRId64 ", attention_head_dim = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, + config.num_layers, + config.selected_layer_count, + config.num_attention_heads * config.attention_head_dim, + config.num_attention_heads, + config.attention_head_dim, + config.in_channels, + config.out_channels); + return config; + } + }; + struct LensTimestepProjEmbeddings : public GGMLBlock { LensTimestepProjEmbeddings(int64_t embedding_dim) { blocks["timestep_embedder"] = std::make_shared(256, embedding_dim); @@ -209,41 +274,27 @@ namespace Lens { } }; - struct LensParams { - int patch_size = 2; - int64_t in_channels = 128; - int64_t out_channels = 32; - int num_layers = 48; - int64_t attention_head_dim = 64; - int64_t num_attention_heads = 24; - int64_t joint_attention_dim = 2880; - int selected_layer_count = 4; - int theta = 10000; - std::vector axes_dim = {8, 28, 28}; - int axes_dim_sum = 64; - }; - class LensModel : public GGMLBlock { public: - LensParams params; + LensConfig config; LensModel() = default; - LensModel(LensParams params) - : params(params) { - int64_t inner_dim = params.num_attention_heads * params.attention_head_dim; + LensModel(LensConfig config) + : config(config) { + int64_t inner_dim = config.num_attention_heads * config.attention_head_dim; blocks["time_text_embed"] = std::make_shared(inner_dim); - blocks["img_in"] = std::make_shared(params.in_channels, inner_dim, true); - blocks["txt_in"] = std::make_shared(params.joint_attention_dim * params.selected_layer_count, inner_dim, true); - for (int i = 0; i < params.selected_layer_count; ++i) { - blocks["txt_norm." + std::to_string(i)] = std::make_shared(params.joint_attention_dim, 1e-5f); + blocks["img_in"] = std::make_shared(config.in_channels, inner_dim, true); + blocks["txt_in"] = std::make_shared(config.joint_attention_dim * config.selected_layer_count, inner_dim, true); + for (int i = 0; i < config.selected_layer_count; ++i) { + blocks["txt_norm." + std::to_string(i)] = std::make_shared(config.joint_attention_dim, 1e-5f); } - for (int i = 0; i < params.num_layers; ++i) { + for (int i = 0; i < config.num_layers; ++i) { blocks["transformer_blocks." + std::to_string(i)] = std::make_shared(inner_dim, - params.num_attention_heads, - params.attention_head_dim); + config.num_attention_heads, + config.attention_head_dim); } blocks["norm_out"] = std::make_shared(inner_dim, 1e-6f); - blocks["proj_out"] = std::make_shared(inner_dim, params.patch_size * params.patch_size * params.out_channels, true); + blocks["proj_out"] = std::make_shared(inner_dim, config.patch_size * config.patch_size * config.out_channels, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -269,9 +320,9 @@ namespace Lens { img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); img = img_in->forward(ctx, img); - std::vector txt_chunks = ggml_ext_chunk(ctx->ggml_ctx, context, params.selected_layer_count, 0); + std::vector txt_chunks = ggml_ext_chunk(ctx->ggml_ctx, context, config.selected_layer_count, 0); ggml_tensor* txt = nullptr; - for (int i = 0; i < params.selected_layer_count; ++i) { + for (int i = 0; i < config.selected_layer_count; ++i) { auto txt_norm = std::dynamic_pointer_cast(blocks["txt_norm." + std::to_string(i)]); auto chunk = txt_norm->forward(ctx, txt_chunks[i]); txt = txt == nullptr ? chunk : ggml_concat(ctx->ggml_ctx, txt, chunk, 0); @@ -281,7 +332,7 @@ namespace Lens { sd::ggml_graph_cut::mark_graph_cut(img, "lens.prelude", "img"); sd::ggml_graph_cut::mark_graph_cut(txt, "lens.prelude", "txt"); - for (int i = 0; i < params.num_layers; ++i) { + for (int i = 0; i < config.num_layers; ++i) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); auto out = block->forward(ctx, img, txt, t_emb, pe); img = out.first; @@ -294,13 +345,13 @@ namespace Lens { img = proj_out->forward(ctx, img); auto out = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); - out = ggml_reshape_4d(ctx->ggml_ctx, out, W, H, params.patch_size * params.patch_size * params.out_channels, N); + out = ggml_reshape_4d(ctx->ggml_ctx, out, W, H, config.patch_size * config.patch_size * config.out_channels, N); return out; } }; struct LensRunner : public DiffusionModelRunner { - LensParams lens_params; + LensConfig config; LensModel lens; std::vector pe_vec; @@ -308,53 +359,9 @@ namespace Lens { ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") - : DiffusionModelRunner(backend, params_backend, prefix) { - lens_params.num_layers = 0; - for (const auto& [name, tensor_storage] : tensor_storage_map) { - if (!starts_with(name, prefix)) { - continue; - } - if (ends_with(name, "img_in.weight") && tensor_storage.n_dims == 2) { - lens_params.in_channels = tensor_storage.ne[0]; - int64_t inner_dim = tensor_storage.ne[1]; - lens_params.num_attention_heads = inner_dim / lens_params.attention_head_dim; - } else if (ends_with(name, "txt_in.weight") && tensor_storage.n_dims == 2) { - lens_params.selected_layer_count = static_cast(tensor_storage.ne[0] / lens_params.joint_attention_dim); - } else if (ends_with(name, "proj_out.weight") && tensor_storage.n_dims == 2) { - lens_params.out_channels = tensor_storage.ne[1] / lens_params.patch_size / lens_params.patch_size; - } else if (ends_with(name, "transformer_blocks.0.attn.norm_q.weight") && tensor_storage.n_dims == 1) { - lens_params.attention_head_dim = tensor_storage.ne[0]; - } - - size_t pos = name.find("transformer_blocks."); - if (pos != std::string::npos) { - std::string layer_name = name.substr(pos); - auto items = split_string(layer_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > lens_params.num_layers) { - lens_params.num_layers = block_index + 1; - } - } - } - } - if (lens_params.num_layers == 0) { - lens_params.num_layers = 48; - } - lens_params.axes_dim_sum = 0; - for (int axis_dim : lens_params.axes_dim) { - lens_params.axes_dim_sum += axis_dim; - } - - LOG_INFO("lens: layers = %d, in_channels = %" PRId64 ", out_channels = %" PRId64 - ", heads = %" PRId64 ", head_dim = %" PRId64, - lens_params.num_layers, - lens_params.in_channels, - lens_params.out_channels, - lens_params.num_attention_heads, - lens_params.attention_head_dim); - - lens = LensModel(lens_params); + : DiffusionModelRunner(backend, params_backend, prefix), + config(LensConfig::detect_from_weights(tensor_storage_map, prefix)) { + lens = LensModel(config); lens.init(params_ctx, tensor_storage_map, prefix); } @@ -380,12 +387,12 @@ namespace Lens { static_cast(x->ne[0]), static_cast(x->ne[3]), static_cast(context->ne[1]), - lens_params.theta, + config.theta, circular_y_enabled, circular_x_enabled, - lens_params.axes_dim); - int pos_len = static_cast(pe_vec.size() / lens_params.axes_dim_sum / 2); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, lens_params.axes_dim_sum / 2, pos_len); + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); set_backend_tensor_data(pe, pe_vec.data()); auto runner_ctx = get_context(); diff --git a/src/llm.hpp b/src/llm.hpp index e3614b22a..927c7c8b5 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -63,7 +63,7 @@ namespace LLM { QWEN3_VL, }; - struct LLMVisionParams { + struct LLMVisionConfig { LLMVisionArch arch = LLMVisionArch::QWEN2_5_VL; int num_layers = 32; int64_t hidden_size = 1280; @@ -79,7 +79,7 @@ namespace LLM { std::set fullatt_block_indexes = {7, 15, 23, 31}; }; - struct LLMParams { + struct LLMConfig { LLMArch arch = LLMArch::QWEN2_5_VL; int64_t num_layers = 28; int64_t hidden_size = 3584; @@ -101,7 +101,129 @@ namespace LLM { std::vector sliding_attention; int64_t num_experts = 0; int64_t num_experts_per_tok = 0; - LLMVisionParams vision; + LLMVisionConfig vision; + bool have_vision_weight = false; + bool llama_cpp_style = false; + + static LLMConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + LLMArch arch) { + LLMConfig config; + config.arch = arch; + if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { + config.head_dim = 128; + config.num_heads = 32; + config.num_kv_heads = 8; + config.qkv_bias = false; + config.rms_norm_eps = 1e-5f; + } else if (arch == LLMArch::QWEN3 || arch == LLMArch::QWEN3_VL) { + config.head_dim = 128; + config.num_heads = 32; + config.num_kv_heads = 8; + config.qkv_bias = false; + config.qk_norm = true; + config.rms_norm_eps = 1e-6f; + if (arch == LLMArch::QWEN3_VL) { + config.max_position_embeddings = 262144; + config.rope_thetas = {5000000.f}; + config.vision.arch = LLMVisionArch::QWEN3_VL; + } + } else if (arch == LLMArch::GEMMA3_12B) { + config.head_dim = 256; + config.num_heads = 16; + config.num_kv_heads = 8; + config.qkv_bias = false; + config.qk_norm = true; + config.rms_norm_eps = 1e-6f; + config.rms_norm_add = false; + config.normalize_input = true; + config.max_position_embeddings = 131072; + config.mlp_activation = MLPActivation::GELU_TANH; + config.rope_thetas = {1000000.f, 10000.f}; + config.rope_scales = {8.f, 1.f}; + config.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0}; + } else if (arch == LLMArch::GEMMA2_2B) { + config.head_dim = 256; + config.num_heads = 8; + config.num_kv_heads = 4; + config.qkv_bias = false; + config.qk_norm = false; + config.rms_norm_eps = 1e-6f; + config.rms_norm_add = true; + config.normalize_input = true; + config.max_position_embeddings = 8192; + config.mlp_activation = MLPActivation::GELU_TANH; + config.hidden_size = 2304; + config.intermediate_size = 9216; + config.num_layers = 26; + config.vocab_size = 256000; + } else if (arch == LLMArch::GPT_OSS_20B) { + config.head_dim = 64; + config.num_heads = 64; + config.num_kv_heads = 8; + config.qkv_bias = true; + config.attention_out_bias = true; + config.qk_norm = false; + config.rms_norm_eps = 1e-5f; + config.hidden_size = 2880; + config.intermediate_size = 2880; + config.num_layers = 24; + config.vocab_size = 201088; + config.max_position_embeddings = 131072; + config.rope_thetas = {150000.f}; + config.rope_scales = {32.f}; + config.sliding_attention = {128, 0}; + config.num_experts = 32; + config.num_experts_per_tok = 4; + } + + config.num_layers = 0; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + size_t pos = name.find("visual."); + if (pos != std::string::npos) { + config.have_vision_weight = true; + if (contains(name, "attn.q_proj")) { + config.llama_cpp_style = true; + } + continue; + } + pos = name.find("layers."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.num_layers) { + config.num_layers = block_index + 1; + } + } + } + if (contains(name, "embed_tokens.weight")) { + config.hidden_size = tensor_storage.ne[0]; + config.vocab_size = tensor_storage.ne[1]; + } + if (contains(name, "layers.0.mlp.gate_proj.weight")) { + config.intermediate_size = tensor_storage.ne[1]; + } + if (contains(name, "layers.0.mlp.experts.gate_up_proj.weight")) { + config.intermediate_size = tensor_storage.ne[1] / 2; + } + if (contains(name, "layers.0.mlp.experts.gate_proj.weight")) { + config.intermediate_size = tensor_storage.ne[1]; + } + } + if (arch == LLMArch::QWEN3 && config.num_layers == 28) { + config.num_heads = 16; + } + LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, + config.num_layers, + config.vocab_size, + config.hidden_size, + config.intermediate_size); + return config; + } }; struct LLMRMSNorm : public UnaryBlock { @@ -232,11 +354,11 @@ namespace LLM { } public: - GPTOSSMLP(const LLMParams& params) - : hidden_size(params.hidden_size), - intermediate_size(params.intermediate_size), - num_experts(params.num_experts), - num_experts_per_tok(params.num_experts_per_tok) {} + GPTOSSMLP(const LLMConfig& config) + : hidden_size(config.hidden_size), + intermediate_size(config.intermediate_size), + num_experts(config.num_experts), + num_experts_per_tok(config.num_experts_per_tok) {} ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { // x: [N, n_token, hidden_size] @@ -667,7 +789,7 @@ namespace LLM { public: VisionModel(bool llama_cpp_style, - const LLMVisionParams& vision_params, + const LLMVisionConfig& vision_params, float eps = 1e-6f) : arch_(vision_params.arch), num_layers(vision_params.num_layers), @@ -784,23 +906,23 @@ namespace LLM { } public: - Attention(const LLMParams& params) - : arch(params.arch), - num_heads(params.num_heads), - num_kv_heads(params.num_kv_heads), - head_dim(params.head_dim), - qk_norm(params.qk_norm), - max_position_embeddings(params.max_position_embeddings), - rope_thetas(params.rope_thetas), - rope_scales(params.rope_scales), - has_attention_sinks(params.arch == LLMArch::GPT_OSS_20B) { - blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); - blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, params.attention_out_bias); - if (params.qk_norm) { - blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps, params.rms_norm_add); - blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps, params.rms_norm_add); + Attention(const LLMConfig& config) + : arch(config.arch), + num_heads(config.num_heads), + num_kv_heads(config.num_kv_heads), + head_dim(config.head_dim), + qk_norm(config.qk_norm), + max_position_embeddings(config.max_position_embeddings), + rope_thetas(config.rope_thetas), + rope_scales(config.rope_scales), + has_attention_sinks(config.arch == LLMArch::GPT_OSS_20B) { + blocks["q_proj"] = std::make_shared(config.hidden_size, num_heads * head_dim, config.qkv_bias); + blocks["k_proj"] = std::make_shared(config.hidden_size, num_kv_heads * head_dim, config.qkv_bias); + blocks["v_proj"] = std::make_shared(config.hidden_size, num_kv_heads * head_dim, config.qkv_bias); + blocks["o_proj"] = std::make_shared(num_heads * head_dim, config.hidden_size, config.attention_out_bias); + if (config.qk_norm) { + blocks["q_norm"] = std::make_shared(head_dim, config.rms_norm_eps, config.rms_norm_add); + blocks["k_norm"] = std::make_shared(head_dim, config.rms_norm_eps, config.rms_norm_add); } } @@ -982,42 +1104,42 @@ namespace LLM { std::string post_ffw_norm_name; public: - TransformerBlock(const LLMParams& params, int layer_index) - : arch(params.arch), + TransformerBlock(const LLMConfig& config, int layer_index) + : arch(config.arch), sliding_attention(0) { - if (params.arch == LLMArch::GEMMA3_12B) { + if (config.arch == LLMArch::GEMMA3_12B) { post_attention_norm_name = "post_attention_norm"; // attn_post_norm pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm post_ffw_norm_name = "post_ffw_norm"; // ffn_post_norm - } else if (params.arch == LLMArch::GEMMA2_2B) { + } else if (config.arch == LLMArch::GEMMA2_2B) { post_attention_norm_name = "post_attention_layernorm"; // ffn_norm pre_ffw_norm_name = "pre_feedforward_layernorm"; post_ffw_norm_name = "post_feedforward_layernorm"; - } else if (params.arch == LLMArch::GPT_OSS_20B) { + } else if (config.arch == LLMArch::GPT_OSS_20B) { pre_ffw_norm_name = "post_attention_norm"; // attn_post_norm } else { pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm } - blocks["self_attn"] = std::make_shared(params); - if (params.arch == LLMArch::GPT_OSS_20B) { - blocks["mlp"] = std::make_shared(params); + blocks["self_attn"] = std::make_shared(config); + if (config.arch == LLMArch::GPT_OSS_20B) { + blocks["mlp"] = std::make_shared(config); } else { - blocks["mlp"] = std::make_shared(params.hidden_size, - params.intermediate_size, + blocks["mlp"] = std::make_shared(config.hidden_size, + config.intermediate_size, false, - params.mlp_activation); + config.mlp_activation); } - blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); - blocks[pre_ffw_norm_name] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + blocks["input_layernorm"] = std::make_shared(config.hidden_size, config.rms_norm_eps, config.rms_norm_add); + blocks[pre_ffw_norm_name] = std::make_shared(config.hidden_size, config.rms_norm_eps, config.rms_norm_add); if (!post_attention_norm_name.empty()) { - blocks[post_attention_norm_name] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + blocks[post_attention_norm_name] = std::make_shared(config.hidden_size, config.rms_norm_eps, config.rms_norm_add); } if (!post_ffw_norm_name.empty()) { - blocks[post_ffw_norm_name] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + blocks[post_ffw_norm_name] = std::make_shared(config.hidden_size, config.rms_norm_eps, config.rms_norm_add); } - if (!params.sliding_attention.empty()) { - sliding_attention = params.sliding_attention[layer_index % params.sliding_attention.size()]; + if (!config.sliding_attention.empty()) { + sliding_attention = config.sliding_attention[layer_index % config.sliding_attention.size()]; } } @@ -1074,16 +1196,16 @@ namespace LLM { struct TextModel : public GGMLBlock { protected: int64_t num_layers; - LLMParams params; + LLMConfig config; public: - TextModel(const LLMParams& params) - : num_layers(params.num_layers), params(params) { - blocks["embed_tokens"] = std::shared_ptr(new Embedding(params.vocab_size, params.hidden_size)); + TextModel(const LLMConfig& config) + : num_layers(config.num_layers), config(config) { + blocks["embed_tokens"] = std::shared_ptr(new Embedding(config.vocab_size, config.hidden_size)); for (int i = 0; i < num_layers; i++) { - blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params, i)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(config, i)); } - blocks["norm"] = std::shared_ptr(new LLMRMSNorm(params.hidden_size, params.rms_norm_eps, params.rms_norm_add)); + blocks["norm"] = std::shared_ptr(new LLMRMSNorm(config.hidden_size, config.rms_norm_eps, config.rms_norm_add)); } ggml_tensor* embed(GGMLRunnerContext* ctx, @@ -1103,8 +1225,8 @@ namespace LLM { auto norm = std::dynamic_pointer_cast(blocks["norm"]); std::vector intermediate_outputs; - if (params.normalize_input) { - x = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(static_cast(params.hidden_size)), true); + if (config.normalize_input) { + x = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(static_cast(config.hidden_size)), true); } if (return_all_hidden_states) { intermediate_outputs.push_back(x); @@ -1174,15 +1296,15 @@ namespace LLM { struct LLM : public GGMLBlock { bool enable_vision; - LLMParams params; + LLMConfig config; public: LLM() = default; - LLM(LLMParams params, bool enable_vision = false, bool llama_cpp_style = false) - : enable_vision(enable_vision), params(params) { - blocks["model"] = std::shared_ptr(new TextModel(params)); + LLM(LLMConfig config, bool enable_vision = false, bool llama_cpp_style = false) + : enable_vision(enable_vision), config(config) { + blocks["model"] = std::shared_ptr(new TextModel(config)); if (enable_vision) { - blocks["visual"] = std::shared_ptr(new VisionModel(llama_cpp_style, params.vision)); + blocks["visual"] = std::shared_ptr(new VisionModel(llama_cpp_style, config.vision)); } } @@ -1226,7 +1348,7 @@ namespace LLM { }; struct LLMRunner : public GGMLRunner { - LLMParams params; + LLMConfig config; bool enable_vision; LLM model; @@ -1242,7 +1364,7 @@ namespace LLM { static ggml_tensor* process_image_common(ggml_context* ctx, ggml_tensor* image, - const LLMVisionParams& vision_params) { + const LLMVisionConfig& vision_params) { // image: [C, H, W] // return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1 int64_t C = image->ne[2]; @@ -1337,7 +1459,7 @@ namespace LLM { ggml_context* compute_ctx, GGMLRunnerContext* runner_ctx, ggml_tensor* image, - const LLMVisionParams& vision_params, + const LLMVisionConfig& vision_params, std::shared_ptr vision_model, std::vector& window_index_vec, std::vector& window_inverse_index_vec, @@ -1452,141 +1574,25 @@ namespace LLM { const String2TensorStorage& tensor_storage_map, const std::string prefix, bool enable_vision_ = false) - : GGMLRunner(backend, params_backend), enable_vision(enable_vision_) { - params.arch = arch; - if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { - params.head_dim = 128; - params.num_heads = 32; - params.num_kv_heads = 8; - params.qkv_bias = false; - params.rms_norm_eps = 1e-5f; - } else if (arch == LLMArch::QWEN3 || arch == LLMArch::QWEN3_VL) { - params.head_dim = 128; - params.num_heads = 32; - params.num_kv_heads = 8; - params.qkv_bias = false; - params.qk_norm = true; - params.rms_norm_eps = 1e-6f; - if (arch == LLMArch::QWEN3_VL) { - params.max_position_embeddings = 262144; - params.rope_thetas = {5000000.f}; - params.vision.arch = LLMVisionArch::QWEN3_VL; - } - } else if (arch == LLMArch::GEMMA3_12B) { - params.head_dim = 256; - params.num_heads = 16; - params.num_kv_heads = 8; - params.qkv_bias = false; - params.qk_norm = true; - params.rms_norm_eps = 1e-6f; - // llama.cpp adds +1 to Gemma3 norm.weight when exporting GGUF, so GGUF loading - // must keep rms_norm_add disabled here or the offset gets applied twice. - // Convenient for the converter, less convenient for whoever gets to debug it later. - params.rms_norm_add = false; - params.normalize_input = true; - params.max_position_embeddings = 131072; - params.mlp_activation = MLPActivation::GELU_TANH; - params.rope_thetas = {1000000.f, 10000.f}; - params.rope_scales = {8.f, 1.f}; - params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0}; - } else if (arch == LLMArch::GEMMA2_2B) { - params.head_dim = 256; - params.num_heads = 8; - params.num_kv_heads = 4; - params.qkv_bias = false; - params.qk_norm = false; - params.rms_norm_eps = 1e-6f; - params.rms_norm_add = true; - params.normalize_input = true; - params.max_position_embeddings = 8192; - params.mlp_activation = MLPActivation::GELU_TANH; - params.hidden_size = 2304; - params.intermediate_size = 9216; - params.num_layers = 26; - params.vocab_size = 256000; - } else if (arch == LLMArch::GPT_OSS_20B) { - params.head_dim = 64; - params.num_heads = 64; - params.num_kv_heads = 8; - params.qkv_bias = true; - params.attention_out_bias = true; - params.qk_norm = false; - params.rms_norm_eps = 1e-5f; - params.hidden_size = 2880; - params.intermediate_size = 2880; - params.num_layers = 24; - params.vocab_size = 201088; - params.max_position_embeddings = 131072; - params.rope_thetas = {150000.f}; - params.rope_scales = {32.f}; - params.sliding_attention = {128, 0}; - params.num_experts = 32; - params.num_experts_per_tok = 4; - } - bool have_vision_weight = false; - bool llama_cpp_style = false; - params.num_layers = 0; - for (auto pair : tensor_storage_map) { - std::string tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) - continue; - size_t pos = tensor_name.find("visual."); - if (pos != std::string::npos) { - have_vision_weight = true; - if (contains(tensor_name, "attn.q_proj")) { - llama_cpp_style = true; - } - continue; - } - pos = tensor_name.find("layers."); - if (pos != std::string::npos) { - tensor_name = tensor_name.substr(pos); // remove prefix - auto items = split_string(tensor_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > params.num_layers) { - params.num_layers = block_index + 1; - } - } - } - if (contains(tensor_name, "embed_tokens.weight")) { - params.hidden_size = pair.second.ne[0]; - params.vocab_size = pair.second.ne[1]; - } - if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) { - params.intermediate_size = pair.second.ne[1]; - } - if (contains(tensor_name, "layers.0.mlp.experts.gate_up_proj.weight")) { - params.intermediate_size = pair.second.ne[1] / 2; - } - if (contains(tensor_name, "layers.0.mlp.experts.gate_proj.weight")) { - params.intermediate_size = pair.second.ne[1]; - } - } - if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B - params.num_heads = 16; - } - LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, - params.num_layers, - params.vocab_size, - params.hidden_size, - params.intermediate_size); - if (enable_vision && !have_vision_weight) { + : GGMLRunner(backend, params_backend), + config(LLMConfig::detect_from_weights(tensor_storage_map, prefix, arch)), + enable_vision(enable_vision_) { + if (enable_vision && !config.have_vision_weight) { LOG_WARN("no vision weights detected, vision disabled"); enable_vision = false; } if (enable_vision) { LOG_DEBUG("enable llm vision"); - if (llama_cpp_style) { + if (config.llama_cpp_style) { LOG_DEBUG("llama.cpp style vision weight"); } } - model = LLM(params, enable_vision, llama_cpp_style); + model = LLM(config, enable_vision, config.llama_cpp_style); model.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { - return llm_arch_to_str[static_cast(params.arch)]; + return llm_arch_to_str[static_cast(config.arch)]; } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -1638,12 +1644,12 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || - params.arch == LLMArch::MINISTRAL_3_3B || - params.arch == LLMArch::QWEN3 || - params.arch == LLMArch::GEMMA3_12B || - params.arch == LLMArch::GEMMA2_2B || - params.arch == LLMArch::GPT_OSS_20B) { + if (config.arch == LLMArch::MISTRAL_SMALL_3_2 || + config.arch == LLMArch::MINISTRAL_3_3B || + config.arch == LLMArch::QWEN3 || + config.arch == LLMArch::GEMMA3_12B || + config.arch == LLMArch::GEMMA2_2B || + config.arch == LLMArch::GPT_OSS_20B) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -1682,9 +1688,9 @@ namespace LLM { set_backend_tensor_data(attention_mask, attention_mask_vec.data()); } - if (params.arch == LLMArch::GEMMA3_12B || params.arch == LLMArch::GPT_OSS_20B) { + if (config.arch == LLMArch::GEMMA3_12B || config.arch == LLMArch::GPT_OSS_20B) { int sliding_window = 0; - for (int window : params.sliding_attention) { + for (int window : config.sliding_attention) { sliding_window = std::max(sliding_window, window); } sliding_attention_mask_vec.resize(n_tokens * n_tokens); @@ -1740,15 +1746,15 @@ namespace LLM { int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { int64_t grid_t = 1; - int64_t grid_h = h / params.vision.patch_size; - int64_t grid_w = w / params.vision.patch_size; - int64_t llm_grid_h = grid_h / params.vision.spatial_merge_size; - int64_t llm_grid_w = grid_w / params.vision.spatial_merge_size; + int64_t grid_h = h / config.vision.patch_size; + int64_t grid_w = w / config.vision.patch_size; + int64_t llm_grid_h = grid_h / config.vision.spatial_merge_size; + int64_t llm_grid_w = grid_w / config.vision.spatial_merge_size; return grid_t * grid_h * grid_w; } ggml_tensor* process_image(ggml_context* ctx, ggml_tensor* image) { - return process_image_common(ctx, image, params.vision); + return process_image_common(ctx, image, config.vision); } ggml_tensor* build_patch_pos_embeds(GGMLRunnerContext* runner_ctx, @@ -1770,7 +1776,7 @@ namespace LLM { compute_ctx, runner_ctx, image, - params.vision, + config.vision, model.vision_model(), window_index_vec, window_inverse_index_vec, @@ -1784,8 +1790,8 @@ namespace LLM { ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE); ggml_tensor* image = make_input(image_tensor); - GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); - GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); + GGML_ASSERT(image->ne[1] % (config.vision.patch_size * config.vision.spatial_merge_size) == 0); + GGML_ASSERT(image->ne[0] % (config.vision.patch_size * config.vision.spatial_merge_size) == 0); auto runnter_ctx = get_context(); ggml_tensor* hidden_states = encode_image(&runnter_ctx, image); diff --git a/src/ltx_audio_vae.h b/src/ltx_audio_vae.h index 88c376314..51fb19898 100644 --- a/src/ltx_audio_vae.h +++ b/src/ltx_audio_vae.h @@ -58,11 +58,12 @@ namespace LTXV { return base_output_sample_rate(); } - static LTXAudioVAEConfig detect_from_weights(const String2TensorStorage& tensor_storage_map) { + static LTXAudioVAEConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix = "") { LTXAudioVAEConfig config; auto require = [&](const std::string& name) -> const TensorStorage* { - auto iter = tensor_storage_map.find(name); + std::string tensor_name = prefix.empty() ? name : prefix + "." + name; + auto iter = tensor_storage_map.find(tensor_name); if (iter == tensor_storage_map.end()) { return nullptr; } @@ -168,6 +169,12 @@ namespace LTXV { if (config.audio_channels != 2 || config.latent_channels != 8 || config.mel_bins != 64) { return config; } + LOG_DEBUG("ltx_audio_vae: sample_rate = %d, mel_bins = %d, latent_channels = %d, latent_frequency_bins = %d, has_bwe = %s", + config.sample_rate, + config.mel_bins, + config.latent_channels, + config.latent_frequency_bins, + config.has_bwe ? "true" : "false"); return config; } }; diff --git a/src/ltxv.hpp b/src/ltxv.hpp index a7d3fb04e..a430d75c1 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -72,6 +72,200 @@ namespace LTXV { return max_block + 1; } + struct LTXAVConfig { + int64_t in_channels = 128; + int64_t out_channels = 128; + int64_t hidden_size = 3840; + int64_t cross_attention_dim = 4096; + int64_t caption_channels = 3840; + int64_t num_attention_heads = 30; + int64_t attention_head_dim = 128; + int64_t num_layers = 28; + float positional_embedding_theta = 10000.f; + std::vector positional_embedding_max_pos = {20, 2048, 2048}; + std::tuple vae_scale_factors = {8, 32, 32}; + bool causal_temporal_positioning = true; + float timestep_scale_multiplier = 1000.f; + + int64_t audio_in_channels = 128; + int64_t audio_out_channels = 128; + int64_t audio_hidden_size = 2048; + int64_t audio_cross_attention_dim = 2048; + int64_t audio_num_attention_heads = 32; + int64_t audio_attention_head_dim = 64; + std::vector audio_positional_embedding_max_pos = {20}; + float av_ca_timestep_scale_multiplier = 1000.f; + int64_t num_audio_channels = 8; + int64_t audio_frequency_bins = 16; + + bool use_connector = false; + int64_t connector_hidden_size = 3840; + int64_t connector_num_heads = 30; + int64_t connector_head_dim = 128; + int64_t connector_num_layers = 2; + int64_t connector_num_registers = 128; + bool connector_rope_interleaved = false; + bool connector_apply_gated_attention = false; + + bool use_audio_connector = false; + int64_t audio_connector_hidden_size = 2048; + int64_t audio_connector_num_heads = 32; + int64_t audio_connector_head_dim = 64; + int64_t audio_connector_num_layers = 2; + int64_t audio_connector_num_registers = 128; + bool audio_connector_rope_interleaved = false; + bool audio_connector_apply_gated_attention = false; + + bool video_rope_interleaved = false; + bool use_middle_indices_grid = true; + bool cross_attention_adaln = false; + + bool use_caption_projection = true; + bool use_audio_caption_projection = true; + bool caption_proj_before_connector = true; + bool caption_projection_first_linear = false; + + bool self_attention_gated = false; + bool cross_attention_gated = false; + + static std::pair infer_attention_layout(int64_t hidden_size, + int64_t preferred_heads = -1) { + if (preferred_heads > 0 && hidden_size % preferred_heads == 0) { + return {preferred_heads, hidden_size / preferred_heads}; + } + const int candidates[] = {128, 96, 80, 64, 48, 40, 32}; + for (int head_dim : candidates) { + if (hidden_size % head_dim == 0) { + int64_t heads = hidden_size / head_dim; + if (heads >= 8 && heads <= 64) { + return {heads, head_dim}; + } + } + } + return {32, hidden_size / 32}; + } + + static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map, + const std::string& bias_name, + int64_t fallback_heads) { + auto it = tensor_storage_map.find(bias_name); + if (it != tensor_storage_map.end()) { + return it->second.ne[0]; + } + return fallback_heads; + } + + static LTXAVConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + LTXAVConfig config; + auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight"); + if (patchify_proj_iter != tensor_storage_map.end()) { + config.in_channels = patchify_proj_iter->second.ne[0]; + config.hidden_size = patchify_proj_iter->second.ne[1]; + int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32); + auto attn_layout = infer_attention_layout(config.hidden_size, video_heads); + config.num_attention_heads = attn_layout.first; + config.attention_head_dim = attn_layout.second; + } + + auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight"); + if (audio_patchify_proj_iter != tensor_storage_map.end()) { + config.audio_in_channels = audio_patchify_proj_iter->second.ne[0]; + config.audio_hidden_size = audio_patchify_proj_iter->second.ne[1]; + config.audio_out_channels = config.audio_in_channels; + int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32); + auto audio_attn_layout = infer_attention_layout(config.audio_hidden_size, audio_heads); + config.audio_num_attention_heads = audio_attn_layout.first; + config.audio_attention_head_dim = audio_attn_layout.second; + } + + auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight"); + if (proj_out_iter != tensor_storage_map.end()) { + config.out_channels = proj_out_iter->second.ne[1]; + } + auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight"); + if (audio_proj_out_iter != tensor_storage_map.end()) { + config.audio_out_channels = audio_proj_out_iter->second.ne[1]; + } + + auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight"); + if (attn2_iter != tensor_storage_map.end()) { + config.cross_attention_dim = attn2_iter->second.ne[0]; + } + auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight"); + if (audio_attn2_iter != tensor_storage_map.end()) { + config.audio_cross_attention_dim = audio_attn2_iter->second.ne[0]; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) { + config.cross_attention_adaln = true; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() || + tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + config.self_attention_gated = true; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() || + tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) { + config.cross_attention_gated = true; + } + if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() && + tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) { + config.use_caption_projection = false; + } + if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() && + tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) { + config.use_audio_caption_projection = false; + } + + config.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks."); + + auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); + if (connector_iter != tensor_storage_map.end()) { + config.use_connector = true; + config.connector_hidden_size = connector_iter->second.ne[1]; + int64_t connector_heads = infer_gate_heads(tensor_storage_map, + prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", + 32); + auto connector_layout = infer_attention_layout(config.connector_hidden_size, connector_heads); + config.connector_num_heads = connector_layout.first; + config.connector_head_dim = connector_layout.second; + config.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks."); + auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers"); + if (register_iter != tensor_storage_map.end()) { + config.connector_num_registers = register_iter->second.ne[1]; + } + if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + config.connector_apply_gated_attention = true; + } + } + + auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); + if (audio_connector_iter != tensor_storage_map.end()) { + config.use_audio_connector = true; + config.audio_connector_hidden_size = audio_connector_iter->second.ne[1]; + int64_t connector_heads = infer_gate_heads(tensor_storage_map, + prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", + 32); + auto connector_layout = infer_attention_layout(config.audio_connector_hidden_size, connector_heads); + config.audio_connector_num_heads = connector_layout.first; + config.audio_connector_head_dim = connector_layout.second; + config.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks."); + auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers"); + if (register_iter != tensor_storage_map.end()) { + config.audio_connector_num_registers = register_iter->second.ne[1]; + } + if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + config.audio_connector_apply_gated_attention = true; + } + } + LOG_DEBUG("ltxav: num_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_attention_heads = %" PRId64 ", audio_hidden_size = %" PRId64 ", audio_num_attention_heads = %" PRId64, + config.num_layers, + config.hidden_size, + config.num_attention_heads, + config.audio_hidden_size, + config.audio_num_attention_heads); + return config; + } + }; + __STATIC_INLINE__ std::vector generate_freq_grid(float theta, int positional_dims, int dim) { @@ -749,63 +943,6 @@ namespace LTXV { } }; - struct LTXAVParams { - int64_t in_channels = 128; - int64_t out_channels = 128; - int64_t hidden_size = 3840; - int64_t cross_attention_dim = 4096; - int64_t caption_channels = 3840; - int64_t num_attention_heads = 30; - int64_t attention_head_dim = 128; - int64_t num_layers = 28; - float positional_embedding_theta = 10000.f; - std::vector positional_embedding_max_pos = {20, 2048, 2048}; - std::tuple vae_scale_factors = {8, 32, 32}; - bool causal_temporal_positioning = true; - float timestep_scale_multiplier = 1000.f; - - int64_t audio_in_channels = 128; - int64_t audio_out_channels = 128; - int64_t audio_hidden_size = 2048; - int64_t audio_cross_attention_dim = 2048; - int64_t audio_num_attention_heads = 32; - int64_t audio_attention_head_dim = 64; - std::vector audio_positional_embedding_max_pos = {20}; - float av_ca_timestep_scale_multiplier = 1000.f; - int64_t num_audio_channels = 8; - int64_t audio_frequency_bins = 16; - - bool use_connector = false; - int64_t connector_hidden_size = 3840; - int64_t connector_num_heads = 30; - int64_t connector_head_dim = 128; - int64_t connector_num_layers = 2; - int64_t connector_num_registers = 128; - bool connector_rope_interleaved = false; - bool connector_apply_gated_attention = false; - - bool use_audio_connector = false; - int64_t audio_connector_hidden_size = 2048; - int64_t audio_connector_num_heads = 32; - int64_t audio_connector_head_dim = 64; - int64_t audio_connector_num_layers = 2; - int64_t audio_connector_num_registers = 128; - bool audio_connector_rope_interleaved = false; - bool audio_connector_apply_gated_attention = false; - - bool video_rope_interleaved = false; - bool use_middle_indices_grid = true; - bool cross_attention_adaln = false; - - bool use_caption_projection = true; - bool use_audio_caption_projection = true; - bool caption_proj_before_connector = true; - bool caption_projection_first_linear = false; - - bool self_attention_gated = false; - bool cross_attention_gated = false; - }; - __STATIC_INLINE__ std::pair infer_attention_layout(int64_t hidden_size, int64_t preferred_heads = -1) { if (preferred_heads > 0 && hidden_size % preferred_heads == 0) { @@ -1169,92 +1306,92 @@ namespace LTXV { }; struct LTXAVModelBlock : public GGMLBlock { - LTXAVParams cfg; + LTXAVConfig config; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { params["scale_shift_table"] = ggml_new_tensor_2d(ctx, get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32), - cfg.hidden_size, + config.hidden_size, 2); params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32), - cfg.audio_hidden_size, + config.audio_hidden_size, 2); } - LTXAVModelBlock(const LTXAVParams& params) - : cfg(params) { - blocks["patchify_proj"] = std::make_shared(cfg.in_channels, cfg.hidden_size, true, true); - blocks["audio_patchify_proj"] = std::make_shared(cfg.audio_in_channels, cfg.audio_hidden_size, true, true); - blocks["adaln_single"] = std::make_shared(cfg.hidden_size, cfg.cross_attention_adaln ? 9 : 6); - blocks["audio_adaln_single"] = std::make_shared(cfg.audio_hidden_size, cfg.cross_attention_adaln ? 9 : 6); - if (cfg.cross_attention_adaln) { - blocks["prompt_adaln_single"] = std::make_shared(cfg.hidden_size, 2); - blocks["audio_prompt_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 2); - } - blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared(cfg.hidden_size, 4); - blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared(cfg.hidden_size, 1); - blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 4); - blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 1); - - if (cfg.use_caption_projection) { - if (cfg.caption_proj_before_connector) { - if (cfg.caption_projection_first_linear) { - blocks["caption_projection"] = std::make_shared(cfg.caption_channels, cfg.hidden_size); + LTXAVModelBlock(const LTXAVConfig& config) + : config(config) { + blocks["patchify_proj"] = std::make_shared(config.in_channels, config.hidden_size, true, true); + blocks["audio_patchify_proj"] = std::make_shared(config.audio_in_channels, config.audio_hidden_size, true, true); + blocks["adaln_single"] = std::make_shared(config.hidden_size, config.cross_attention_adaln ? 9 : 6); + blocks["audio_adaln_single"] = std::make_shared(config.audio_hidden_size, config.cross_attention_adaln ? 9 : 6); + if (config.cross_attention_adaln) { + blocks["prompt_adaln_single"] = std::make_shared(config.hidden_size, 2); + blocks["audio_prompt_adaln_single"] = std::make_shared(config.audio_hidden_size, 2); + } + blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared(config.hidden_size, 4); + blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared(config.hidden_size, 1); + blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared(config.audio_hidden_size, 4); + blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared(config.audio_hidden_size, 1); + + if (config.use_caption_projection) { + if (config.caption_proj_before_connector) { + if (config.caption_projection_first_linear) { + blocks["caption_projection"] = std::make_shared(config.caption_channels, config.hidden_size); } } else { - blocks["caption_projection"] = std::make_shared(cfg.caption_channels, cfg.hidden_size, cfg.hidden_size); + blocks["caption_projection"] = std::make_shared(config.caption_channels, config.hidden_size, config.hidden_size); } } - if (cfg.use_audio_caption_projection) { - if (cfg.caption_proj_before_connector) { - if (cfg.caption_projection_first_linear) { - blocks["audio_caption_projection"] = std::make_shared(cfg.caption_channels, cfg.audio_hidden_size); + if (config.use_audio_caption_projection) { + if (config.caption_proj_before_connector) { + if (config.caption_projection_first_linear) { + blocks["audio_caption_projection"] = std::make_shared(config.caption_channels, config.audio_hidden_size); } } else { - blocks["audio_caption_projection"] = std::make_shared(cfg.caption_channels, cfg.audio_hidden_size, cfg.audio_hidden_size); + blocks["audio_caption_projection"] = std::make_shared(config.caption_channels, config.audio_hidden_size, config.audio_hidden_size); } } - if (cfg.use_connector) { - blocks["video_embeddings_connector"] = std::make_shared(cfg.connector_hidden_size, - cfg.connector_num_heads, - cfg.connector_head_dim, - cfg.connector_num_layers, - cfg.connector_num_registers, - cfg.connector_rope_interleaved, - cfg.connector_apply_gated_attention); - } - if (cfg.use_audio_connector) { - blocks["audio_embeddings_connector"] = std::make_shared(cfg.audio_connector_hidden_size, - cfg.audio_connector_num_heads, - cfg.audio_connector_head_dim, - cfg.audio_connector_num_layers, - cfg.audio_connector_num_registers, - cfg.audio_connector_rope_interleaved, - cfg.audio_connector_apply_gated_attention); - } - - for (int i = 0; i < cfg.num_layers; i++) { - blocks["transformer_blocks." + std::to_string(i)] = std::make_shared(cfg.hidden_size, - cfg.audio_hidden_size, - cfg.num_attention_heads, - cfg.audio_num_attention_heads, - cfg.attention_head_dim, - cfg.audio_attention_head_dim, - cfg.cross_attention_dim, - cfg.audio_cross_attention_dim, - cfg.self_attention_gated || cfg.cross_attention_gated, - cfg.cross_attention_adaln, - cfg.video_rope_interleaved); - } - - blocks["norm_out"] = std::make_shared(cfg.hidden_size, 1e-6f, false); - blocks["proj_out"] = std::make_shared(cfg.hidden_size, cfg.out_channels, true, true); - blocks["audio_norm_out"] = std::make_shared(cfg.audio_hidden_size, 1e-6f, false); - blocks["audio_proj_out"] = std::make_shared(cfg.audio_hidden_size, cfg.audio_out_channels, true, true); + if (config.use_connector) { + blocks["video_embeddings_connector"] = std::make_shared(config.connector_hidden_size, + config.connector_num_heads, + config.connector_head_dim, + config.connector_num_layers, + config.connector_num_registers, + config.connector_rope_interleaved, + config.connector_apply_gated_attention); + } + if (config.use_audio_connector) { + blocks["audio_embeddings_connector"] = std::make_shared(config.audio_connector_hidden_size, + config.audio_connector_num_heads, + config.audio_connector_head_dim, + config.audio_connector_num_layers, + config.audio_connector_num_registers, + config.audio_connector_rope_interleaved, + config.audio_connector_apply_gated_attention); + } + + for (int i = 0; i < config.num_layers; i++) { + blocks["transformer_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, + config.audio_hidden_size, + config.num_attention_heads, + config.audio_num_attention_heads, + config.attention_head_dim, + config.audio_attention_head_dim, + config.cross_attention_dim, + config.audio_cross_attention_dim, + config.self_attention_gated || config.cross_attention_gated, + config.cross_attention_adaln, + config.video_rope_interleaved); + } + + blocks["norm_out"] = std::make_shared(config.hidden_size, 1e-6f, false); + blocks["proj_out"] = std::make_shared(config.hidden_size, config.out_channels, true, true); + blocks["audio_norm_out"] = std::make_shared(config.audio_hidden_size, 1e-6f, false); + blocks["audio_proj_out"] = std::make_shared(config.audio_hidden_size, config.audio_out_channels, true, true); } ggml_tensor* patchify_video(GGMLRunnerContext* ctx, ggml_tensor* x, int64_t n) { @@ -1293,8 +1430,8 @@ namespace LTXV { if (ax == nullptr) { return nullptr; } - ax = ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, cfg.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f] - ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f] + ax = ggml_reshape_4d(ctx->ggml_ctx, ax, config.audio_frequency_bins, config.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f] + ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f] return ax; } @@ -1308,17 +1445,17 @@ namespace LTXV { } bool is_fully_processed_context = - context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim && + context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim && context->ne[1] >= 1024; bool is_unprocessed_dual_context = - context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim && + context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim && context->ne[1] < 1024; if (is_fully_processed_context) { - auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim); + auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim); ggml_tensor* a_context = nullptr; if (process_audio_context) { - a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim); + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim); } return {v_context, a_context}; } @@ -1326,32 +1463,32 @@ namespace LTXV { ggml_tensor* v_context = context; ggml_tensor* a_context = process_audio_context ? context : nullptr; if (is_unprocessed_dual_context) { - v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim); + v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.cross_attention_dim); if (process_audio_context) { - a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim); + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.cross_attention_dim, config.cross_attention_dim + config.audio_cross_attention_dim); } - } else if (context->ne[0] == cfg.caption_channels * 2) { - v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.caption_channels); + } else if (context->ne[0] == config.caption_channels * 2) { + v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, config.caption_channels); if (process_audio_context) { - a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.caption_channels, cfg.caption_channels * 2); + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, config.caption_channels, config.caption_channels * 2); } } - if (cfg.caption_proj_before_connector) { - if (cfg.use_caption_projection && + if (config.caption_proj_before_connector) { + if (config.use_caption_projection && blocks.count("caption_projection") > 0 && v_context != nullptr && - v_context->ne[0] == cfg.caption_channels) { + v_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); if (caption_projection != nullptr) { v_context = caption_projection->forward(ctx, v_context); } } if (process_audio_context && - cfg.use_audio_caption_projection && + config.use_audio_caption_projection && blocks.count("audio_caption_projection") > 0 && a_context != nullptr && - a_context->ne[0] == cfg.caption_channels) { + a_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); if (caption_projection != nullptr) { a_context = caption_projection->forward(ctx, a_context); @@ -1359,34 +1496,34 @@ namespace LTXV { } } - if (cfg.use_connector && v_context != nullptr && v_context->ne[0] == cfg.connector_hidden_size) { + if (config.use_connector && v_context != nullptr && v_context->ne[0] == config.connector_hidden_size) { auto connector = std::dynamic_pointer_cast(blocks["video_embeddings_connector"]); v_context = connector->forward(ctx, v_context, video_connector_pe); } if (process_audio_context && - cfg.use_audio_connector && + config.use_audio_connector && a_context != nullptr && - a_context->ne[0] == cfg.audio_connector_hidden_size) { + a_context->ne[0] == config.audio_connector_hidden_size) { auto connector = std::dynamic_pointer_cast(blocks["audio_embeddings_connector"]); a_context = connector->forward(ctx, a_context, audio_connector_pe); } - if (!cfg.caption_proj_before_connector && - cfg.use_caption_projection && + if (!config.caption_proj_before_connector && + config.use_caption_projection && blocks.count("caption_projection") > 0 && v_context != nullptr && - v_context->ne[0] == cfg.caption_channels) { + v_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); if (caption_projection != nullptr) { v_context = caption_projection->forward(ctx, v_context); } } if (process_audio_context && - !cfg.caption_proj_before_connector && - cfg.use_audio_caption_projection && + !config.caption_proj_before_connector && + config.use_audio_caption_projection && blocks.count("audio_caption_projection") > 0 && a_context != nullptr && - a_context->ne[0] == cfg.caption_channels) { + a_context->ne[0] == config.caption_channels) { auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); if (caption_projection != nullptr) { a_context = caption_projection->forward(ctx, a_context); @@ -1428,8 +1565,8 @@ namespace LTXV { auto audio_norm_out = std::dynamic_pointer_cast(blocks["audio_norm_out"]); auto audio_proj_out = std::dynamic_pointer_cast(blocks["audio_proj_out"]); - GGML_ASSERT(vx->ne[3] % cfg.in_channels == 0); - int64_t n = vx->ne[3] / cfg.in_channels; + GGML_ASSERT(vx->ne[3] % config.in_channels == 0); + int64_t n = vx->ne[3] / config.in_channels; int64_t width = vx->ne[0]; int64_t height = vx->ne[1]; int64_t frames = vx->ne[2]; @@ -1452,20 +1589,20 @@ namespace LTXV { a_context = ggml_cont(ctx->ggml_ctx, a_context); } - auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, cfg.timestep_scale_multiplier); + auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, config.timestep_scale_multiplier); auto v_pair = adaln_single->forward(ctx, v_timestep_scaled); auto v_timestep_mod = v_pair.first; auto v_embedded_time = v_pair.second; ggml_tensor* effective_audio_timestep = audio_timestep != nullptr ? audio_timestep : timestep; - auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, cfg.timestep_scale_multiplier); + auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, config.timestep_scale_multiplier); auto a_pair = audio_adaln_single->forward(ctx, a_timestep_scaled); auto a_timestep_mod = a_pair.first; auto a_embedded_time = a_pair.second; ggml_tensor* v_prompt_timestep_mod = nullptr; ggml_tensor* a_prompt_timestep_mod = nullptr; - if (cfg.cross_attention_adaln) { + if (config.cross_attention_adaln) { auto prompt_adaln_single = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); auto audio_prompt_adaln_single = std::dynamic_pointer_cast(blocks["audio_prompt_adaln_single"]); v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, a_timestep_scaled).first; @@ -1474,7 +1611,7 @@ namespace LTXV { auto av_ca_video_timestep = repeat_scalar_timestep_like(ctx, effective_audio_timestep, timestep); auto av_ca_audio_timestep = effective_audio_timestep; - auto av_ca_factor = cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier; + auto av_ca_factor = config.av_ca_timestep_scale_multiplier / config.timestep_scale_multiplier; auto av_ca_video_scale_shift_timestep = std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, av_ca_video_timestep).first; auto av_ca_a2v_gate_noise_timestep = @@ -1491,7 +1628,7 @@ namespace LTXV { sd::ggml_graph_cut::mark_graph_cut(vx, "ltxav.prelude", "vx"); sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.prelude", "ax"); - for (int i = 0; i < cfg.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); auto out = block->forward(ctx, vx, @@ -1517,14 +1654,14 @@ namespace LTXV { sd::ggml_graph_cut::mark_graph_cut(ax, "ltxav.transformer_blocks." + std::to_string(i), "ax"); } - auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size); + auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, config.hidden_size); vx = norm_out->forward(ctx, vx); vx = modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1]); vx = proj_out->forward(ctx, vx); vx = unpatchify_video(ctx, vx, width, height, frames); if (ax != nullptr && audio_time > 0) { - auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size); + auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, config.audio_hidden_size); ax = audio_norm_out->forward(ctx, ax); ax = modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1]); ax = audio_proj_out->forward(ctx, ax); @@ -1536,7 +1673,7 @@ namespace LTXV { }; struct LTXAVRunner : public DiffusionModelRunner { - LTXAVParams params; + LTXAVConfig config; LTXAVModelBlock model; std::vector video_pe_vec; std::vector audio_pe_vec; @@ -1547,124 +1684,13 @@ namespace LTXV { sd::Tensor vx_input_cache; sd::Tensor ax_input_cache; - static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map, - const std::string& bias_name, - int64_t fallback_heads) { - auto it = tensor_storage_map.find(bias_name); - if (it != tensor_storage_map.end()) { - return it->second.ne[0]; - } - return fallback_heads; - } - LTXAVRunner(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string& prefix = "model.diffusion_model") : DiffusionModelRunner(backend, params_backend, prefix), - params(), - model(params) { - auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight"); - if (patchify_proj_iter != tensor_storage_map.end()) { - params.in_channels = patchify_proj_iter->second.ne[0]; - params.hidden_size = patchify_proj_iter->second.ne[1]; - int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32); - auto attn_layout = infer_attention_layout(params.hidden_size, video_heads); - params.num_attention_heads = attn_layout.first; - params.attention_head_dim = attn_layout.second; - } - - auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight"); - if (audio_patchify_proj_iter != tensor_storage_map.end()) { - params.audio_in_channels = audio_patchify_proj_iter->second.ne[0]; - params.audio_hidden_size = audio_patchify_proj_iter->second.ne[1]; - params.audio_out_channels = params.audio_in_channels; - int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32); - auto audio_attn_layout = infer_attention_layout(params.audio_hidden_size, audio_heads); - params.audio_num_attention_heads = audio_attn_layout.first; - params.audio_attention_head_dim = audio_attn_layout.second; - } - - auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight"); - if (proj_out_iter != tensor_storage_map.end()) { - params.out_channels = proj_out_iter->second.ne[1]; - } - auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight"); - if (audio_proj_out_iter != tensor_storage_map.end()) { - params.audio_out_channels = audio_proj_out_iter->second.ne[1]; - } - - auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight"); - if (attn2_iter != tensor_storage_map.end()) { - params.cross_attention_dim = attn2_iter->second.ne[0]; - } - auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight"); - if (audio_attn2_iter != tensor_storage_map.end()) { - params.audio_cross_attention_dim = audio_attn2_iter->second.ne[0]; - } - if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) { - params.cross_attention_adaln = true; - } - if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() || - tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) { - params.self_attention_gated = true; - } - if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() || - tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) { - params.cross_attention_gated = true; - } - if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() && - tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) { - params.use_caption_projection = false; - } - if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() && - tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) { - params.use_audio_caption_projection = false; - } - - params.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks."); - - auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); - if (connector_iter != tensor_storage_map.end()) { - params.use_connector = true; - params.connector_hidden_size = connector_iter->second.ne[1]; - int64_t connector_heads = infer_gate_heads(tensor_storage_map, - prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", - 32); - auto connector_layout = infer_attention_layout(params.connector_hidden_size, connector_heads); - params.connector_num_heads = connector_layout.first; - params.connector_head_dim = connector_layout.second; - params.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks."); - auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers"); - if (register_iter != tensor_storage_map.end()) { - params.connector_num_registers = register_iter->second.ne[1]; - } - if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { - params.connector_apply_gated_attention = true; - } - } - - auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); - if (audio_connector_iter != tensor_storage_map.end()) { - params.use_audio_connector = true; - params.audio_connector_hidden_size = audio_connector_iter->second.ne[1]; - int64_t connector_heads = infer_gate_heads(tensor_storage_map, - prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", - 32); - auto connector_layout = infer_attention_layout(params.audio_connector_hidden_size, connector_heads); - params.audio_connector_num_heads = connector_layout.first; - params.audio_connector_head_dim = connector_layout.second; - params.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks."); - auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers"); - if (register_iter != tensor_storage_map.end()) { - params.audio_connector_num_registers = register_iter->second.ne[1]; - } - if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { - params.audio_connector_apply_gated_attention = true; - } - } - - model = LTXAVModelBlock(params); + config(LTXAVConfig::detect_from_weights(tensor_storage_map, prefix)), + model(config) { model.init(params_ctx, tensor_storage_map, prefix); } @@ -1692,21 +1718,21 @@ namespace LTXV { int64_t total_channels = x_tensor.shape()[3]; int64_t spatial_size = width * height * frames; - GGML_ASSERT(total_channels >= params.in_channels); + GGML_ASSERT(total_channels >= config.in_channels); - sd::Tensor vx({width, height, frames, params.in_channels}); - size_t video_values = static_cast(params.in_channels * spatial_size); + sd::Tensor vx({width, height, frames, config.in_channels}); + size_t video_values = static_cast(config.in_channels * spatial_size); std::copy_n(x_tensor.data(), video_values, vx.data()); - if (audio_length <= 0 || total_channels == params.in_channels) { + if (audio_length <= 0 || total_channels == config.in_channels) { return {vx, {}}; } - int64_t needed_audio_values = static_cast(audio_length) * params.num_audio_channels * params.audio_frequency_bins; - int64_t packed_audio_values = (total_channels - params.in_channels) * spatial_size; + int64_t needed_audio_values = static_cast(audio_length) * config.num_audio_channels * config.audio_frequency_bins; + int64_t packed_audio_values = (total_channels - config.in_channels) * spatial_size; GGML_ASSERT(packed_audio_values >= needed_audio_values); - sd::Tensor ax({params.audio_frequency_bins, audio_length, params.num_audio_channels, 1}); + sd::Tensor ax({config.audio_frequency_bins, audio_length, config.num_audio_channels, 1}); const float* audio_src = x_tensor.data() + video_values; std::copy_n(audio_src, static_cast(needed_audio_values), ax.data()); return {vx, ax}; @@ -1767,25 +1793,25 @@ namespace LTXV { if (has_video_positions) { GGML_ASSERT(video_positions_tensor.shape()[2] == video_token_count); video_pe_vec = build_video_rope_matrix_from_positions(video_positions_tensor, - static_cast(params.hidden_size), - static_cast(params.num_attention_heads), - params.positional_embedding_theta, - params.positional_embedding_max_pos, - params.use_middle_indices_grid); + static_cast(config.hidden_size), + static_cast(config.num_attention_heads), + config.positional_embedding_theta, + config.positional_embedding_max_pos, + config.use_middle_indices_grid); } else { video_pe_vec = build_video_rope_matrix(vx->ne[0], vx->ne[1], vx->ne[2], - static_cast(params.hidden_size), - static_cast(params.num_attention_heads), + static_cast(config.hidden_size), + static_cast(config.num_attention_heads), video_frame_rate, - params.positional_embedding_theta, - params.positional_embedding_max_pos, - params.vae_scale_factors, - params.causal_temporal_positioning, - params.use_middle_indices_grid); + config.positional_embedding_theta, + config.positional_embedding_max_pos, + config.vae_scale_factors, + config.causal_temporal_positioning, + config.use_middle_indices_grid); } - auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.attention_head_dim / 2, video_token_count * params.num_attention_heads); + auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.attention_head_dim / 2, video_token_count * config.num_attention_heads); ggml_set_name(video_pe, "ltxav_video_pe"); set_backend_tensor_data(video_pe, video_pe_vec.data()); @@ -1794,66 +1820,66 @@ namespace LTXV { ggml_tensor* audio_cross_pe = nullptr; if (ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0) { audio_pe_vec = build_audio_rope_matrix(ax->ne[1], - static_cast(params.audio_hidden_size), - static_cast(params.audio_num_attention_heads), - params.positional_embedding_theta, - params.audio_positional_embedding_max_pos[0], - params.use_middle_indices_grid); - audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads); + static_cast(config.audio_hidden_size), + static_cast(config.audio_num_attention_heads), + config.positional_embedding_theta, + config.audio_positional_embedding_max_pos[0], + config.use_middle_indices_grid); + audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads); ggml_set_name(audio_pe, "ltxav_audio_pe"); set_backend_tensor_data(audio_pe, audio_pe_vec.data()); - int temporal_max_pos = std::max(params.positional_embedding_max_pos[0], params.audio_positional_embedding_max_pos[0]); + int temporal_max_pos = std::max(config.positional_embedding_max_pos[0], config.audio_positional_embedding_max_pos[0]); if (has_video_positions) { video_cross_pe_vec = build_video_temporal_rope_matrix_from_positions(video_positions_tensor, - static_cast(params.audio_cross_attention_dim), - static_cast(params.audio_num_attention_heads), - params.positional_embedding_theta, + static_cast(config.audio_cross_attention_dim), + static_cast(config.audio_num_attention_heads), + config.positional_embedding_theta, temporal_max_pos, true); } else { video_cross_pe_vec = build_video_temporal_rope_matrix(vx->ne[0], vx->ne[1], vx->ne[2], - static_cast(params.audio_cross_attention_dim), - static_cast(params.audio_num_attention_heads), + static_cast(config.audio_cross_attention_dim), + static_cast(config.audio_num_attention_heads), video_frame_rate, - params.positional_embedding_theta, + config.positional_embedding_theta, temporal_max_pos, - std::get<0>(params.vae_scale_factors), - params.causal_temporal_positioning, + std::get<0>(config.vae_scale_factors), + config.causal_temporal_positioning, true); } - video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, video_token_count * params.audio_num_attention_heads); + video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, video_token_count * config.audio_num_attention_heads); ggml_set_name(video_cross_pe, "ltxav_video_cross_pe"); set_backend_tensor_data(video_cross_pe, video_cross_pe_vec.data()); audio_cross_pe_vec = build_audio_rope_matrix(ax->ne[1], - static_cast(params.audio_cross_attention_dim), - static_cast(params.audio_num_attention_heads), - params.positional_embedding_theta, + static_cast(config.audio_cross_attention_dim), + static_cast(config.audio_num_attention_heads), + config.positional_embedding_theta, temporal_max_pos, true); - audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads); + audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_attention_head_dim / 2, ax->ne[1] * config.audio_num_attention_heads); ggml_set_name(audio_cross_pe, "ltxav_audio_cross_pe"); set_backend_tensor_data(audio_cross_pe, audio_cross_pe_vec.data()); } bool needs_video_connector_pe = - params.use_connector && + config.use_connector && context != nullptr && - (context->ne[0] == params.connector_hidden_size || - ((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim || - context->ne[0] == params.caption_channels * 2) && + (context->ne[0] == config.connector_hidden_size || + ((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim || + context->ne[0] == config.caption_channels * 2) && context->ne[1] < 1024)); ggml_tensor* video_connector_pe = nullptr; if (needs_video_connector_pe) { int64_t seq_len = context->ne[1]; int64_t target_len = std::max(1024, seq_len); - int64_t duplications = (target_len + params.connector_num_registers - 1) / params.connector_num_registers; - int64_t full_len = seq_len + duplications * params.connector_num_registers - seq_len; - connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(params.connector_hidden_size), static_cast(params.connector_num_heads), 10000.f, 4096.f, true); - video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.connector_head_dim / 2, full_len * params.connector_num_heads); + int64_t duplications = (target_len + config.connector_num_registers - 1) / config.connector_num_registers; + int64_t full_len = seq_len + duplications * config.connector_num_registers - seq_len; + connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(config.connector_hidden_size), static_cast(config.connector_num_heads), 10000.f, 4096.f, true); + video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.connector_head_dim / 2, full_len * config.connector_num_heads); ggml_set_name(video_connector_pe, "ltxav_video_connector_pe"); set_backend_tensor_data(video_connector_pe, connector_pe_vec.data()); } @@ -1864,20 +1890,20 @@ namespace LTXV { ax->ne[1] > 0; bool needs_audio_connector_pe = run_audio_context && - params.use_audio_connector && + config.use_audio_connector && context != nullptr && - (context->ne[0] == params.audio_connector_hidden_size || - ((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim || - context->ne[0] == params.caption_channels * 2) && + (context->ne[0] == config.audio_connector_hidden_size || + ((context->ne[0] == config.cross_attention_dim + config.audio_cross_attention_dim || + context->ne[0] == config.caption_channels * 2) && context->ne[1] < 1024)); ggml_tensor* audio_connector_pe = nullptr; if (needs_audio_connector_pe) { int64_t seq_len = context->ne[1]; int64_t target_len = std::max(1024, seq_len); - int64_t duplications = (target_len + params.audio_connector_num_registers - 1) / params.audio_connector_num_registers; - int64_t full_len = seq_len + duplications * params.audio_connector_num_registers - seq_len; - audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(params.audio_connector_hidden_size), static_cast(params.audio_connector_num_heads), 10000.f, 4096.f, true); - audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_connector_head_dim / 2, full_len * params.audio_connector_num_heads); + int64_t duplications = (target_len + config.audio_connector_num_registers - 1) / config.audio_connector_num_registers; + int64_t full_len = seq_len + duplications * config.audio_connector_num_registers - seq_len; + audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(config.audio_connector_hidden_size), static_cast(config.audio_connector_num_heads), 10000.f, 4096.f, true); + audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.audio_connector_head_dim / 2, full_len * config.audio_connector_num_heads); ggml_set_name(audio_connector_pe, "ltxav_audio_connector_pe"); set_backend_tensor_data(audio_connector_pe, audio_connector_pe_vec.data()); } diff --git a/src/mmdit.hpp b/src/mmdit.hpp index 45bc2d916..326bd0797 100644 --- a/src/mmdit.hpp +++ b/src/mmdit.hpp @@ -1,7 +1,10 @@ #ifndef __MMDIT_HPP__ #define __MMDIT_HPP__ +#include #include +#include +#include #include "diffusion_model.hpp" #include "ggml_extend.hpp" @@ -9,6 +12,128 @@ #define MMDIT_GRAPH_SIZE 10240 +struct MMDiTConfig { + int64_t input_size = -1; + int patch_size = 2; + int64_t in_channels = 16; + int64_t d_self = -1; // >=0 for MMdiT-X + int64_t depth = 24; + float mlp_ratio = 4.0f; + int64_t adm_in_channels = 2048; + int64_t out_channels = 16; + int64_t pos_embed_max_size = 192; + int64_t num_patches = 36864; // 192 * 192 + int64_t context_size = 4096; + int64_t context_embedder_out_dim = 1536; + int64_t hidden_size = 1536; + std::string qk_norm; + + static MMDiTConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + MMDiTConfig config; + bool has_weight_config = false; + bool has_pos_embed = false; + bool has_hidden_size = false; + bool has_context_embed = false; + + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + + if (name.find("x_embedder.proj.weight") != std::string::npos && tensor_storage.n_dims == 4) { + has_weight_config = true; + has_hidden_size = true; + config.patch_size = static_cast(tensor_storage.ne[0]); + config.in_channels = tensor_storage.ne[2]; + config.hidden_size = tensor_storage.ne[3]; + } else if (name.find("t_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) { + has_weight_config = true; + has_hidden_size = true; + config.hidden_size = tensor_storage.ne[1]; + } else if (name.find("y_embedder.mlp.0.weight") != std::string::npos && tensor_storage.n_dims == 2) { + has_weight_config = true; + has_hidden_size = true; + config.adm_in_channels = tensor_storage.ne[0]; + config.hidden_size = tensor_storage.ne[1]; + } else if (name.find("context_embedder.weight") != std::string::npos && tensor_storage.n_dims == 2) { + has_weight_config = true; + has_context_embed = true; + config.context_size = tensor_storage.ne[0]; + config.context_embedder_out_dim = tensor_storage.ne[1]; + } else if (name.find("final_layer.linear.weight") != std::string::npos && tensor_storage.n_dims == 2) { + has_weight_config = true; + has_hidden_size = true; + config.hidden_size = tensor_storage.ne[0]; + int64_t patch_area = static_cast(config.patch_size) * config.patch_size; + if (patch_area > 0) { + config.out_channels = tensor_storage.ne[1] / patch_area; + } + } else if (name.find("pos_embed") != std::string::npos && tensor_storage.n_dims == 3) { + has_weight_config = true; + has_pos_embed = true; + has_hidden_size = true; + config.hidden_size = tensor_storage.ne[0]; + config.num_patches = tensor_storage.ne[1]; + for (int64_t size = 1; size * size <= config.num_patches; size++) { + if (size * size == config.num_patches) { + config.pos_embed_max_size = size; + break; + } + } + } + + size_t jb = name.find("joint_blocks."); + if (jb == std::string::npos) { + continue; + } + + has_weight_config = true; + std::string block_name = name.substr(jb); + int64_t block_depth = atoi(block_name.substr(13, block_name.find(".", 13)).c_str()); + if (block_depth + 1 > config.depth) { + config.depth = block_depth + 1; + } + if (block_name.find("attn.ln") != std::string::npos) { + if (block_name.find(".bias") != std::string::npos) { + config.qk_norm = "ln"; + } else { + config.qk_norm = "rms"; + } + } + if (block_name.find("attn2") != std::string::npos) { + if (block_depth > config.d_self) { + config.d_self = block_depth; + } + } + } + + if (!has_pos_embed && config.d_self >= 0) { + config.pos_embed_max_size *= 2; + config.num_patches *= 4; + } + if (!has_hidden_size || config.hidden_size <= 0) { + config.hidden_size = 64 * config.depth; + } + if (!has_context_embed || config.context_embedder_out_dim <= 0) { + config.context_embedder_out_dim = config.hidden_size; + } + + if (has_weight_config) { + LOG_DEBUG("mmdit: num_layers = %" PRId64 ", num_mmdit_x_layers = %" PRId64 ", hidden_size = %" PRId64 ", patch_size = %d, in_channels = %" PRId64 ", out_channels = %" PRId64 ", context_size = %" PRId64 ", adm_in_channels = %" PRId64 ", qk_norm = %s", + config.depth, + config.d_self + 1, + config.hidden_size, + config.patch_size, + config.in_channels, + config.out_channels, + config.context_size, + config.adm_in_channels, + config.qk_norm.empty() ? "none" : config.qk_norm.c_str()); + } + return config; + } +}; + struct Mlp : public GGMLBlock { public: Mlp(int64_t in_features, @@ -612,28 +737,16 @@ struct FinalLayer : public GGMLBlock { struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - int64_t input_size = -1; - int patch_size = 2; - int64_t in_channels = 16; - int64_t d_self = -1; // >=0 for MMdiT-X - int64_t depth = 24; - float mlp_ratio = 4.0f; - int64_t adm_in_channels = 2048; - int64_t out_channels = 16; - int64_t pos_embed_max_size = 192; - int64_t num_patchs = 36864; // 192 * 192 - int64_t context_size = 4096; - int64_t context_embedder_out_dim = 1536; - int64_t hidden_size; - std::string qk_norm; - void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; - params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); + params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, config.hidden_size, config.num_patches, 1); } public: - MMDiT(const String2TensorStorage& tensor_storage_map = {}) { + MMDiTConfig config; + + explicit MMDiT(MMDiTConfig config = {}) + : config(config) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -646,64 +759,30 @@ struct MMDiT : public GGMLBlock { // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - for (auto pair : tensor_storage_map) { - std::string tensor_name = pair.first; - if (tensor_name.find("model.diffusion_model.") == std::string::npos) - continue; - size_t jb = tensor_name.find("joint_blocks."); - if (jb != std::string::npos) { - tensor_name = tensor_name.substr(jb); // remove prefix - int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str()); - if (block_depth + 1 > depth) { - depth = block_depth + 1; - } - if (tensor_name.find("attn.ln") != std::string::npos) { - if (tensor_name.find(".bias") != std::string::npos) { - qk_norm = "ln"; - } else { - qk_norm = "rms"; - } - } - if (tensor_name.find("attn2") != std::string::npos) { - if (block_depth > d_self) { - d_self = block_depth; - } - } - } - } - - if (d_self >= 0) { - pos_embed_max_size *= 2; - num_patchs *= 4; - } - - LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1); - - int64_t default_out_channels = in_channels; - hidden_size = 64 * depth; - context_embedder_out_dim = 64 * depth; - int64_t num_heads = depth; - - blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true)); - blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(hidden_size)); + blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(config.input_size, + config.patch_size, + config.in_channels, + config.hidden_size, + true)); + blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(config.hidden_size)); - if (adm_in_channels != -1) { - blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size)); + if (config.adm_in_channels != -1) { + blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(config.adm_in_channels, config.hidden_size)); } - blocks["context_embedder"] = std::shared_ptr(new Linear(4096, context_embedder_out_dim, true, true)); + blocks["context_embedder"] = std::shared_ptr(new Linear(config.context_size, config.context_embedder_out_dim, true, true)); - for (int i = 0; i < depth; i++) { - blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size, - num_heads, - mlp_ratio, - qk_norm, + for (int i = 0; i < config.depth; i++) { + blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(config.hidden_size, + config.depth, + config.mlp_ratio, + config.qk_norm, true, - i == depth - 1, - i <= d_self)); + i == config.depth - 1, + i <= config.d_self)); } - blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); + blocks["final_layer"] = std::shared_ptr(new FinalLayer(config.hidden_size, config.patch_size, config.out_channels)); } ggml_tensor* @@ -712,22 +791,22 @@ struct MMDiT : public GGMLBlock { int64_t w) { auto pos_embed = params["pos_embed"]; - h = (h + 1) / patch_size; - w = (w + 1) / patch_size; + h = (h + 1) / config.patch_size; + w = (w + 1) / config.patch_size; - GGML_ASSERT(h <= pos_embed_max_size && h > 0); - GGML_ASSERT(w <= pos_embed_max_size && w > 0); + GGML_ASSERT(h <= config.pos_embed_max_size && h > 0); + GGML_ASSERT(w <= config.pos_embed_max_size && w > 0); - int64_t top = (pos_embed_max_size - h) / 2; - int64_t left = (pos_embed_max_size - w) / 2; + int64_t top = (config.pos_embed_max_size - h) / 2; + int64_t left = (config.pos_embed_max_size - w) / 2; - auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size); + auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, config.hidden_size, config.pos_embed_max_size, config.pos_embed_max_size); // spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, - hidden_size, - pos_embed_max_size, + config.hidden_size, + config.pos_embed_max_size, h, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], @@ -735,14 +814,14 @@ struct MMDiT : public GGMLBlock { spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [pos_embed_max_size, h, hidden_size] spatial_pos_embed = ggml_view_3d(ctx, spatial_pos_embed, - hidden_size, + config.hidden_size, h, w, spatial_pos_embed->nb[1], spatial_pos_embed->nb[2], - spatial_pos_embed->nb[2] * left); // [w, h, hidden_size] - spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size] - spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); // [1, h*w, hidden_size] + spatial_pos_embed->nb[2] * left); // [w, h, hidden_size] + spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size] + spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, config.hidden_size, h * w, 1); // [1, h*w, hidden_size] return spatial_pos_embed; } @@ -757,7 +836,7 @@ struct MMDiT : public GGMLBlock { // return: [N, N*W, patch_size * patch_size * out_channels] auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - for (int i = 0; i < depth; i++) { + for (int i = 0; i < config.depth; i++) { // skip iteration if i is in skip_layers if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { continue; @@ -800,7 +879,7 @@ struct MMDiT : public GGMLBlock { x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size] - if (y != nullptr && adm_in_channels != -1) { + if (y != nullptr && config.adm_in_channels != -1) { auto y_embedder = std::dynamic_pointer_cast(blocks["y_embedder"]); y = y_embedder->forward(ctx, y); // [N, hidden_size] @@ -820,19 +899,22 @@ struct MMDiT : public GGMLBlock { x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) - x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, /*patch_last*/ false); // [N, C, H, W] + x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, config.patch_size, config.patch_size, /*patch_last*/ false); // [N, C, H, W] return x; } }; struct MMDiTRunner : public DiffusionModelRunner { + MMDiTConfig config; MMDiT mmdit; MMDiTRunner(ggml_backend_t backend, ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") - : DiffusionModelRunner(backend, params_backend, prefix), mmdit(tensor_storage_map) { + : DiffusionModelRunner(backend, params_backend, prefix), + config(MMDiTConfig::detect_from_weights(tensor_storage_map, prefix)), + mmdit(config) { mmdit.init(params_ctx, tensor_storage_map, prefix); } diff --git a/src/pid.hpp b/src/pid.hpp index c29c207f0..39ab004dc 100644 --- a/src/pid.hpp +++ b/src/pid.hpp @@ -16,7 +16,7 @@ namespace Pid { constexpr int PID_GRAPH_SIZE = 196608; constexpr float PID_PI = 3.14159265358979323846f; - struct PixelDiTParams { + struct PixelDiTConfig { int64_t in_channels = 3; int64_t hidden_size = 1536; int64_t num_groups = 24; @@ -38,6 +38,45 @@ namespace Pid { int64_t lq_latent_down_factor = 8; int64_t rope_ref_grid_h = 64; int64_t rope_ref_grid_w = 64; + + static PixelDiTConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + PixelDiTConfig config; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + size_t pos = name.find("patch_blocks."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + config.patch_depth = std::max(config.patch_depth, block_index + 1); + } + } + pos = name.find("pixel_blocks."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + config.pixel_depth = std::max(config.pixel_depth, block_index + 1); + } + } + if (name.find("lq_proj.latent_proj.0.weight") != std::string::npos) { + config.lq_latent_channels = tensor_storage.ne[2]; + config.lq_latent_down_factor = config.lq_latent_channels >= 64 ? 16 : 8; + } + if (name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) { + config.patch_mlp_hidden_dim = tensor_storage.ne[1]; + } + } + LOG_DEBUG("pid: patch_depth = %" PRId64 ", pixel_depth = %" PRId64 ", patch_mlp_hidden_dim = %" PRId64 ", lq_latent_channels = %" PRId64 ", lq_latent_down_factor = %" PRId64, + config.patch_depth, + config.pixel_depth, + config.patch_mlp_hidden_dim, + config.lq_latent_channels, + config.lq_latent_down_factor); + return config; + } }; inline std::vector make_rope_1d(int length, @@ -466,29 +505,29 @@ namespace Pid { }; struct LQProjection2D : public GGMLBlock { - PixelDiTParams params_cfg; - - LQProjection2D(const PixelDiTParams& params_cfg) - : params_cfg(params_cfg) { - blocks["latent_proj.0"] = std::make_shared(params_cfg.lq_latent_channels, params_cfg.lq_hidden_dim, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); - blocks["latent_proj.2"] = std::make_shared(params_cfg.lq_hidden_dim, params_cfg.lq_hidden_dim, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); - for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) { - blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared(params_cfg.lq_hidden_dim); + PixelDiTConfig config; + + LQProjection2D(const PixelDiTConfig& config) + : config(config) { + blocks["latent_proj.0"] = std::make_shared(config.lq_latent_channels, config.lq_hidden_dim, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); + blocks["latent_proj.2"] = std::make_shared(config.lq_hidden_dim, config.lq_hidden_dim, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); + for (int i = 0; i < config.lq_num_res_blocks; ++i) { + blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared(config.lq_hidden_dim); } - int num_outputs = static_cast((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval); + int num_outputs = static_cast((config.patch_depth + config.lq_interval - 1) / config.lq_interval); for (int i = 0; i < num_outputs; ++i) { - blocks["output_heads." + std::to_string(i)] = std::make_shared(params_cfg.lq_hidden_dim, params_cfg.hidden_size, true); - blocks["gate_modules." + std::to_string(i)] = std::make_shared(params_cfg.hidden_size); + blocks["output_heads." + std::to_string(i)] = std::make_shared(config.lq_hidden_dim, config.hidden_size, true); + blocks["gate_modules." + std::to_string(i)] = std::make_shared(config.hidden_size); } } bool is_gate_active(int block_idx) const { - return block_idx % params_cfg.lq_interval == 0; + return block_idx % config.lq_interval == 0; } int get_output_index(int block_idx) const { - return block_idx / static_cast(params_cfg.lq_interval); + return block_idx / static_cast(config.lq_interval); } ggml_tensor* gate(GGMLRunnerContext* ctx, @@ -506,8 +545,8 @@ namespace Pid { int64_t target_pW) { auto conv0 = std::dynamic_pointer_cast(blocks["latent_proj.0"]); auto conv2 = std::dynamic_pointer_cast(blocks["latent_proj.2"]); - float z_to_patch_ratio = static_cast(params_cfg.lq_sr_scale * params_cfg.lq_latent_down_factor) / - static_cast(params_cfg.patch_size); + float z_to_patch_ratio = static_cast(config.lq_sr_scale * config.lq_latent_down_factor) / + static_cast(config.patch_size); GGML_ASSERT(z_to_patch_ratio >= 1.0f); if (lq_latent->ne[0] != target_pW || lq_latent->ne[1] != target_pH) { lq_latent = ggml_interpolate(ctx->ggml_ctx, @@ -522,7 +561,7 @@ namespace Pid { auto feat = conv0->forward(ctx, lq_latent); feat = ggml_silu_inplace(ctx->ggml_ctx, feat); feat = conv2->forward(ctx, feat); - for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) { + for (int i = 0; i < config.lq_num_res_blocks; ++i) { auto block = std::dynamic_pointer_cast(blocks["latent_proj." + std::to_string(3 + i)]); feat = block->forward(ctx, feat); } @@ -533,7 +572,7 @@ namespace Pid { auto tokens = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, feat, 2, 0, 1, 3)); tokens = ggml_reshape_3d(ctx->ggml_ctx, tokens, C, L, B); - int num_outputs = static_cast((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval); + int num_outputs = static_cast((config.patch_depth + config.lq_interval - 1) / config.lq_interval); std::vector outputs; outputs.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { @@ -545,34 +584,34 @@ namespace Pid { }; struct PixelDiT : public GGMLBlock { - PixelDiTParams params_cfg; + PixelDiTConfig config; PixelDiT() = default; - PixelDiT(const PixelDiTParams& params_cfg) - : params_cfg(params_cfg) { - blocks["pixel_embedder"] = std::make_shared(params_cfg.in_channels, params_cfg.pixel_hidden_size); - blocks["s_embedder"] = std::make_shared(params_cfg.in_channels * params_cfg.patch_size * params_cfg.patch_size, params_cfg.hidden_size, false, true); - blocks["t_embedder"] = std::make_shared(params_cfg.hidden_size); - blocks["y_embedder"] = std::make_shared(params_cfg.txt_embed_dim, params_cfg.hidden_size, true, true); - for (int i = 0; i < params_cfg.patch_depth; ++i) { - blocks["patch_blocks." + std::to_string(i)] = std::make_shared(params_cfg.hidden_size, params_cfg.num_groups, params_cfg.patch_mlp_hidden_dim); + PixelDiT(const PixelDiTConfig& config) + : config(config) { + blocks["pixel_embedder"] = std::make_shared(config.in_channels, config.pixel_hidden_size); + blocks["s_embedder"] = std::make_shared(config.in_channels * config.patch_size * config.patch_size, config.hidden_size, false, true); + blocks["t_embedder"] = std::make_shared(config.hidden_size); + blocks["y_embedder"] = std::make_shared(config.txt_embed_dim, config.hidden_size, true, true); + for (int i = 0; i < config.patch_depth; ++i) { + blocks["patch_blocks." + std::to_string(i)] = std::make_shared(config.hidden_size, config.num_groups, config.patch_mlp_hidden_dim); } - for (int i = 0; i < params_cfg.pixel_depth; ++i) { - blocks["pixel_blocks." + std::to_string(i)] = std::make_shared(params_cfg.pixel_hidden_size, - params_cfg.hidden_size, - params_cfg.patch_size, - params_cfg.pixel_attn_hidden_size, - params_cfg.pixel_num_groups); + for (int i = 0; i < config.pixel_depth; ++i) { + blocks["pixel_blocks." + std::to_string(i)] = std::make_shared(config.pixel_hidden_size, + config.hidden_size, + config.patch_size, + config.pixel_attn_hidden_size, + config.pixel_num_groups); } - blocks["final_layer"] = std::make_shared(params_cfg.pixel_hidden_size, params_cfg.in_channels); - blocks["lq_proj"] = std::make_shared(params_cfg); + blocks["final_layer"] = std::make_shared(config.pixel_hidden_size, config.in_channels); + blocks["lq_proj"] = std::make_shared(config); } void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { - params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, params_cfg.hidden_size, params_cfg.txt_max_length, 1); + params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, config.hidden_size, config.txt_max_length, 1); } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -594,21 +633,21 @@ namespace Pid { int64_t W_orig = x->ne[0]; int64_t H_orig = x->ne[1]; - x = DiT::pad_to_patch_size(ctx, x, static_cast(params_cfg.patch_size), static_cast(params_cfg.patch_size)); + x = DiT::pad_to_patch_size(ctx, x, static_cast(config.patch_size), static_cast(config.patch_size)); int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t B = x->ne[3]; - int64_t Hs = H / params_cfg.patch_size; - int64_t Ws = W / params_cfg.patch_size; + int64_t Hs = H / config.patch_size; + int64_t Ws = W / config.patch_size; int64_t L = Hs * Ws; - int64_t P2 = params_cfg.patch_size * params_cfg.patch_size; + int64_t P2 = config.patch_size * config.patch_size; - auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast(params_cfg.patch_size), static_cast(params_cfg.patch_size), true); + auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast(config.patch_size), static_cast(config.patch_size), true); auto t_emb = t_embedder->forward(ctx, timesteps); auto condition = ggml_silu(ctx->ggml_ctx, t_emb); GGML_ASSERT(context != nullptr); - int64_t Ltxt = std::min(context->ne[1], params_cfg.txt_max_length); + int64_t Ltxt = std::min(context->ne[1], config.txt_max_length); auto y = ggml_ext_slice(ctx->ggml_ctx, context, 1, 0, Ltxt); auto y_emb = y_embedder->forward(ctx, y); auto y_pos = ggml_ext_slice(ctx->ggml_ctx, params["y_pos_embedding"], 1, 0, Ltxt); @@ -618,7 +657,7 @@ namespace Pid { auto s = s_embedder->forward(ctx, x_patches); - for (int i = 0; i < params_cfg.patch_depth; ++i) { + for (int i = 0; i < config.patch_depth; ++i) { if (lq_proj->is_gate_active(i)) { int out_idx = lq_proj->get_output_index(i); if (out_idx < static_cast(lq_features.size())) { @@ -639,22 +678,22 @@ namespace Pid { } s = ggml_silu(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, s, t_emb)); - auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, params_cfg.hidden_size, L * B); - auto pixels = pixel_embedder->forward(ctx, x, params_cfg.patch_size, pixel_pos_full); - for (int i = 0; i < params_cfg.pixel_depth; ++i) { + auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, config.hidden_size, L * B); + auto pixels = pixel_embedder->forward(ctx, x, config.patch_size, pixel_pos_full); + for (int i = 0; i < config.pixel_depth; ++i) { auto block = std::dynamic_pointer_cast(blocks["pixel_blocks." + std::to_string(i)]); pixels = block->forward(ctx, pixels, s_cond, H, W, pixel_pos_comp); sd::ggml_graph_cut::mark_graph_cut(pixels, "pid.pixel_blocks." + std::to_string(i), "pixels"); } pixels = final_layer->forward(ctx, pixels); - pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, params_cfg.in_channels * P2, L, B); + pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, config.in_channels * P2, L, B); auto out = DiT::unpatchify(ctx->ggml_ctx, pixels, Hs, Ws, - static_cast(params_cfg.patch_size), - static_cast(params_cfg.patch_size), + static_cast(config.patch_size), + static_cast(config.patch_size), false); out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H_orig); out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W_orig); @@ -663,7 +702,7 @@ namespace Pid { }; struct PiDRunner : public DiffusionModelRunner { - PixelDiTParams params_cfg; + PixelDiTConfig config; PixelDiT model; std::vector pos_img_vec; std::vector pos_txt_vec; @@ -674,43 +713,9 @@ namespace Pid { ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map, const std::string prefix = "model.diffusion_model") - : DiffusionModelRunner(backend, params_backend, prefix) { - for (const auto& pair : tensor_storage_map) { - const std::string& tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) { - continue; - } - size_t pos = tensor_name.find("patch_blocks."); - if (pos != std::string::npos) { - auto items = split_string(tensor_name.substr(pos), '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - params_cfg.patch_depth = std::max(params_cfg.patch_depth, block_index + 1); - } - } - pos = tensor_name.find("pixel_blocks."); - if (pos != std::string::npos) { - auto items = split_string(tensor_name.substr(pos), '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - params_cfg.pixel_depth = std::max(params_cfg.pixel_depth, block_index + 1); - } - } - if (tensor_name.find("lq_proj.latent_proj.0.weight") != std::string::npos) { - params_cfg.lq_latent_channels = pair.second.ne[2]; - params_cfg.lq_latent_down_factor = params_cfg.lq_latent_channels >= 64 ? 16 : 8; - } - if (tensor_name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) { - params_cfg.patch_mlp_hidden_dim = pair.second.ne[1]; - } - } - LOG_INFO("PiD params: patch_depth=%" PRId64 ", pixel_depth=%" PRId64 ", patch_mlp_hidden_dim=%" PRId64 ", lq_latent_channels=%" PRId64 ", lq_latent_down_factor=%" PRId64, - params_cfg.patch_depth, - params_cfg.pixel_depth, - params_cfg.patch_mlp_hidden_dim, - params_cfg.lq_latent_channels, - params_cfg.lq_latent_down_factor); - model = PixelDiT(params_cfg); + : DiffusionModelRunner(backend, params_backend, prefix), + config(PixelDiTConfig::detect_from_weights(tensor_storage_map, prefix)) { + model = PixelDiT(config); model.init(params_ctx, tensor_storage_map, prefix); } @@ -737,60 +742,60 @@ namespace Pid { int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t B = x->ne[3]; - int64_t Wp = align_up(static_cast(W), static_cast(params_cfg.patch_size)); - int64_t Hp = align_up(static_cast(H), static_cast(params_cfg.patch_size)); - int64_t Hs = Hp / params_cfg.patch_size; - int64_t Ws = Wp / params_cfg.patch_size; + int64_t Wp = align_up(static_cast(W), static_cast(config.patch_size)); + int64_t Hp = align_up(static_cast(H), static_cast(config.patch_size)); + int64_t Hs = Hp / config.patch_size; + int64_t Ws = Wp / config.patch_size; pos_img_vec = make_rope_2d(static_cast(Hs), static_cast(Ws), - static_cast(params_cfg.hidden_size / params_cfg.num_groups), + static_cast(config.hidden_size / config.num_groups), 10000.f, 16.f, - static_cast(params_cfg.rope_ref_grid_h), - static_cast(params_cfg.rope_ref_grid_w)); + static_cast(config.rope_ref_grid_h), + static_cast(config.rope_ref_grid_w)); auto pos_img = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, - params_cfg.hidden_size / params_cfg.num_groups / 2, + config.hidden_size / config.num_groups / 2, Hs * Ws); set_backend_tensor_data(pos_img, pos_img_vec.data()); - int64_t Ltxt = std::min(context->ne[1], params_cfg.txt_max_length); + int64_t Ltxt = std::min(context->ne[1], config.txt_max_length); pos_txt_vec = make_rope_1d(static_cast(Ltxt), - static_cast(params_cfg.hidden_size / params_cfg.num_groups), - params_cfg.text_rope_theta); + static_cast(config.hidden_size / config.num_groups), + config.text_rope_theta); auto pos_txt = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, - params_cfg.hidden_size / params_cfg.num_groups / 2, + config.hidden_size / config.num_groups / 2, Ltxt); set_backend_tensor_data(pos_txt, pos_txt_vec.data()); pixel_pos_vec = make_pixel_abs_pos(static_cast(Hp), static_cast(Wp), - static_cast(params_cfg.pixel_hidden_size)); + static_cast(config.pixel_hidden_size)); auto pixel_pos = ggml_new_tensor_3d(compute_ctx, GGML_TYPE_F32, - params_cfg.pixel_hidden_size, + config.pixel_hidden_size, Wp * Hp, 1); set_backend_tensor_data(pixel_pos, pixel_pos_vec.data()); pixel_pos_comp_vec = make_rope_2d(static_cast(Hs), static_cast(Ws), - static_cast(params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups), + static_cast(config.pixel_attn_hidden_size / config.pixel_num_groups), 10000.f, 16.f, - static_cast(params_cfg.rope_ref_grid_h), - static_cast(params_cfg.rope_ref_grid_w)); + static_cast(config.rope_ref_grid_h), + static_cast(config.rope_ref_grid_w)); auto pixel_pos_comp = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, - params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups / 2, + config.pixel_attn_hidden_size / config.pixel_num_groups / 2, Hs * Ws); set_backend_tensor_data(pixel_pos_comp, pixel_pos_comp_vec.data()); diff --git a/src/qwen_image.hpp b/src/qwen_image.hpp index bea71b97f..de52c880b 100644 --- a/src/qwen_image.hpp +++ b/src/qwen_image.hpp @@ -10,6 +10,48 @@ namespace Qwen { constexpr int QWEN_IMAGE_GRAPH_SIZE = 20480; + struct QwenImageConfig { + int patch_size = 2; + int64_t in_channels = 64; + int64_t out_channels = 16; + int num_layers = 60; + int64_t attention_head_dim = 128; + int64_t num_attention_heads = 24; + int64_t joint_attention_dim = 3584; + int theta = 10000; + std::vector axes_dim = {16, 56, 56}; + int axes_dim_sum = 128; + bool zero_cond_t = false; + + static QwenImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + QwenImageConfig config; + config.num_layers = 0; + for (const auto& [name, _] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (name.find("__index_timestep_zero__") != std::string::npos) { + config.zero_cond_t = true; + } + size_t pos = name.find("transformer_blocks."); + if (pos == std::string::npos) { + continue; + } + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.num_layers) { + config.num_layers = block_index + 1; + } + } + } + LOG_DEBUG("qwen_image: num_layers = %d, zero_cond_t = %s", + config.num_layers, + config.zero_cond_t ? "true" : "false"); + return config; + } + }; + struct TimestepEmbedding : public GGMLBlock { public: TimestepEmbedding(int64_t in_channels, @@ -350,46 +392,32 @@ namespace Qwen { } }; - struct QwenImageParams { - int patch_size = 2; - int64_t in_channels = 64; - int64_t out_channels = 16; - int num_layers = 60; - int64_t attention_head_dim = 128; - int64_t num_attention_heads = 24; - int64_t joint_attention_dim = 3584; - int theta = 10000; - std::vector axes_dim = {16, 56, 56}; - int axes_dim_sum = 128; - bool zero_cond_t = false; - }; - class QwenImageModel : public GGMLBlock { protected: - QwenImageParams params; + QwenImageConfig config; public: QwenImageModel() {} - QwenImageModel(QwenImageParams params) - : params(params) { - int64_t inner_dim = params.num_attention_heads * params.attention_head_dim; + QwenImageModel(QwenImageConfig config) + : config(config) { + int64_t inner_dim = config.num_attention_heads * config.attention_head_dim; blocks["time_text_embed"] = std::shared_ptr(new QwenTimestepProjEmbeddings(inner_dim)); - blocks["txt_norm"] = std::shared_ptr(new RMSNorm(params.joint_attention_dim, 1e-6f)); - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, inner_dim)); - blocks["txt_in"] = std::shared_ptr(new Linear(params.joint_attention_dim, inner_dim)); + blocks["txt_norm"] = std::shared_ptr(new RMSNorm(config.joint_attention_dim, 1e-6f)); + blocks["img_in"] = std::shared_ptr(new Linear(config.in_channels, inner_dim)); + blocks["txt_in"] = std::shared_ptr(new Linear(config.joint_attention_dim, inner_dim)); // blocks - for (int i = 0; i < params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::shared_ptr(new QwenImageTransformerBlock(inner_dim, - params.num_attention_heads, - params.attention_head_dim, + config.num_attention_heads, + config.attention_head_dim, 1e-6f, - params.zero_cond_t)); + config.zero_cond_t)); blocks["transformer_blocks." + std::to_string(i)] = block; } blocks["norm_out"] = std::shared_ptr(new AdaLayerNormContinuous(inner_dim, inner_dim, false, 1e-6f)); - blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, params.patch_size * params.patch_size * params.out_channels)); + blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, config.patch_size * config.patch_size * config.out_channels)); } ggml_tensor* forward_orig(GGMLRunnerContext* ctx, @@ -406,7 +434,7 @@ namespace Qwen { auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); auto t_emb = time_text_embed->forward(ctx, timestep); - if (params.zero_cond_t) { + if (config.zero_cond_t) { auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep)); t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1); } @@ -417,7 +445,7 @@ namespace Qwen { sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.prelude", "txt"); // sd::ggml_graph_cut::mark_graph_cut(t_emb, "qwen_image.prelude", "t_emb"); - for (int i = 0; i < params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index); @@ -427,7 +455,7 @@ namespace Qwen { sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.transformer_blocks." + std::to_string(i), "txt"); } - if (params.zero_cond_t) { + if (config.zero_cond_t) { t_emb = ggml_ext_chunk(ctx->ggml_ctx, t_emb, 2, 1)[0]; } @@ -456,12 +484,12 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = DiT::pad_and_patchify(ctx, x, params.patch_size, params.patch_size); + auto img = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size); int64_t img_tokens = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = DiT::pad_and_patchify(ctx, ref, params.patch_size, params.patch_size); + ref = DiT::pad_and_patchify(ctx, ref, config.patch_size, config.patch_size); img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } @@ -474,7 +502,7 @@ namespace Qwen { out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] } - out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, params.patch_size, params.patch_size); // [N, C, H, W] + out = DiT::unpatchify_and_crop(ctx->ggml_ctx, out, H, W, config.patch_size, config.patch_size); // [N, C, H, W] return out; } @@ -482,7 +510,7 @@ namespace Qwen { struct QwenImageRunner : public DiffusionModelRunner { public: - QwenImageParams qwen_image_params; + QwenImageConfig config; QwenImageModel qwen_image; std::vector pe_vec; std::vector modulate_index_vec; @@ -494,34 +522,10 @@ namespace Qwen { const std::string prefix = "", SDVersion version = VERSION_QWEN_IMAGE, bool zero_cond_t = false) - : DiffusionModelRunner(backend, params_backend, prefix) { - qwen_image_params.num_layers = 0; - qwen_image_params.zero_cond_t = zero_cond_t; - for (auto pair : tensor_storage_map) { - std::string tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) - continue; - if (tensor_name.find("__index_timestep_zero__") != std::string::npos) { - qwen_image_params.zero_cond_t = true; - } - size_t pos = tensor_name.find("transformer_blocks."); - if (pos != std::string::npos) { - tensor_name = tensor_name.substr(pos); // remove prefix - auto items = split_string(tensor_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > qwen_image_params.num_layers) { - qwen_image_params.num_layers = block_index + 1; - } - } - continue; - } - } - LOG_INFO("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers); - if (qwen_image_params.zero_cond_t) { - LOG_INFO("use zero_cond_t"); - } - qwen_image = QwenImageModel(qwen_image_params); + : DiffusionModelRunner(backend, params_backend, prefix), + config(QwenImageConfig::detect_from_weights(tensor_storage_map, prefix)) { + config.zero_cond_t = config.zero_cond_t || zero_cond_t; + qwen_image = QwenImageModel(config); qwen_image.init(params_ctx, tensor_storage_map, prefix); } @@ -552,36 +556,36 @@ namespace Qwen { pe_vec = Rope::gen_qwen_image_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), - qwen_image_params.patch_size, + config.patch_size, static_cast(x->ne[3]), static_cast(context->ne[1]), ref_latents, increase_ref_index, - qwen_image_params.theta, + config.theta, circular_y_enabled, circular_x_enabled, - qwen_image_params.axes_dim); - int pos_len = static_cast(pe_vec.size() / qwen_image_params.axes_dim_sum / 2); + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe, true, "pe"); // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); ggml_tensor* modulate_index = nullptr; - if (qwen_image_params.zero_cond_t) { + if (config.zero_cond_t) { modulate_index_vec.clear(); - int64_t h_len = ((x->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); - int64_t w_len = ((x->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + int64_t h_len = ((x->ne[1] + (config.patch_size / 2)) / config.patch_size); + int64_t w_len = ((x->ne[0] + (config.patch_size / 2)) / config.patch_size); int64_t num_img_tokens = h_len * w_len; modulate_index_vec.insert(modulate_index_vec.end(), num_img_tokens, 0.f); int64_t num_ref_img_tokens = 0; for (ggml_tensor* ref : ref_latents) { - int64_t h_len = ((ref->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); - int64_t w_len = ((ref->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + int64_t h_len = ((ref->ne[1] + (config.patch_size / 2)) / config.patch_size); + int64_t w_len = ((ref->ne[0] + (config.patch_size / 2)) / config.patch_size); num_ref_img_tokens += h_len * w_len; } diff --git a/src/t5.hpp b/src/t5.hpp index 9b2bdaef1..c6fa5375b 100644 --- a/src/t5.hpp +++ b/src/t5.hpp @@ -14,6 +14,28 @@ #include "model.h" #include "tokenizers/t5_unigram_tokenizer.h" +struct T5Config { + int64_t num_layers = 24; + int64_t model_dim = 4096; + int64_t ff_dim = 10240; + int64_t num_heads = 64; + int64_t vocab_size = 32128; + bool relative_attention = true; + + static T5Config detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + bool is_umt5 = false) { + (void)tensor_storage_map; + (void)prefix; + T5Config config; + if (is_umt5) { + config.vocab_size = 256384; + config.relative_attention = false; + } + return config; + } +}; + class T5LayerNorm : public UnaryBlock { protected: int64_t hidden_size; @@ -272,30 +294,21 @@ struct T5Stack : public GGMLBlock { } }; -struct T5Params { - int64_t num_layers = 24; - int64_t model_dim = 4096; - int64_t ff_dim = 10240; - int64_t num_heads = 64; - int64_t vocab_size = 32128; - bool relative_attention = true; -}; - struct T5 : public GGMLBlock { - T5Params params; + T5Config config; public: T5() {} - T5(T5Params params) - : params(params) { - blocks["encoder"] = std::shared_ptr(new T5Stack(params.num_layers, - params.model_dim, - params.model_dim, - params.ff_dim, - params.num_heads, - params.relative_attention)); - blocks["shared"] = std::shared_ptr(new Embedding(params.vocab_size, - params.model_dim)); + T5(T5Config config) + : config(config) { + blocks["encoder"] = std::shared_ptr(new T5Stack(config.num_layers, + config.model_dim, + config.model_dim, + config.ff_dim, + config.num_heads, + config.relative_attention)); + blocks["shared"] = std::shared_ptr(new Embedding(config.vocab_size, + config.model_dim)); } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -316,7 +329,7 @@ struct T5 : public GGMLBlock { }; struct T5Runner : public GGMLRunner { - T5Params params; + T5Config config; T5 model; std::vector relative_position_bucket_vec; @@ -325,12 +338,9 @@ struct T5Runner : public GGMLRunner { const String2TensorStorage& tensor_storage_map, const std::string prefix, bool is_umt5 = false) - : GGMLRunner(backend, params_backend) { - if (is_umt5) { - params.vocab_size = 256384; - params.relative_attention = false; - } - model = T5(params); + : GGMLRunner(backend, params_backend), + config(T5Config::detect_from_weights(tensor_storage_map, prefix, is_umt5)) { + model = T5(config); model.init(params_ctx, tensor_storage_map, prefix); } diff --git a/src/unet.hpp b/src/unet.hpp index ef468741d..4cecb87c7 100644 --- a/src/unet.hpp +++ b/src/unet.hpp @@ -1,6 +1,9 @@ #ifndef __UNET_HPP__ #define __UNET_HPP__ +#include +#include + #include "common_block.hpp" #include "diffusion_model.hpp" #include "model.h" @@ -9,6 +12,125 @@ #define UNET_GRAPH_SIZE 102400 +struct UNetConfig { + SDVersion version = VERSION_SD1; + // network hparams + int in_channels = 4; + int out_channels = 4; + int num_res_blocks = 2; + std::vector attention_resolutions = {4, 2, 1}; + std::vector channel_mult = {1, 2, 4, 4}; + std::vector transformer_depth = {1, 1, 1, 1}; + int time_embed_dim = 1280; // model_channels*4 + int num_heads = 8; + int num_head_channels = -1; // channels // num_heads + int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL + bool use_linear_projection = false; + bool tiny_unet = false; + int model_channels = 320; + int adm_in_channels = 2816; // only for VERSION_SDXL/SVD + + static UNetConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + SDVersion version = VERSION_SD1) { + UNetConfig config; + config.version = version; + + if (sd_version_is_sd2(version)) { + config.context_dim = 1024; + config.num_head_channels = 64; + config.num_heads = -1; + config.use_linear_projection = true; + } else if (sd_version_is_sdxl(version)) { + config.context_dim = 2048; + config.attention_resolutions = {4, 2}; + config.channel_mult = {1, 2, 4}; + config.transformer_depth = {1, 2, 10}; + config.num_head_channels = 64; + config.num_heads = -1; + config.use_linear_projection = true; + if (version == VERSION_SDXL_VEGA) { + config.transformer_depth = {1, 1, 2}; + } + } else if (version == VERSION_SVD) { + config.in_channels = 8; + config.out_channels = 4; + config.context_dim = 1024; + config.adm_in_channels = 768; + config.num_head_channels = 64; + config.num_heads = -1; + config.use_linear_projection = true; + } + if (sd_version_is_inpaint(version)) { + config.in_channels = 9; + } else if (sd_version_is_unet_edit(version)) { + config.in_channels = 8; + } + if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { + config.num_res_blocks = 1; + config.channel_mult = {1, 2, 4}; + config.tiny_unet = true; + if (version == VERSION_SDXS_512_DS) { + config.attention_resolutions = {4, 2}; // here just like SDXL + } + } + + auto find_weight = [&](const std::string& suffix) -> const TensorStorage* { + std::string name = prefix.empty() ? suffix : prefix + "." + suffix; + auto it = tensor_storage_map.find(name); + if (it == tensor_storage_map.end()) { + return nullptr; + } + return &it->second; + }; + + if (const TensorStorage* input = find_weight("input_blocks.0.0.weight")) { + if (input->n_dims == 4) { + config.in_channels = static_cast(input->ne[2]); + config.model_channels = static_cast(input->ne[3]); + config.time_embed_dim = config.model_channels * 4; + } + } + if (const TensorStorage* time_embed = find_weight("time_embed.0.weight")) { + if (time_embed->n_dims == 2) { + config.model_channels = static_cast(time_embed->ne[0]); + config.time_embed_dim = static_cast(time_embed->ne[1]); + } + } + if (const TensorStorage* label_emb = find_weight("label_emb.0.0.weight")) { + if (label_emb->n_dims == 2) { + config.adm_in_channels = static_cast(label_emb->ne[0]); + config.time_embed_dim = static_cast(label_emb->ne[1]); + } + } + if (const TensorStorage* out = find_weight("out.2.weight")) { + if (out->n_dims == 4) { + config.out_channels = static_cast(out->ne[3]); + } + } + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (name.find("attn2.to_k.weight") != std::string::npos && tensor_storage.n_dims == 2) { + config.context_dim = static_cast(tensor_storage.ne[0]); + break; + } + } + + LOG_DEBUG("unet: in_channels = %d, out_channels = %d, model_channels = %d, time_embed_dim = %d, context_dim = %d, adm_in_channels = %d, num_res_blocks = %d, tiny_unet = %s", + config.in_channels, + config.out_channels, + config.model_channels, + config.time_embed_dim, + config.context_dim, + config.adm_in_channels, + config.num_res_blocks, + config.tiny_unet ? "true" : "false"); + return config; + } +}; + class SpatialVideoTransformer : public SpatialTransformer { protected: int64_t time_depth; @@ -166,66 +288,26 @@ class SpatialVideoTransformer : public SpatialTransformer { // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { -protected: - SDVersion version = VERSION_SD1; - // network hparams - int in_channels = 4; - int out_channels = 4; - int num_res_blocks = 2; - std::vector attention_resolutions = {4, 2, 1}; - std::vector channel_mult = {1, 2, 4, 4}; - std::vector transformer_depth = {1, 1, 1, 1}; - int time_embed_dim = 1280; // model_channels*4 - int num_heads = 8; - int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL - bool use_linear_projection = false; - bool tiny_unet = false; - public: - int model_channels = 320; - int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - - UnetModelBlock(SDVersion version = VERSION_SD1, const String2TensorStorage& tensor_storage_map = {}) - : version(version) { - if (sd_version_is_sd2(version)) { - context_dim = 1024; - num_head_channels = 64; - num_heads = -1; - use_linear_projection = true; - } else if (sd_version_is_sdxl(version)) { - context_dim = 2048; - attention_resolutions = {4, 2}; - channel_mult = {1, 2, 4}; - transformer_depth = {1, 2, 10}; - num_head_channels = 64; - num_heads = -1; - use_linear_projection = true; - if (version == VERSION_SDXL_VEGA) { - transformer_depth = {1, 1, 2}; - } - } else if (version == VERSION_SVD) { - in_channels = 8; - out_channels = 4; - context_dim = 1024; - adm_in_channels = 768; - num_head_channels = 64; - num_heads = -1; - use_linear_projection = true; - } - if (sd_version_is_inpaint(version)) { - in_channels = 9; - } else if (sd_version_is_unet_edit(version)) { - in_channels = 8; - } - if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { - num_res_blocks = 1; - channel_mult = {1, 2, 4}; - tiny_unet = true; - if (version == VERSION_SDXS_512_DS) { - attention_resolutions = {4, 2}; // here just like SDXL - } - } + UNetConfig config; + + explicit UnetModelBlock(UNetConfig config = {}) + : config(config) { + const SDVersion version = this->config.version; + const int in_channels = this->config.in_channels; + const int out_channels = this->config.out_channels; + const int num_res_blocks = this->config.num_res_blocks; + const auto& attention_resolutions = this->config.attention_resolutions; + const auto& channel_mult = this->config.channel_mult; + const auto& transformer_depth = this->config.transformer_depth; + const int time_embed_dim = this->config.time_embed_dim; + const int num_heads = this->config.num_heads; + const int num_head_channels = this->config.num_head_channels; + const int context_dim = this->config.context_dim; + const bool use_linear_projection = this->config.use_linear_projection; + const bool tiny_unet = this->config.tiny_unet; + const int model_channels = this->config.model_channels; + const int adm_in_channels = this->config.adm_in_channels; // dims is always 2 // use_temporal_attention is always True for SVD @@ -398,7 +480,7 @@ class UnetModelBlock : public GGMLBlock { ggml_tensor* x, ggml_tensor* emb, int num_video_frames) { - if (version == VERSION_SVD) { + if (config.version == VERSION_SVD) { auto block = std::dynamic_pointer_cast(blocks[name]); return block->forward(ctx, x, emb, num_video_frames); @@ -414,7 +496,7 @@ class UnetModelBlock : public GGMLBlock { ggml_tensor* x, ggml_tensor* context, int timesteps) { - if (version == VERSION_SVD) { + if (config.version == VERSION_SVD) { auto block = std::dynamic_pointer_cast(blocks[name]); return block->forward(ctx, x, context, timesteps); @@ -440,6 +522,13 @@ class UnetModelBlock : public GGMLBlock { // c_concat: [N, in_channels, h, w] or [1, in_channels, h, w] // y: [N, adm_in_channels] or [1, adm_in_channels] // return: [N, out_channels, h, w] + const SDVersion version = config.version; + const int model_channels = config.model_channels; + const int num_res_blocks = config.num_res_blocks; + const auto& attention_resolutions = config.attention_resolutions; + const auto& channel_mult = config.channel_mult; + const bool tiny_unet = config.tiny_unet; + if (context != nullptr) { if (context->ne[2] != x->ne[3]) { context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3])); @@ -601,6 +690,7 @@ class UnetModelBlock : public GGMLBlock { }; struct UNetModelRunner : public DiffusionModelRunner { + UNetConfig config; UnetModelBlock unet; UNetModelRunner(ggml_backend_t backend, @@ -608,7 +698,9 @@ struct UNetModelRunner : public DiffusionModelRunner { const String2TensorStorage& tensor_storage_map, const std::string prefix, SDVersion version = VERSION_SD1) - : DiffusionModelRunner(backend, params_backend, prefix), unet(version, tensor_storage_map) { + : DiffusionModelRunner(backend, params_backend, prefix), + config(UNetConfig::detect_from_weights(tensor_storage_map, prefix, version)), + unet(config) { unet.init(params_ctx, tensor_storage_map, prefix); } diff --git a/src/wan.hpp b/src/wan.hpp index 68f020e25..bda635066 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -16,6 +16,77 @@ namespace WAN { constexpr int CACHE_T = 2; constexpr int WAN_GRAPH_SIZE = 10240; + struct WanConfig { + std::string model_type = "t2v"; + std::tuple patch_size = {1, 2, 2}; + int64_t text_len = 512; + int64_t in_dim = 16; + int64_t dim = 2048; + int64_t ffn_dim = 8192; + int freq_dim = 256; + int64_t text_dim = 4096; + int64_t out_dim = 16; + int64_t num_heads = 16; + int num_layers = 32; + int vace_layers = 0; + int64_t vace_in_dim = 96; + std::map vace_layers_mapping = {}; + bool qk_norm = true; + bool cross_attn_norm = true; + float eps = 1e-6f; + int64_t flf_pos_embed_token_number = 0; + int theta = 10000; + // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 + std::vector axes_dim = {44, 42, 42}; + int64_t axes_dim_sum = 128; + + static WanConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + WanConfig config; + config.num_layers = 0; + for (const auto& [name, _] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + size_t pos = name.find("vace_blocks."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.vace_layers) { + config.vace_layers = block_index + 1; + } + } + continue; + } + pos = name.find("blocks."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > config.num_layers) { + config.num_layers = block_index + 1; + } + } + continue; + } + if (name.find("img_emb") != std::string::npos) { + config.model_type = "i2v"; + } + if (name.find("img_emb.emb_pos") != std::string::npos) { + config.flf_pos_embed_token_number = 514; + } + } + LOG_DEBUG("wan: model_type = %s, num_layers = %d, vace_layers = %d, dim = %" PRId64 ", ffn_dim = %" PRId64 ", num_heads = %" PRId64, + config.model_type.c_str(), + config.num_layers, + config.vace_layers, + config.dim, + config.ffn_dim, + config.num_heads); + return config; + } + }; + class CausalConv3d : public GGMLBlock { protected: int64_t in_channels; @@ -1799,97 +1870,72 @@ namespace WAN { } }; - struct WanParams { - std::string model_type = "t2v"; - std::tuple patch_size = {1, 2, 2}; - int64_t text_len = 512; - int64_t in_dim = 16; - int64_t dim = 2048; - int64_t ffn_dim = 8192; - int freq_dim = 256; - int64_t text_dim = 4096; - int64_t out_dim = 16; - int64_t num_heads = 16; - int num_layers = 32; - int vace_layers = 0; - int64_t vace_in_dim = 96; - std::map vace_layers_mapping = {}; - bool qk_norm = true; - bool cross_attn_norm = true; - float eps = 1e-6f; - int64_t flf_pos_embed_token_number = 0; - int theta = 10000; - // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 - std::vector axes_dim = {44, 42, 42}; - int64_t axes_dim_sum = 128; - }; - class Wan : public GGMLBlock { protected: - WanParams params; + WanConfig config; public: Wan() {} - Wan(WanParams params) - : params(params) { + Wan(WanConfig config) + : config(config) { // patch_embedding - blocks["patch_embedding"] = std::shared_ptr(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size)); + blocks["patch_embedding"] = std::shared_ptr(new Conv3d(config.in_dim, config.dim, config.patch_size, config.patch_size)); // text_embedding - blocks["text_embedding.0"] = std::shared_ptr(new Linear(params.text_dim, params.dim)); + blocks["text_embedding.0"] = std::shared_ptr(new Linear(config.text_dim, config.dim)); // text_embedding.1 is nn.GELU() - blocks["text_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); + blocks["text_embedding.2"] = std::shared_ptr(new Linear(config.dim, config.dim)); // time_embedding - blocks["time_embedding.0"] = std::shared_ptr(new Linear(params.freq_dim, params.dim)); + blocks["time_embedding.0"] = std::shared_ptr(new Linear(config.freq_dim, config.dim)); // time_embedding.1 is nn.SiLU() - blocks["time_embedding.2"] = std::shared_ptr(new Linear(params.dim, params.dim)); + blocks["time_embedding.2"] = std::shared_ptr(new Linear(config.dim, config.dim)); // time_projection.0 is nn.SiLU() - blocks["time_projection.1"] = std::shared_ptr(new Linear(params.dim, params.dim * 6)); + blocks["time_projection.1"] = std::shared_ptr(new Linear(config.dim, config.dim * 6)); // blocks - for (int i = 0; i < params.num_layers; i++) { - auto block = std::shared_ptr(new WanAttentionBlock(params.model_type == "t2v", - params.dim, - params.ffn_dim, - params.num_heads, - params.qk_norm, - params.cross_attn_norm, - params.eps)); + for (int i = 0; i < config.num_layers; i++) { + auto block = std::shared_ptr(new WanAttentionBlock(config.model_type == "t2v", + config.dim, + config.ffn_dim, + config.num_heads, + config.qk_norm, + config.cross_attn_norm, + config.eps)); blocks["blocks." + std::to_string(i)] = block; } // head - blocks["head"] = std::shared_ptr(new Head(params.dim, params.out_dim, params.patch_size, params.eps)); + blocks["head"] = std::shared_ptr(new Head(config.dim, config.out_dim, config.patch_size, config.eps)); // img_emb - if (params.model_type == "i2v") { - blocks["img_emb"] = std::shared_ptr(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number)); + if (config.model_type == "i2v") { + blocks["img_emb"] = std::shared_ptr(new MLPProj(1280, config.dim, config.flf_pos_embed_token_number)); } // vace - if (params.vace_layers > 0) { - for (int i = 0; i < params.vace_layers; i++) { - auto block = std::shared_ptr(new VaceWanAttentionBlock(params.model_type == "t2v", - params.dim, - params.ffn_dim, - params.num_heads, - params.qk_norm, - params.cross_attn_norm, - params.eps, + if (config.vace_layers > 0) { + for (int i = 0; i < config.vace_layers; i++) { + auto block = std::shared_ptr(new VaceWanAttentionBlock(config.model_type == "t2v", + config.dim, + config.ffn_dim, + config.num_heads, + config.qk_norm, + config.cross_attn_norm, + config.eps, i)); blocks["vace_blocks." + std::to_string(i)] = block; } - int step = params.num_layers / params.vace_layers; + int step = config.num_layers / config.vace_layers; int n = 0; - for (int i = 0; i < params.num_layers; i += step) { - this->params.vace_layers_mapping[i] = n; + for (int i = 0; i < config.num_layers; i += step) { + this->config.vace_layers_mapping[i] = n; n++; } - blocks["vace_patch_embedding"] = std::shared_ptr(new Conv3d(params.vace_in_dim, params.dim, params.patch_size, params.patch_size)); + blocks["vace_patch_embedding"] = std::shared_ptr(new Conv3d(config.vace_in_dim, config.dim, config.patch_size, config.patch_size)); } } @@ -1899,9 +1945,9 @@ namespace WAN { int64_t H = x->ne[1]; int64_t T = x->ne[2]; - int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); - int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); - int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); + int pad_t = (std::get<0>(config.patch_size) - T % std::get<0>(config.patch_size)) % std::get<0>(config.patch_size); + int pad_h = (std::get<1>(config.patch_size) - H % std::get<1>(config.patch_size)) % std::get<1>(config.patch_size); + int pad_w = (std::get<2>(config.patch_size) - W % std::get<2>(config.patch_size)) % std::get<2>(config.patch_size); ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, pad_t, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); return x; } @@ -1914,9 +1960,9 @@ namespace WAN { // x: [N, t_len*h_len*w_len, pt*ph*pw*C] // return: [N*C, t_len*pt, h_len*ph, w_len*pw] int64_t N = x->ne[3]; - int64_t pt = std::get<0>(params.patch_size); - int64_t ph = std::get<1>(params.patch_size); - int64_t pw = std::get<2>(params.patch_size); + int64_t pt = std::get<0>(config.patch_size); + int64_t ph = std::get<1>(config.patch_size); + int64_t pw = std::get<2>(config.patch_size); int64_t C = x->ne[0] / pt / ph / pw; GGML_ASSERT(C * pt * ph * pw == x->ne[0]); @@ -1967,7 +2013,7 @@ namespace WAN { x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] // time_embedding - auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, params.freq_dim); + auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, config.freq_dim); e = time_embedding_0->forward(ctx, e); e = ggml_silu_inplace(ctx->ggml_ctx, e); e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim] @@ -1983,7 +2029,7 @@ namespace WAN { int64_t context_img_len = 0; if (clip_fea != nullptr) { - if (params.model_type == "i2v") { + if (config.model_type == "i2v") { auto img_emb = std::dynamic_pointer_cast(blocks["img_emb"]); auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim] context = ggml_concat(ctx->ggml_ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim] @@ -1993,7 +2039,7 @@ namespace WAN { // vace_patch_embedding ggml_tensor* c = nullptr; - if (params.vace_layers > 0) { + if (config.vace_layers > 0) { auto vace_patch_embedding = std::dynamic_pointer_cast(blocks["vace_patch_embedding"]); c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len] @@ -2010,13 +2056,13 @@ namespace WAN { auto x_orig = x; - for (int i = 0; i < params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); x = block->forward(ctx, x, e0, pe, context, context_img_len); - auto iter = params.vace_layers_mapping.find(i); - if (iter != params.vace_layers_mapping.end()) { + auto iter = config.vace_layers_mapping.find(i); + if (iter != config.vace_layers_mapping.end()) { int n = iter->second; auto vace_block = std::dynamic_pointer_cast(blocks["vace_blocks." + std::to_string(n)]); @@ -2065,14 +2111,14 @@ namespace WAN { x = pad_to_patch_size(ctx, x); - int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); - int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); - int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); + int64_t t_len = ((T + (std::get<0>(config.patch_size) / 2)) / std::get<0>(config.patch_size)); + int64_t h_len = ((H + (std::get<1>(config.patch_size) / 2)) / std::get<1>(config.patch_size)); + int64_t w_len = ((W + (std::get<2>(config.patch_size) / 2)) / std::get<2>(config.patch_size)); if (time_dim_concat != nullptr) { time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] - t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); + t_len = ((x->ne[2] + (std::get<0>(config.patch_size) / 2)) / std::get<0>(config.patch_size)); } auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] @@ -2092,7 +2138,7 @@ namespace WAN { struct WanRunner : public DiffusionModelRunner { public: std::string desc = "wan"; - WanParams wan_params; + WanConfig config; Wan wan; std::vector pe_vec; SDVersion version; @@ -2102,109 +2148,73 @@ namespace WAN { const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "", SDVersion version = VERSION_WAN2) - : DiffusionModelRunner(backend, params_backend, prefix) { - wan_params.num_layers = 0; - for (auto pair : tensor_storage_map) { - std::string tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) - continue; - size_t pos = tensor_name.find("vace_blocks."); - if (pos != std::string::npos) { - tensor_name = tensor_name.substr(pos); // remove prefix - auto items = split_string(tensor_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > wan_params.vace_layers) { - wan_params.vace_layers = block_index + 1; - } - } - continue; - } - pos = tensor_name.find("blocks."); - if (pos != std::string::npos) { - tensor_name = tensor_name.substr(pos); // remove prefix - auto items = split_string(tensor_name, '.'); - if (items.size() > 1) { - int block_index = atoi(items[1].c_str()); - if (block_index + 1 > wan_params.num_layers) { - wan_params.num_layers = block_index + 1; - } - } - continue; - } - if (tensor_name.find("img_emb") != std::string::npos) { - wan_params.model_type = "i2v"; - } - if (tensor_name.find("img_emb.emb_pos") != std::string::npos) { - wan_params.flf_pos_embed_token_number = 514; - } - } - - if (wan_params.num_layers == 30) { + : DiffusionModelRunner(backend, params_backend, prefix), + config(WanConfig::detect_from_weights(tensor_storage_map, prefix)) { + if (config.num_layers == 30) { if (version == VERSION_WAN2_2_TI2V) { - desc = "Wan2.2-TI2V-5B"; - wan_params.dim = 3072; - wan_params.eps = 1e-06f; - wan_params.ffn_dim = 14336; - wan_params.freq_dim = 256; - wan_params.in_dim = 48; - wan_params.num_heads = 24; - wan_params.out_dim = 48; - wan_params.text_len = 512; + desc = "Wan2.2-TI2V-5B"; + config.dim = 3072; + config.eps = 1e-06f; + config.ffn_dim = 14336; + config.freq_dim = 256; + config.in_dim = 48; + config.num_heads = 24; + config.out_dim = 48; + config.text_len = 512; } else { - if (wan_params.vace_layers > 0) { - desc = "Wan2.1-VACE-1.3B"; - wan_params.in_dim = 16; - } else if (wan_params.model_type == "i2v") { - desc = "Wan2.1-I2V-1.3B"; - wan_params.in_dim = 36; + if (config.vace_layers > 0) { + desc = "Wan2.1-VACE-1.3B"; + config.in_dim = 16; + } else if (config.model_type == "i2v") { + desc = "Wan2.1-I2V-1.3B"; + config.in_dim = 36; } else { - desc = "Wan2.1-T2V-1.3B"; - wan_params.in_dim = 16; + desc = "Wan2.1-T2V-1.3B"; + config.in_dim = 16; } - wan_params.dim = 1536; - wan_params.eps = 1e-06f; - wan_params.ffn_dim = 8960; - wan_params.freq_dim = 256; - wan_params.num_heads = 12; - wan_params.out_dim = 16; - wan_params.text_len = 512; + config.dim = 1536; + config.eps = 1e-06f; + config.ffn_dim = 8960; + config.freq_dim = 256; + config.num_heads = 12; + config.out_dim = 16; + config.text_len = 512; } - } else if (wan_params.num_layers == 40) { - if (wan_params.model_type == "t2v") { + } else if (config.num_layers == 40) { + if (config.model_type == "t2v") { if (version == VERSION_WAN2_2_I2V) { - desc = "Wan2.2-I2V-14B"; - wan_params.in_dim = 36; + desc = "Wan2.2-I2V-14B"; + config.in_dim = 36; } else { - if (wan_params.vace_layers > 0) { + if (config.vace_layers > 0) { desc = "Wan2.x-VACE-14B"; } else { desc = "Wan2.x-T2V-14B"; } - wan_params.in_dim = 16; + config.in_dim = 16; } } else { - wan_params.in_dim = 36; - if (wan_params.flf_pos_embed_token_number > 0) { + config.in_dim = 36; + if (config.flf_pos_embed_token_number > 0) { desc = "Wan2.1-FLF2V-14B"; } else { desc = "Wan2.1-I2V-14B"; } } - wan_params.dim = 5120; - wan_params.eps = 1e-06f; - wan_params.ffn_dim = 13824; - wan_params.freq_dim = 256; - wan_params.num_heads = 40; - wan_params.out_dim = 16; - wan_params.text_len = 512; + config.dim = 5120; + config.eps = 1e-06f; + config.ffn_dim = 13824; + config.freq_dim = 256; + config.num_heads = 40; + config.out_dim = 16; + config.text_len = 512; } else { - GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers); + GGML_ABORT("invalid num_layers(%d) of wan", config.num_layers); } LOG_INFO("%s", desc.c_str()); - wan = Wan(wan_params); + wan = Wan(config); wan.init(params_ctx, tensor_storage_map, prefix); } @@ -2237,15 +2247,15 @@ namespace WAN { pe_vec = Rope::gen_wan_pe(static_cast(x->ne[2]), static_cast(x->ne[1]), static_cast(x->ne[0]), - std::get<0>(wan_params.patch_size), - std::get<1>(wan_params.patch_size), - std::get<2>(wan_params.patch_size), + std::get<0>(config.patch_size), + std::get<1>(config.patch_size), + std::get<2>(config.patch_size), 1, - wan_params.theta, - wan_params.axes_dim); - int pos_len = static_cast(pe_vec.size() / wan_params.axes_dim_sum / 2); + config.theta, + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; diff --git a/src/z_image.hpp b/src/z_image.hpp index 82dbe0491..be884c752 100644 --- a/src/z_image.hpp +++ b/src/z_image.hpp @@ -20,6 +20,104 @@ namespace ZImage { constexpr int ADALN_EMBED_DIM = 256; constexpr int SEQ_MULTI_OF = 32; + struct ZImageConfig { + int patch_size = 2; + int64_t hidden_size = 3840; + int64_t in_channels = 16; + int64_t out_channels = 16; + int64_t num_layers = 30; + int64_t num_refiner_layers = 2; + int64_t head_dim = 128; + int64_t num_heads = 30; + int64_t num_kv_heads = 30; + int64_t multiple_of = 256; + float ffn_dim_multiplier = 8.0f / 3.0f; + float norm_eps = 1e-5f; + bool qk_norm = true; + int64_t cap_feat_dim = 2560; + int theta = 256; + std::vector axes_dim = {32, 48, 48}; + int64_t axes_dim_sum = 128; + + static ZImageConfig detect_from_weights(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + ZImageConfig config; + int64_t detected_layers = 0; + int64_t detected_refiner_layers = 0; + int64_t detected_context_refiner = 0; + int64_t detected_head_dim = 0; + int64_t detected_qkv_dim = 0; + + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "x_embedder.weight") && tensor_storage.n_dims == 2) { + int64_t patch_area = config.patch_size * config.patch_size; + config.in_channels = tensor_storage.ne[0] / patch_area; + config.hidden_size = tensor_storage.ne[1]; + } else if (ends_with(name, "cap_embedder.1.weight") && tensor_storage.n_dims == 2) { + config.cap_feat_dim = tensor_storage.ne[0]; + config.hidden_size = tensor_storage.ne[1]; + } else if (ends_with(name, "layers.0.attention.q_norm.weight") && tensor_storage.n_dims == 1) { + detected_head_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "layers.0.attention.qkv.weight") && tensor_storage.n_dims == 2) { + detected_qkv_dim = tensor_storage.ne[1]; + } else if (ends_with(name, "final_layer.linear.weight") && tensor_storage.n_dims == 2) { + int64_t patch_area = config.patch_size * config.patch_size; + config.out_channels = tensor_storage.ne[1] / patch_area; + } + + size_t pos = name.find("layers."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + detected_layers = std::max(detected_layers, block_index + 1); + } + } + pos = name.find("noise_refiner."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + detected_refiner_layers = std::max(detected_refiner_layers, block_index + 1); + } + } + pos = name.find("context_refiner."); + if (pos != std::string::npos) { + auto items = split_string(name.substr(pos), '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + detected_context_refiner = std::max(detected_context_refiner, block_index + 1); + } + } + } + if (detected_layers > 0) { + config.num_layers = detected_layers; + } + if (detected_refiner_layers > 0 || detected_context_refiner > 0) { + config.num_refiner_layers = std::max(detected_refiner_layers, detected_context_refiner); + } + if (detected_head_dim > 0) { + config.head_dim = detected_head_dim; + config.num_heads = config.hidden_size / config.head_dim; + if (detected_qkv_dim > 0) { + int64_t qkv_heads = detected_qkv_dim / config.head_dim; + config.num_kv_heads = std::max(1, (qkv_heads - config.num_heads) / 2); + } + } + LOG_DEBUG("z_image: num_layers = %" PRId64 ", num_refiner_layers = %" PRId64 ", hidden_size = %" PRId64 ", num_heads = %" PRId64 ", num_kv_heads = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, + config.num_layers, + config.num_refiner_layers, + config.hidden_size, + config.num_heads, + config.num_kv_heads, + config.in_channels, + config.out_channels); + return config; + } + }; + struct JointAttention : public GGMLBlock { protected: int64_t head_dim; @@ -263,90 +361,70 @@ namespace ZImage { } }; - struct ZImageParams { - int patch_size = 2; - int64_t hidden_size = 3840; - int64_t in_channels = 16; - int64_t out_channels = 16; - int64_t num_layers = 30; - int64_t num_refiner_layers = 2; - int64_t head_dim = 128; - int64_t num_heads = 30; - int64_t num_kv_heads = 30; - int64_t multiple_of = 256; - float ffn_dim_multiplier = 8.0f / 3.0f; - float norm_eps = 1e-5f; - bool qk_norm = true; - int64_t cap_feat_dim = 2560; - int theta = 256; - std::vector axes_dim = {32, 48, 48}; - int64_t axes_dim_sum = 128; - }; - class ZImageModel : public GGMLBlock { protected: - ZImageParams z_image_params; + ZImageConfig config; void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { - params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); - params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.hidden_size); + params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.hidden_size); } public: ZImageModel() = default; - ZImageModel(ZImageParams z_image_params) - : z_image_params(z_image_params) { - blocks["x_embedder"] = std::make_shared(z_image_params.patch_size * z_image_params.patch_size * z_image_params.in_channels, z_image_params.hidden_size); - blocks["t_embedder"] = std::make_shared(MIN(z_image_params.hidden_size, 1024), 256, 256); - blocks["cap_embedder.0"] = std::make_shared(z_image_params.cap_feat_dim, z_image_params.norm_eps); - blocks["cap_embedder.1"] = std::make_shared(z_image_params.cap_feat_dim, z_image_params.hidden_size); - - for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + ZImageModel(ZImageConfig config) + : config(config) { + blocks["x_embedder"] = std::make_shared(config.patch_size * config.patch_size * config.in_channels, config.hidden_size); + blocks["t_embedder"] = std::make_shared(MIN(config.hidden_size, 1024), 256, 256); + blocks["cap_embedder.0"] = std::make_shared(config.cap_feat_dim, config.norm_eps); + blocks["cap_embedder.1"] = std::make_shared(config.cap_feat_dim, config.hidden_size); + + for (int i = 0; i < config.num_refiner_layers; i++) { auto block = std::make_shared(i, - z_image_params.hidden_size, - z_image_params.head_dim, - z_image_params.num_heads, - z_image_params.num_kv_heads, - z_image_params.multiple_of, - z_image_params.ffn_dim_multiplier, - z_image_params.norm_eps, - z_image_params.qk_norm, + config.hidden_size, + config.head_dim, + config.num_heads, + config.num_kv_heads, + config.multiple_of, + config.ffn_dim_multiplier, + config.norm_eps, + config.qk_norm, true); blocks["noise_refiner." + std::to_string(i)] = block; } - for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + for (int i = 0; i < config.num_refiner_layers; i++) { auto block = std::make_shared(i, - z_image_params.hidden_size, - z_image_params.head_dim, - z_image_params.num_heads, - z_image_params.num_kv_heads, - z_image_params.multiple_of, - z_image_params.ffn_dim_multiplier, - z_image_params.norm_eps, - z_image_params.qk_norm, + config.hidden_size, + config.head_dim, + config.num_heads, + config.num_kv_heads, + config.multiple_of, + config.ffn_dim_multiplier, + config.norm_eps, + config.qk_norm, false); blocks["context_refiner." + std::to_string(i)] = block; } - for (int i = 0; i < z_image_params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::make_shared(i, - z_image_params.hidden_size, - z_image_params.head_dim, - z_image_params.num_heads, - z_image_params.num_kv_heads, - z_image_params.multiple_of, - z_image_params.ffn_dim_multiplier, - z_image_params.norm_eps, - z_image_params.qk_norm, + config.hidden_size, + config.head_dim, + config.num_heads, + config.num_kv_heads, + config.multiple_of, + config.ffn_dim_multiplier, + config.norm_eps, + config.qk_norm, true); blocks["layers." + std::to_string(i)] = block; } - blocks["final_layer"] = std::make_shared(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); + blocks["final_layer"] = std::make_shared(config.hidden_size, config.patch_size, config.out_channels); } ggml_tensor* forward_core(GGMLRunnerContext* ctx, @@ -393,14 +471,14 @@ namespace ZImage { auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]); auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]); - for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + for (int i = 0; i < config.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["context_refiner." + std::to_string(i)]); txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr); sd::ggml_graph_cut::mark_graph_cut(txt, "z_image.context_refiner." + std::to_string(i), "txt"); } - for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + for (int i = 0; i < config.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["noise_refiner." + std::to_string(i)]); img = block->forward(ctx, img, img_pe, nullptr, t_emb); @@ -410,7 +488,7 @@ namespace ZImage { auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size] sd::ggml_graph_cut::mark_graph_cut(txt_img, "z_image.prelude", "txt_img"); - for (int i = 0; i < z_image_params.num_layers; i++) { + for (int i = 0; i < config.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb); @@ -442,7 +520,7 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - int patch_size = z_image_params.patch_size; + int patch_size = config.patch_size; auto img = DiT::pad_and_patchify(ctx, x, patch_size, patch_size, false); uint64_t n_img_token = img->ne[1]; @@ -467,7 +545,7 @@ namespace ZImage { struct ZImageRunner : public DiffusionModelRunner { public: - ZImageParams z_image_params; + ZImageConfig config; ZImageModel z_image; std::vector pe_vec; std::vector timestep_vec; @@ -478,8 +556,9 @@ namespace ZImage { const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "", SDVersion version = VERSION_Z_IMAGE) - : DiffusionModelRunner(backend, params_backend, prefix) { - z_image = ZImageModel(z_image_params); + : DiffusionModelRunner(backend, params_backend, prefix), + config(ZImageConfig::detect_from_weights(tensor_storage_map, prefix)) { + z_image = ZImageModel(config); z_image.init(params_ctx, tensor_storage_map, prefix); } @@ -510,19 +589,19 @@ namespace ZImage { pe_vec = Rope::gen_z_image_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), - z_image_params.patch_size, + config.patch_size, static_cast(x->ne[3]), static_cast(context->ne[1]), SEQ_MULTI_OF, ref_latents, increase_ref_index, - z_image_params.theta, + config.theta, circular_y_enabled, circular_x_enabled, - z_image_params.axes_dim); - int pos_len = static_cast(pe_vec.size() / z_image_params.axes_dim_sum / 2); + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe, true, "pe"); // pe->data = nullptr;