From 246e7eed10c2e729f0fa337c0077711dcaf29510 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Fri, 24 Apr 2026 01:43:24 +0200 Subject: [PATCH 1/2] LTX-2 first version --- CMakeLists.txt | 1 + examples/common/common.cpp | 7 + examples/common/common.h | 1 + include/stable-diffusion.h | 12 + src/conditioner.hpp | 290 ++++++ src/denoiser.hpp | 91 ++ src/diffusion_model.hpp | 75 ++ src/ggml_extend.hpp | 53 +- src/llm.hpp | 343 ++++++- src/ltx.hpp | 764 ++++++++++++++++ src/ltx_connector.hpp | 623 +++++++++++++ src/ltx_rope.hpp | 350 +++++++ src/ltxv.hpp | 110 ++- src/ltxvae.hpp | 913 +++++++++++++++++++ src/ltxvae_primitives.hpp | 212 +++++ src/model.cpp | 3 + src/model.h | 11 +- src/name_conversion.cpp | 35 + src/stable-diffusion.cpp | 336 ++++++- src/tokenizers/gemma_tokenizer.cpp | 254 ++++++ src/tokenizers/gemma_tokenizer.h | 50 + src/vae.hpp | 3 + tests/ltx_parity/CMakeLists.txt | 62 ++ tests/ltx_parity/README.md | 36 + tests/ltx_parity/dump_connector.py | 293 ++++++ tests/ltx_parity/dump_gemma.py | 256 ++++++ tests/ltx_parity/dump_reference.py | 623 +++++++++++++ tests/ltx_parity/dump_s2d.py | 176 ++++ tests/ltx_parity/dump_vae.py | 341 +++++++ tests/ltx_parity/test_connector_parity.cpp | 297 ++++++ tests/ltx_parity/test_gemma_parity.cpp | 287 ++++++ tests/ltx_parity/test_gemma_tokenizer.cpp | 88 ++ tests/ltx_parity/test_ltx2_vae_roundtrip.cpp | 221 +++++ tests/ltx_parity/test_ltx_parity.cpp | 438 +++++++++ tests/ltx_parity/test_s2d_primitives.cpp | 185 ++++ tests/ltx_parity/test_vae_parity.cpp | 378 ++++++++ 36 files changed, 8135 insertions(+), 83 deletions(-) create mode 100644 src/ltx.hpp create mode 100644 src/ltx_connector.hpp create mode 100644 src/ltx_rope.hpp create mode 100644 src/ltxvae.hpp create mode 100644 src/ltxvae_primitives.hpp create mode 100644 src/tokenizers/gemma_tokenizer.cpp create mode 100644 src/tokenizers/gemma_tokenizer.h create mode 100644 tests/ltx_parity/CMakeLists.txt create mode 100644 tests/ltx_parity/README.md create mode 100644 tests/ltx_parity/dump_connector.py create mode 100644 tests/ltx_parity/dump_gemma.py create mode 100644 tests/ltx_parity/dump_reference.py create mode 100644 tests/ltx_parity/dump_s2d.py create mode 100644 tests/ltx_parity/dump_vae.py create mode 100644 tests/ltx_parity/test_connector_parity.cpp create mode 100644 tests/ltx_parity/test_gemma_parity.cpp create mode 100644 tests/ltx_parity/test_gemma_tokenizer.cpp create mode 100644 tests/ltx_parity/test_ltx2_vae_roundtrip.cpp create mode 100644 tests/ltx_parity/test_ltx_parity.cpp create mode 100644 tests/ltx_parity/test_s2d_primitives.cpp create mode 100644 tests/ltx_parity/test_vae_parity.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a9fb1041..538b173b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) if (SD_BUILD_EXAMPLES) add_subdirectory(examples) + add_subdirectory(tests/ltx_parity) endif() set(SD_PUBLIC_HEADERS include/stable-diffusion.h) diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 0235c53de..7ac0a0d30 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -319,6 +319,10 @@ ArgOptions SDContextParams::get_options() { "--qwen2vl_vision", "alias of --llm_vision. Deprecated.", &llm_vision_path}, + {"", + "--gemma-tokenizer", + "path to Gemma's tokenizer.json (HF format). Required for LTX-2 text conditioning.", + &gemma_tokenizer_path}, {"", "--diffusion-model", "path to the standalone diffusion model", @@ -638,6 +642,7 @@ std::string SDContextParams::to_string() const { << " t5xxl_path: \"" << t5xxl_path << "\",\n" << " llm_path: \"" << llm_path << "\",\n" << " llm_vision_path: \"" << llm_vision_path << "\",\n" + << " gemma_tokenizer_path: \"" << gemma_tokenizer_path << "\",\n" << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n" @@ -693,6 +698,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f t5xxl_path.c_str(), llm_path.c_str(), llm_vision_path.c_str(), + gemma_tokenizer_path.c_str(), diffusion_model_path.c_str(), high_noise_diffusion_model_path.c_str(), vae_path.c_str(), @@ -2012,6 +2018,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { params.strength = strength; params.seed = seed; params.video_frames = video_frames; + params.fps = static_cast(fps); params.vace_strength = vace_strength; params.vae_tiling_params = vae_tiling_params; params.cache = cache_params; diff --git a/examples/common/common.h b/examples/common/common.h index 5afe89b34..6e405b050 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -90,6 +90,7 @@ struct SDContextParams { std::string t5xxl_path; std::string llm_path; std::string llm_vision_path; + std::string gemma_tokenizer_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index a99b10450..9ab335627 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -76,6 +76,7 @@ enum prediction_t { FLOW_PRED, FLUX_FLOW_PRED, FLUX2_FLOW_PRED, + LTX2_FLOW_PRED, PREDICTION_COUNT }; @@ -169,6 +170,11 @@ typedef struct { const char* t5xxl_path; const char* llm_path; const char* llm_vision_path; + // Path to a HuggingFace-format tokenizer.json file. Currently only read by the + // LTX-2 Gemma 3 conditioner, which requires Gemma's tokenizer for BPE + metaspace + // encoding of prompts. If empty for LTX-2, the conditioner aborts with a clear + // message. Non-LTX-2 pipelines ignore this field. + const char* gemma_tokenizer_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; const char* vae_path; @@ -332,6 +338,12 @@ typedef struct { float strength; int64_t seed; int video_frames; + // Output video fps. Carried through to models that use it for temporal + // positional embeddings — LTX-2's RoPE divides the time axis by fps + // (ltx_core/tools.py::VideoLatentTools.create_initial_state), so the + // default 24 on LTXRunner silently produces wrong positions at any + // other target fps. 0 means "don't override runner default". + float fps; float vace_strength; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 9f4d45524..c38342212 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -5,8 +5,10 @@ #include "clip.hpp" #include "llm.hpp" +#include "ltx_connector.hpp" #include "t5.hpp" #include "tensor_ggml.hpp" +#include "tokenizers/gemma_tokenizer.h" struct SDCondition { sd::Tensor c_crossattn; @@ -1958,4 +1960,292 @@ struct LLMEmbedder : public Conditioner { } }; +// LTX-2 conditioner: Gemma 3 text encoder → feature extractor → 1D connector → +// DiT cross-attention context. Supports both V1 (19B) and V2 (22B) feature +// extractor variants, auto-detected from the tensor map. +// +// Key prefixes (native LTX-2 checkpoint layout, no name-conversion applied): +// text_encoder.model.* Gemma weights +// text_embedding_projection.aggregate_embed.* V1 FeatureExtractorV1 (19B) +// text_embedding_projection.video_aggregate_embed.* V2 FeatureExtractorV2 video branch (22B) +// text_embedding_projection.audio_aggregate_embed.* V2 audio branch (22B, currently unused) +// model.diffusion_model.embeddings_connector.* V1 Embeddings1DConnector (19B) +// model.diffusion_model.video_embeddings_connector.* V2 video connector (22B) +// model.diffusion_model.caption_projection.* V1 PixArt caption_projection (on DiT) +// (V2 has no caption_projection — feature +// extractor already outputs DiT's inner_dim) +// +// If neither V1 nor V2 connector weights are present (e.g. Gemma-only test +// checkpoints), the conditioner falls back to returning the final post-norm +// hidden state — the same cheap path we had before Phase 9 landed. +struct LTX2GemmaConditioner : public Conditioner { + std::shared_ptr llm; + std::shared_ptr tokenizer; + std::shared_ptr connector_runner; + std::string prefix; + std::string tokenizer_path; + int64_t gemma_hidden_size = 0; + int gemma_num_hidden_layers = 0; + // True when using the V2 (22B) feature extractor; used by get_learned_condition + // to pick the right CPU normalization path. + bool use_v2_feature_extractor = false; + + LTX2GemmaConditioner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix = "text_encoder", + const std::string tokenizer_path = "", + const std::string feat_ext_prefix = "text_embedding_projection", + const std::string connector_prefix_arg = "") + : prefix(prefix), tokenizer_path(tokenizer_path) { + llm = std::make_shared(LLM::LLMArch::GEMMA3, + backend, + offload_params_to_cpu, + tensor_storage_map, + prefix, + /*enable_vision=*/false); + gemma_hidden_size = llm->params.hidden_size; + gemma_num_hidden_layers = static_cast(llm->params.num_layers); + + if (!tokenizer_path.empty()) { + tokenizer = std::make_shared(); + if (!tokenizer->load_from_file(tokenizer_path)) { + LOG_WARN("LTX2GemmaConditioner: failed to load Gemma tokenizer from '%s'", tokenizer_path.c_str()); + tokenizer.reset(); + } + } + + // Auto-detect V1 vs V2 feature extractor + connector prefix variant. + // V2 (22B): text_embedding_projection.video_aggregate_embed.{weight,bias} + + // model.diffusion_model.video_embeddings_connector.* + // V1 (19B): text_embedding_projection.aggregate_embed.weight + + // model.diffusion_model.embeddings_connector.* + // `connector_prefix_arg` is honored when non-empty, otherwise we probe both. + const std::string& feat_ext_pre = feat_ext_prefix; + + auto agg_v1_it = tensor_storage_map.find(feat_ext_pre + ".aggregate_embed.weight"); + auto agg_v2_it = tensor_storage_map.find(feat_ext_pre + ".video_aggregate_embed.weight"); + + std::string connector_pre; + LTXConnector::FeatureExtractorVersion fe_version = LTXConnector::FeatureExtractorVersion::V1; + int64_t flat_dim = 0; + int64_t inner_dim = 0; + + if (agg_v2_it != tensor_storage_map.end()) { + fe_version = LTXConnector::FeatureExtractorVersion::V2; + flat_dim = agg_v2_it->second.ne[0]; + inner_dim = agg_v2_it->second.ne[1]; + use_v2_feature_extractor = true; + connector_pre = connector_prefix_arg.empty() + ? "model.diffusion_model.video_embeddings_connector" + : connector_prefix_arg; + } else if (agg_v1_it != tensor_storage_map.end()) { + fe_version = LTXConnector::FeatureExtractorVersion::V1; + flat_dim = agg_v1_it->second.ne[0]; + inner_dim = agg_v1_it->second.ne[1]; + connector_pre = connector_prefix_arg.empty() + ? "model.diffusion_model.embeddings_connector" + : connector_prefix_arg; + } else { + LOG_INFO("LTX2GemmaConditioner: no feature_extractor weights found — falling back to " + "last_hidden_state pass-through (Gemma-only mode)"); + return; + } + + auto conn0_it = tensor_storage_map.find(connector_pre + ".transformer_1d_blocks.0.attn1.to_q.weight"); + if (conn0_it == tensor_storage_map.end()) { + LOG_WARN("LTX2GemmaConditioner: feature_extractor weights present but connector at '%s' is missing; " + "falling back to last_hidden_state", + connector_pre.c_str()); + return; + } + if (conn0_it->second.ne[1] != inner_dim) { + LOG_WARN("LTX2GemmaConditioner: connector to_q out_features=%lld does not match " + "feature_extractor inner_dim=%lld; skipping connector.", + (long long)conn0_it->second.ne[1], (long long)inner_dim); + return; + } + + // Count connector layers by probing to_q presence. + int num_layers = 0; + while (tensor_storage_map.find(connector_pre + ".transformer_1d_blocks." + + std::to_string(num_layers) + ".attn1.to_q.weight") != + tensor_storage_map.end()) { + num_layers++; + } + + // num_registers from learnable_registers.ne (ne[0]=inner_dim, ne[1]=num_registers). + int num_registers = 0; + auto reg_it = tensor_storage_map.find(connector_pre + ".learnable_registers"); + if (reg_it != tensor_storage_map.end() && reg_it->second.n_dims >= 2) { + num_registers = static_cast(reg_it->second.ne[1]); + } + + // Detect gated attention inside the connector (V2 / 22B has this). + bool apply_gated = tensor_storage_map.find( + connector_pre + ".transformer_1d_blocks.0.attn1.to_gate_logits.weight") != + tensor_storage_map.end(); + + // LTX-2 fixes head_dim=128 across both variants. + int head_dim = 128; + int num_heads = static_cast(inner_dim / head_dim); + + // We do NOT include caption_projection here — V1 has it on the DiT side, + // V2 has none. Pass source_dim=Gemma hidden so V2's sqrt(target/source) + // rescale is applied correctly. + connector_runner = std::make_shared( + backend, offload_params_to_cpu, + flat_dim, num_heads, head_dim, num_layers, num_registers, + /*caption_channels=*/0, /*caption_hidden=*/0, /*caption_out=*/0, + /*theta=*/10000.0f, /*max_pos=*/std::vector{1}, + tensor_storage_map, + /*include_caption_projection=*/false, + feat_ext_pre, connector_pre, /*caption_proj_prefix=*/"", + fe_version, /*source_dim=*/gemma_hidden_size, apply_gated); + LOG_INFO("LTX2GemmaConditioner: wired %s connector (flat_dim=%lld inner_dim=%lld " + "num_layers=%d num_registers=%d gated=%d)", + fe_version == LTXConnector::FeatureExtractorVersion::V2 ? "V2" : "V1", + (long long)flat_dim, (long long)inner_dim, num_layers, num_registers, + apply_gated ? 1 : 0); + } + + void get_param_tensors(std::map& tensors) override { + llm->get_param_tensors(tensors, prefix); + if (connector_runner) { + connector_runner->get_param_tensors(tensors); + } + } + void alloc_params_buffer() override { + llm->alloc_params_buffer(); + if (connector_runner) connector_runner->alloc_params_buffer(); + } + void free_params_buffer() override { + llm->free_params_buffer(); + if (connector_runner) connector_runner->free_params_buffer(); + } + size_t get_params_buffer_size() override { + size_t s = llm->get_params_buffer_size(); + if (connector_runner) s += connector_runner->get_params_buffer_size(); + return s; + } + void set_flash_attention_enabled(bool enabled) override { + llm->set_flash_attention_enabled(enabled); + if (connector_runner) connector_runner->set_flash_attention_enabled(enabled); + } + + SDCondition get_learned_condition(int n_threads, + const ConditionerParams& p) override { + if (!tokenizer) { + LOG_ERROR("LTX2GemmaConditioner: no tokenizer loaded. Construct the conditioner " + "with a path to Gemma's tokenizer.json."); + GGML_ABORT("Gemma tokenizer missing"); + } + // HuggingFace Gemma tokenizer always prepends ; we replicate that here + // so the encoder sees the same sequence the Python reference does. + std::vector real_ids = tokenizer->tokenize(p.text, nullptr, /*padding=*/false); + real_ids.insert(real_ids.begin(), tokenizer->BOS_TOKEN_ID); + const int64_t T_real = static_cast(real_ids.size()); + LOG_DEBUG("LTX2GemmaConditioner: tokenized prompt '%s' -> %lld real tokens", + p.text.c_str(), (long long)T_real); + sd::Tensor empty_mask; + + if (!connector_runner) { + // No connector weights: behave like before Phase 9 landed (no padding). + sd::Tensor ids_tensor({T_real, 1}); + for (int64_t i = 0; i < T_real; ++i) ids_tensor.data()[i] = real_ids[i]; + auto last_hidden = llm->compute(n_threads, ids_tensor, empty_mask, {}, {}); + SDCondition cond; + cond.c_crossattn = last_hidden; + return cond; + } + + // Python LTX-2 tokenizer pads to max_length=1024 with padding_side="left" + // and pad_token = EOS: + // ltx_core/text_encoders/gemma/tokenizer.py:21-24 (padding_side="left", + // pad_token=EOS) and ltx_core/text_encoders/gemma/encoders/base_encoder.py:182 + // (`LTXVGemmaTokenizer(tokenizer_root, 1024)`). + // Gemma processes the full max_length, and the connector then sees a + // max_length-long sequence with learnable_registers tiled max_length/num_reg + // times (8× on the 22B V2 path, where num_reg=128). Padding only to + // num_registers produces the wrong Gemma RoPE positions for the real tokens + // and cuts the DiT cross-attention context by the same factor; both regress + // output quality from recognisable subjects to colored-blob textures. + const int num_registers = connector_runner->num_registers; + const int64_t max_length = 1024; + int64_t T_pad = 0; + int64_t T = T_real; + if (T_real < max_length) { + T_pad = max_length - T_real; + T = max_length; + } else { + // Prompt already exceeds max_length — truncate to match tokenizer + // behaviour (`truncation=True` in LTXVGemmaTokenizer). + LOG_WARN("LTX2GemmaConditioner: prompt tokenised to %lld >= max_length=%lld; truncating.", + (long long)T_real, (long long)max_length); + real_ids.resize(static_cast(max_length)); + T = max_length; + T_pad = 0; + } + sd::Tensor input_ids({T, 1}); + for (int64_t i = 0; i < T_pad; ++i) input_ids.data()[i] = tokenizer->EOS_TOKEN_ID; + const int64_t real_to_write = std::min(T_real, max_length); + for (int64_t i = 0; i < real_to_write; ++i) input_ids.data()[T_pad + i] = real_ids[i]; + // num_registers must divide max_length (Embeddings1DConnector tiles). + GGML_ASSERT(num_registers == 0 || max_length % num_registers == 0); + connector_runner->set_target_seq_len(static_cast(max_length)); + + // 1. Gemma: compute all N+1 hidden states on the padded sequence. + // Layout returned by compute_all_hidden_states: ne [N+1, H, T, B] = + // PyTorch [B, T, H, N+1] (stack of per-layer hidden states). + auto stacked = llm->compute_all_hidden_states(n_threads, input_ids, empty_mask); + const int64_t B = 1; + const int64_t D = gemma_hidden_size; + const int64_t L = gemma_num_hidden_layers + 1; + GGML_ASSERT(stacked.numel() == L * D * T * B); + + // 2. CPU normalize → [B, T, D*L]. seq_lens=[T_real_eff] + left-padding tells + // the normalizer to zero out the pad positions (which live at [0, T_pad)). + // T_real_eff caps at max_length to handle the truncated-prompt branch above. + const int64_t T_real_eff = std::min(T_real, max_length); + std::vector seq_lens(B, static_cast(T_real_eff)); + sd::Tensor normed({D * L, T, B}); + if (use_v2_feature_extractor) { + LTXConnector::feature_extractor_normalize_v2( + stacked.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + } else { + LTXConnector::feature_extractor_normalize( + stacked.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + } + + // Python's Embeddings1DConnector._replace_padded_with_learnable_registers moves + // the real-token rows from [T_pad, T) to the START of the sequence and replaces + // the now-empty tail with learnable_registers[T_real:]. Equivalent CPU-side shift: + // after normalize, [0,T_pad) holds zeros (masked pad), [T_pad,T) holds real. + // Slide the real rows down to [0,T_real) and re-zero the tail — the connector + // runner then tiles/slices learnable_registers[T_real:max_length] and concats. + if (T_pad > 0) { + const int64_t flat_dim = D * L; + sd::Tensor reals({flat_dim, T_real_eff, B}); + for (int64_t b = 0; b < B; ++b) { + std::memcpy(reals.data() + b * T_real_eff * flat_dim, + normed.data() + b * T * flat_dim + T_pad * flat_dim, + static_cast(T_real_eff * flat_dim) * sizeof(float)); + } + normed = std::move(reals); + } + + // 3. Run connector (stage 3 = after all transformer blocks + final rms_norm, + // before caption_projection — the DiT owns caption_projection). + auto context = connector_runner->compute(n_threads, normed, /*stage=*/3); + + SDCondition cond; + cond.c_crossattn = context; + return cond; + } +}; + #endif diff --git a/src/denoiser.hpp b/src/denoiser.hpp index a6e81d597..4613bffd2 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -720,6 +720,97 @@ struct FluxFlowDenoiser : public DiscreteFlowDenoiser { } }; +// LTX-2 flow-match denoiser. +// +// Reference: /devel/tools/diffusion/LTX-2/packages/ltx-core/src/ltx_core/components/schedulers.py +// +// Key differences from FluxFlowDenoiser: +// - sigma_to_t(σ) = σ * 1000 (Flux passes raw σ; LTX's TransformerArgsPreprocessor scales by +// 1000 in Python, but we externalise that to the denoiser so the +// DiT's AdaLayerNormSingle doesn't double-multiply). +// - Token-count-dependent shift: mu = linear_interp(tokens, 1024→0.95, 4096→2.05), log-space. +// - Terminal stretch: after flux_time_shift, rescale non-zero sigmas so the last non-zero lands +// at `terminal` (default 0.1). This is what the LTX-2 distilled LoRAs expect. +// - scheduler_t is ignored — LTX2Scheduler is fixed; a non-default value would give wrong +// behaviour for the trained weights. +struct LTX2FlowDenoiser : public DiscreteFlowDenoiser { + static constexpr int BASE_SHIFT_ANCHOR = 1024; + static constexpr int MAX_SHIFT_ANCHOR = 4096; + + // Log-space shift anchors; get exponentiated in compute_mu. + float max_shift = 2.05f; + float base_shift = 0.95f; + float terminal = 0.1f; + bool stretch = true; + + LTX2FlowDenoiser() = default; + + // Compute the shift `mu` used inside flux_time_shift. Python: + // mm = (max_shift - base_shift) / (MAX_ANCHOR - BASE_ANCHOR) + // b = base_shift - mm * BASE_ANCHOR + // sigma_shift = tokens * mm + b + float compute_mu(int tokens) const { + float mm = (max_shift - base_shift) / static_cast(MAX_SHIFT_ANCHOR - BASE_SHIFT_ANCHOR); + float b = base_shift - mm * static_cast(BASE_SHIFT_ANCHOR); + return static_cast(tokens) * mm + b; + } + + // t_to_sigma uses the base-shift mapping as a best-effort inverse. The real inverse depends on + // the terminal stretch, which needs the full schedule context — sampling never actually inverts + // t_to_sigma at arbitrary points, so this is only here to satisfy the virtual interface. + float t_to_sigma(float t) override { + return flux_time_shift(base_shift, 1.0f, (t + 1.0f) / TIMESTEPS); + } + + std::vector get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion /*version*/) override { + if (scheduler_type != DISCRETE_SCHEDULER) { + LOG_WARN("LTX2FlowDenoiser: ignoring scheduler_type=%d; LTX-2 uses a fixed schedule", + static_cast(scheduler_type)); + } + + int tokens = image_seq_len > 0 ? image_seq_len : MAX_SHIFT_ANCHOR; + float mu = compute_mu(tokens); + float exp_mu = std::exp(mu); + LOG_DEBUG("LTX2FlowDenoiser: tokens=%d mu=%.4f stretch=%d terminal=%.3f", + tokens, mu, stretch ? 1 : 0, terminal); + + std::vector sigmas(n + 1); + // linspace(1.0, 0.0, n+1) then apply flux_time_shift (power=1) to non-zero entries. + for (uint32_t i = 0; i <= n; ++i) { + float t = 1.0f - static_cast(i) / static_cast(n); + if (t <= 0.0f) { + sigmas[i] = 0.0f; + } else { + sigmas[i] = exp_mu / (exp_mu + (1.0f / t - 1.0f)); + } + } + + // Terminal stretch: rescale `1 - σ` so that the last non-zero σ lands at `terminal`. + if (stretch) { + int last_nonzero = -1; + for (int i = static_cast(n); i >= 0; --i) { + if (sigmas[i] > 0.0f) { + last_nonzero = i; + break; + } + } + if (last_nonzero > 0) { + float one_minus_last = 1.0f - sigmas[last_nonzero]; + float scale_factor = one_minus_last / (1.0f - terminal); + if (scale_factor > 0.0f) { + for (uint32_t i = 0; i <= n; ++i) { + if (sigmas[i] > 0.0f) { + sigmas[i] = 1.0f - (1.0f - sigmas[i]) / scale_factor; + } + } + } + } + } + + return sigmas; + } +}; + struct Flux2FlowDenoiser : public FluxFlowDenoiser { Flux2FlowDenoiser() = default; diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index c0a2a11c0..4faf7460b 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -5,6 +5,7 @@ #include "anima.hpp" #include "ernie_image.hpp" #include "flux.hpp" +#include "ltx.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" #include "tensor_ggml.hpp" @@ -50,6 +51,9 @@ struct DiffusionModel { virtual int64_t get_adm_in_channels() = 0; virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; + // Overridden only by models whose spatial / temporal embeddings depend on the + // output fps (currently LTX-2). Image-only models ignore the value. + virtual void set_fps(float fps) {} }; struct UNetModel : public DiffusionModel { @@ -517,6 +521,77 @@ struct ZImageModel : public DiffusionModel { } }; +struct LTXDiffusionModel : public DiffusionModel { + std::string prefix; + LTX::LTXRunner ltx; + + LTXDiffusionModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_LTX2) + : prefix(prefix), ltx(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { + } + + std::string get_desc() override { + return ltx.get_desc(); + } + + void alloc_params_buffer() override { + ltx.alloc_params_buffer(); + } + + void free_params_buffer() override { + ltx.free_params_buffer(); + } + + void free_compute_buffer() override { + ltx.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + ltx.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return ltx.get_params_buffer_size(); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + ltx.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 0; + } + + void set_flash_attention_enabled(bool enabled) override { + ltx.set_flash_attention_enabled(enabled); + } + + void set_circular_axes(bool circular_x, bool circular_y) override { + ltx.set_circular_axes(circular_x, circular_y); + } + + void set_fps(float fps) override { + if (fps > 0.f) { + ltx.set_fps(fps); + } + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + static const sd::Tensor empty; + return ltx.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context), + empty); + } +}; + struct ErnieImageModel : public DiffusionModel { std::string prefix; ErnieImage::ErnieImageRunner ernie_image; diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cbd..8275f26e0 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -2048,6 +2048,27 @@ struct GGMLRunner { params_buffer_size / (1024.f * 1024.f), ggml_backend_is_cpu(params_backend) ? "RAM" : "VRAM", num_tensors); + // Per-type tensor-size breakdown to make silent F32/F16 upcasts visible. + // Only emit at INFO when the buffer is > 1 GB — avoids spamming for small + // runners like the VAE scale head or connector. + if (params_buffer_size >= size_t(1) << 30) { + std::map> per_type; // bytes, count + for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != nullptr; + t = ggml_get_next_tensor(params_ctx, t)) { + auto& entry = per_type[t->type]; + entry.first += ggml_nbytes(t); + entry.second += 1; + } + std::string breakdown; + for (const auto& kv : per_type) { + char buf[96]; + std::snprintf(buf, sizeof(buf), "%s %zu/%6.1fMB ", + ggml_type_name(kv.first), kv.second.second, + kv.second.first / (1024.f * 1024.f)); + breakdown += buf; + } + LOG_INFO("%s param breakdown: %s", get_desc().c_str(), breakdown.c_str()); + } return true; } @@ -2354,7 +2375,37 @@ class Linear : public UnaryBlock { }; __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { - std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; + // ggml-cpu implements get_rows for the full quant set in + // ggml_compute_forward_get_rows (ggml-cpu/ops.cpp) — both the legacy + // Q{4,5,8}_{0,1} formats AND the K-quants / IQ-quants. Historically this + // allowlist only contained legacy types, which forced the LTX-2 Gemma-3 + // token_embd weight (IQ4_XS in the 12B checkpoint) to fall back to F32 + // during Embedding::init_params and cost an extra ~3.5 GB of RAM per + // encode. Keep the list tight — only what ggml-cpu get_rows actually + // dispatches — so an unsupported type still trips the F32 fallback. + std::set allow_types = { + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q5_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_0, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_IQ2_XXS, + GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + GGML_TYPE_IQ1_S, + GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, + GGML_TYPE_IQ4_XS, + }; if (allow_types.find(wtype) != allow_types.end()) { return true; } diff --git a/src/llm.hpp b/src/llm.hpp index 4afaa3ba6..0a12d1958 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -22,13 +22,19 @@ #include "tokenizers/qwen2_tokenizer.h" namespace LLM { - constexpr int LLM_GRAPH_SIZE = 10240; + // Bumped aggressively for the 22B LTX-2 smoke test where Gemma 3 12B runs with + // compute_all_hidden_states (49-layer concat stack over 48 layers of sandwich- + // norm + attn + MLP). The assert at ggml.c:6877 fired at 40960; 200000 leaves + // ample headroom while we diagnose whether real op count or hash dedup is the + // issue. + constexpr int LLM_GRAPH_SIZE = 200000; enum class LLMArch { QWEN2_5_VL, QWEN3, MISTRAL_SMALL_3_2, MINISTRAL_3_3B, + GEMMA3, ARCH_COUNT, }; @@ -37,6 +43,7 @@ namespace LLM { "qwen3", "mistral_small3.2", "ministral3.3b", + "gemma3", }; struct LLMVisionParams { @@ -65,12 +72,63 @@ namespace LLM { bool qk_norm = false; int64_t vocab_size = 152064; float rms_norm_eps = 1e-06f; + + // Gemma 3 additions (unused by other archs). + // Pattern: layers where (idx % sliding_window_pattern == 0) use global attention + // with rope_theta_global; other layers use sliding-window attention of size + // sliding_window with rope_theta_local. has_post_norms adds a second RMSNorm after + // attn and after MLP inside each block. embed_scale multiplies token embeddings + // once before the first layer. + int sliding_window = 0; // 0 = disabled + int sliding_window_pattern = 0; // 0 = disabled + float rope_theta_global = 0.f; // 0 = use legacy hardcoded theta + float rope_theta_local = 0.f; + // Gemma 3 rope_scaling: linear RoPE scaling applied only to full-attention + // (global) layers. HuggingFace config.json: rope_scaling={factor: F, rope_type: linear}. + // Sliding layers are unscaled. 1.0 = disabled. For the 12B model this is 8.0. + float rope_scaling_factor_global = 1.0f; + bool has_post_norms = false; + float embed_scale = 1.0f; + LLMVisionParams vision; }; + // Gemma 3 RMSNorm variant: scale by (1 + w) instead of w. Weights are stored with + // zero-mean convention so at init effective scale is 1.0. Not used by other archs. + class RMSNormPlus1 : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + std::string prefix; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + this->prefix = prefix; + params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + + public: + RMSNormPlus1(int64_t hidden_size, float eps = 1e-06f) + : hidden_size(hidden_size), eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + ggml_tensor* w = params["weight"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + auto scaled = ggml_mul(ctx->ggml_ctx, x, w); // rms(x) * w + x = ggml_add_inplace(ctx->ggml_ctx, x, scaled); // rms(x) * (1 + w) + return x; + } + }; + struct MLP : public GGMLBlock { + protected: + bool use_gelu_tanh; + public: - MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { + MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false, bool use_gelu_tanh = false) + : use_gelu_tanh(use_gelu_tanh) { blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); @@ -83,9 +141,13 @@ namespace LLM { auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); auto h = gate_proj->forward(ctx, x); - h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); - h = down_proj->forward(ctx, h); + if (use_gelu_tanh) { + h = ggml_gelu_inplace(ctx->ggml_ctx, h); + } else { + h = ggml_silu_inplace(ctx->ggml_ctx, h); + } + h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); + h = down_proj->forward(ctx, h); return h; } }; @@ -376,24 +438,50 @@ namespace LLM { int64_t num_heads; int64_t num_kv_heads; bool qk_norm; + int layer_idx; + int sliding_window_pattern; + float rope_theta_global; + float rope_theta_local; + float rope_scaling_factor_global; public: - Attention(const LLMParams& params) - : arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) { + Attention(const LLMParams& params, int layer_idx = 0) + : arch(params.arch), + num_heads(params.num_heads), + num_kv_heads(params.num_kv_heads), + head_dim(params.head_dim), + qk_norm(params.qk_norm), + layer_idx(layer_idx), + sliding_window_pattern(params.sliding_window_pattern), + rope_theta_global(params.rope_theta_global), + rope_theta_local(params.rope_theta_local), + rope_scaling_factor_global(params.rope_scaling_factor_global) { blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false); if (params.qk_norm) { - blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); - blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + if (arch == LLMArch::GEMMA3) { + blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + } else { + blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + } } } + bool is_gemma_sliding_layer() const { + return arch == LLMArch::GEMMA3 + && sliding_window_pattern > 0 + && ((layer_idx + 1) % sliding_window_pattern) != 0; + } + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + ggml_tensor* attention_mask_sliding = nullptr) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -411,8 +499,8 @@ namespace LLM { v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] if (qk_norm) { - auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); - auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); q = q_norm->forward(ctx, q); k = k_norm->forward(ctx, k); @@ -427,12 +515,28 @@ namespace LLM { } else if (arch == LLMArch::QWEN3) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::GEMMA3) { + // Per-layer theta: global (full attention) layers use rope_theta_global, + // sliding layers use rope_theta_local. Pattern: is_global = ((l+1)%p == 0). + // Real Gemma 3 12B config also sets linear rope_scaling with factor=8.0 + // on full_attention only. HuggingFace divides inv_freq by factor, which + // ggml_rope_ext expresses as freq_scale = 1 / factor. + bool is_sliding = is_gemma_sliding_layer(); + float theta = is_sliding ? rope_theta_local : rope_theta_global; + float freq_scale = is_sliding ? 1.0f : (1.0f / rope_scaling_factor_global); + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_multi(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } + // Gemma 3: pick the sliding-window mask for local layers. + if (is_gemma_sliding_layer() && attention_mask_sliding != nullptr) { + attention_mask = attention_mask_sliding; + } + q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] @@ -447,21 +551,58 @@ namespace LLM { }; struct TransformerBlock : public GGMLBlock { + protected: + bool has_post_norms; + public: - TransformerBlock(const LLMParams& params) { - blocks["self_attn"] = std::make_shared(params); - blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size); - blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); - blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + TransformerBlock(const LLMParams& params, int layer_idx = 0) + : has_post_norms(params.has_post_norms) { + bool gemma = (params.arch == LLMArch::GEMMA3); + blocks["self_attn"] = std::make_shared(params, layer_idx); + blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size, false, gemma); + + if (gemma) { + blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["pre_feedforward_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_feedforward_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + } else { + blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + } } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + ggml_tensor* attention_mask_sliding = nullptr) { // x: [N, n_token, hidden_size] - auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); - auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + if (has_post_norms) { + // Gemma 3 sandwich: pre-attn-norm → attn → post-attn-norm → +res + // → pre-ff-norm → mlp → post-ff-norm → +res. + auto input_ln = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attn_ln = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + auto pre_ff_ln = std::dynamic_pointer_cast(blocks["pre_feedforward_layernorm"]); + auto post_ff_ln = std::dynamic_pointer_cast(blocks["post_feedforward_layernorm"]); + + auto residual = x; + x = input_ln->forward(ctx, x); + x = self_attn->forward(ctx, x, input_pos, attention_mask, attention_mask_sliding); + x = post_attn_ln->forward(ctx, x); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + + residual = x; + x = pre_ff_ln->forward(ctx, x); + x = mlp->forward(ctx, x); + x = post_ff_ln->forward(ctx, x); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + return x; + } + auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); @@ -482,15 +623,23 @@ namespace LLM { struct TextModel : public GGMLBlock { protected: int64_t num_layers; + float embed_scale; + bool has_post_norms; public: TextModel(const LLMParams& params) - : num_layers(params.num_layers) { + : num_layers(params.num_layers), + embed_scale(params.embed_scale), + has_post_norms(params.has_post_norms) { blocks["embed_tokens"] = std::shared_ptr(new Embedding(params.vocab_size, params.hidden_size)); for (int i = 0; i < num_layers; i++) { - blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params, i)); + } + if (params.arch == LLMArch::GEMMA3) { + blocks["norm"] = std::shared_ptr(new RMSNormPlus1(params.hidden_size, params.rms_norm_eps)); + } else { + blocks["norm"] = std::shared_ptr(new RMSNorm(params.hidden_size, params.rms_norm_eps)); } - blocks["norm"] = std::shared_ptr(new RMSNorm(params.hidden_size, params.rms_norm_eps)); } ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -498,14 +647,22 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { // input_ids: [N, n_token] // return: [N, n_token, hidden_size] auto embed_tokens = std::dynamic_pointer_cast(blocks["embed_tokens"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto x = embed_tokens->forward(ctx, input_ids); + if (embed_scale != 1.0f) { + x = ggml_scale(ctx->ggml_ctx, x, embed_scale); + } + if (all_hidden_states) { + all_hidden_states->push_back(x); + } std::vector intermediate_outputs; @@ -551,7 +708,10 @@ namespace LLM { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos, attention_mask); + x = block->forward(ctx, x, input_pos, attention_mask, attention_mask_sliding); + if (all_hidden_states) { + all_hidden_states->push_back(x); + } if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -565,6 +725,12 @@ namespace LLM { } else { x = norm->forward(ctx, x); } + // HF Gemma 3 (and most HF causal-LM models): hidden_states[-1] is the + // POST-final-norm state. Replace the last pre-norm entry we stored with + // the normed version so downstream stacking matches exactly. + if (all_hidden_states && !all_hidden_states->empty()) { + all_hidden_states->back() = x; + } return x; } }; @@ -599,11 +765,14 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); return x; } @@ -652,13 +821,37 @@ namespace LLM { params.qkv_bias = false; params.qk_norm = true; params.rms_norm_eps = 1e-6f; + } else if (arch == LLMArch::GEMMA3) { + // Gemma 3 12B (LTX-2 text encoder). See memory file + // .opencode/memories/2026-04-22_1000_gemma3-delta-note.md for derivation. + params.head_dim = 256; + params.num_heads = 16; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.qk_norm = true; + params.rms_norm_eps = 1e-6f; + params.sliding_window = 1024; + params.sliding_window_pattern = 6; + params.rope_theta_global = 1000000.f; + params.rope_theta_local = 10000.f; + // Real Gemma 3 12B config.json sets rope_scaling={factor: 8.0, + // rope_type: linear} on full_attention layers. HuggingFace divides + // inv_freq by factor, which corresponds to ggml_rope_ext freq_scale + // = 1/factor. Sliding-attention layers stay unscaled. + params.rope_scaling_factor_global = 8.f; + params.has_post_norms = true; + // embed_scale is sqrt(hidden_size); hidden_size is autodetected below, + // so defer setting embed_scale until after the tensor-storage scan. } bool have_vision_weight = false; bool llama_cpp_style = false; params.num_layers = 0; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) + // Use prefix-boundary match (must be followed by '.') rather than bare + // substring: otherwise e.g. prefix "text_encoder" would also match + // "text_encoder_deep.*" tensors and inflate auto-detected num_layers. + if (tensor_name.rfind(prefix + ".", 0) != 0) continue; size_t pos = tensor_name.find("visual."); if (pos != std::string::npos) { @@ -686,10 +879,32 @@ namespace LLM { if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) { params.intermediate_size = pair.second.ne[1]; } + if (arch == LLMArch::GEMMA3) { + // Gemma 3 has configurable head_dim (256 for 12B, 32 in our tiny test). + // q_norm.weight has shape [head_dim]; q_proj.weight is [hidden_size, num_heads*head_dim] + // and stored in GGML with ne[1]=num_heads*head_dim; likewise k_proj gives num_kv_heads. + if (contains(tensor_name, "layers.0.self_attn.q_norm.weight")) { + params.head_dim = (int)pair.second.ne[0]; + } + } } if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B params.num_heads = 16; } + if (arch == LLMArch::GEMMA3) { + // Second pass: derive num_heads / num_kv_heads once head_dim is known. + for (auto pair : tensor_storage_map) { + std::string tn = pair.first; + if (tn.rfind(prefix + ".", 0) != 0) continue; + if (contains(tn, "layers.0.self_attn.q_proj.weight") && params.head_dim > 0) { + params.num_heads = (int)(pair.second.ne[1] / params.head_dim); + } + if (contains(tn, "layers.0.self_attn.k_proj.weight") && params.head_dim > 0) { + params.num_kv_heads = (int)(pair.second.ne[1] / params.head_dim); + } + } + params.embed_scale = sqrtf((float)params.hidden_size); + } LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, params.num_layers, params.vocab_size, @@ -722,8 +937,11 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { + auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); // [N, n_token, hidden_size] return hidden_states; } @@ -737,11 +955,15 @@ namespace LLM { return hidden_states; } + // Scratch storage for the Gemma sliding-window mask. + std::vector sliding_attention_mask_vec; + ggml_cgraph* build_graph(const sd::Tensor& input_ids_tensor, const sd::Tensor& attention_mask_tensor, const std::vector>>& image_embeds_tensor, - std::set out_layers) { - ggml_cgraph* gf = ggml_new_graph(compute_ctx); + std::set out_layers, + std::vector* all_hidden_states = nullptr) { + ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE); ggml_tensor* input_ids = make_input(input_ids_tensor); std::vector> image_embeds; image_embeds.reserve(image_embeds_tensor.size()); @@ -751,7 +973,7 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3 || params.arch == LLMArch::GEMMA3) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -789,9 +1011,27 @@ namespace LLM { set_backend_tensor_data(attention_mask, attention_mask_vec.data()); } + // Gemma 3 sliding-window mask: causal AND (q - k < window_size). + ggml_tensor* attention_mask_sliding = nullptr; + if (params.arch == LLMArch::GEMMA3 && params.sliding_window > 0) { + sliding_attention_mask_vec.resize(n_tokens * n_tokens); + for (int q = 0; q < n_tokens; q++) { + for (int k = 0; k < n_tokens; k++) { + float value = 0.f; + if (k > q || (q - k) >= params.sliding_window) { + value = -INFINITY; + } + sliding_attention_mask_vec[q * n_tokens + k] = value; + } + } + attention_mask_sliding = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens); + set_backend_tensor_data(attention_mask_sliding, sliding_attention_mask_vec.data()); + } + auto runner_ctx = get_context(); - ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); ggml_build_forward_expand(gf, hidden_states); @@ -809,6 +1049,35 @@ namespace LLM { return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); } + // Returns all N+1 hidden states (embedding + each transformer layer, with + // final layer's output post-model.norm). Stacked along a new innermost axis, + // shape in GGML: ne=[num_layers+1, hidden_size, n_tokens, batch] which matches + // PyTorch `torch.stack(hidden_states, dim=-1)` layout of [B, T, H, N+1]. + sd::Tensor compute_all_hidden_states(const int n_threads, + const sd::Tensor& input_ids, + const sd::Tensor& attention_mask) { + auto get_graph = [&]() -> ggml_cgraph* { + std::vector hidden_states; + ggml_cgraph* gf = build_graph(input_ids, attention_mask, {}, {}, &hidden_states); + + GGML_ASSERT(!hidden_states.empty()); + // Reshape each [H, T, B] -> [1, H, T, B] so we can concat along axis 0. + ggml_tensor* stacked = nullptr; + for (auto* h : hidden_states) { + auto h_cont = ggml_cont(compute_ctx, h); + auto h_4d = ggml_reshape_4d(compute_ctx, h_cont, 1, h_cont->ne[0], h_cont->ne[1], h_cont->ne[2]); + if (stacked == nullptr) { + stacked = h_4d; + } else { + stacked = ggml_concat(compute_ctx, stacked, h_4d, 0); + } + } + ggml_build_forward_expand(gf, stacked); + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } + int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { int64_t grid_t = 1; int64_t grid_h = h / params.vision.patch_size; @@ -989,6 +1258,10 @@ namespace LLM { : model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { tokenizer = std::make_shared(); + } else if (arch == LLMArch::GEMMA3) { + // Gemma 3 uses SentencePiece (vocab 262208). A SentencePiece loader is + // not yet implemented in this repo; tokenization path lands in task #25. + GGML_ABORT("Gemma 3 SentencePiece tokenizer not implemented yet"); } else { tokenizer = std::make_shared(); } diff --git a/src/ltx.hpp b/src/ltx.hpp new file mode 100644 index 000000000..00995f1c2 --- /dev/null +++ b/src/ltx.hpp @@ -0,0 +1,764 @@ +#ifndef __LTX_HPP__ +#define __LTX_HPP__ + +#include +#include +#include +#include + +#include "ggml_extend.hpp" +#include "ltx_rope.hpp" +#include "model.h" + +// LTX-2 video DiT. +// Reference: /devel/tools/diffusion/LTX-2/packages/ltx-core/src/ltx_core/model/transformer/ +// +// Scope (first landing): text-conditioned video-only (LTXModelType.VideoOnly), rope_type=INTERLEAVED, +// cross_attention_adaln=false, apply_gated_attention=false. Audio pathway and AV cross-attention are +// deferred (stubbed out) — the weights are just not instantiated. + +namespace LTX { + // 32768 was enough for the 2-layer parity-test DiT. The 22B V2 has 48 layers + // + cross_attention_adaln + prompt_adaln_single, roughly 2-3× the op count + // per block vs. V1. Bump generously so graph construction never fails the + // `cgraph->n_nodes < cgraph->size` assert in ggml's append path. + constexpr int LTX_GRAPH_SIZE = 131072; + constexpr int TIME_PROJ_DIM = 256; + constexpr int ADALN_BASE = 6; + constexpr int ADALN_WITH_CA = 9; + + // Python: ltx_core.model.transformer.rope.LTXRopeType. Real LTX-2.3 config uses + // SPLIT; earlier LTX variants (and our parity test's old default) were INTERLEAVED. + enum class RopeType { INTERLEAVED, SPLIT }; + + // Parameter-free RMSNorm helper. + __STATIC_INLINE__ ggml_tensor* parameterless_rms_norm(ggml_context* ctx, ggml_tensor* x, float eps = 1e-6f) { + return ggml_rms_norm(ctx, x, eps); + } + + struct AdaLayerNormSingle : public GGMLBlock { + protected: + int embedding_dim; + int embedding_coefficient; + + public: + AdaLayerNormSingle() = default; + AdaLayerNormSingle(int embedding_dim, int embedding_coefficient) + : embedding_dim(embedding_dim), embedding_coefficient(embedding_coefficient) { + // Python: self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim, size_emb_dim=embedding_dim // 3) + // -> time_proj: sinusoidal (no weights) + // -> timestep_embedder.linear_1: Linear(256, embedding_dim) + // -> timestep_embedder.linear_2: Linear(embedding_dim, embedding_dim) + // Python: self.linear = Linear(embedding_dim, coefficient * embedding_dim) + blocks["emb.timestep_embedder.linear_1"] = std::make_shared(TIME_PROJ_DIM, embedding_dim, true); + blocks["emb.timestep_embedder.linear_2"] = std::make_shared(embedding_dim, embedding_dim, true); + blocks["linear"] = std::make_shared(embedding_dim, embedding_coefficient * embedding_dim, true); + } + + // timestep: [B] — caller MUST pass the pre-scaled timestep (σ * timestep_scale_multiplier). + // Python applies the scaling in TransformerArgsPreprocessor._prepare_timestep; we mirror that + // boundary so the denoiser (sigma_to_t) is the single place that owns the 1000× factor. + // Double-scaling (denoiser + AdaLN) would drive sinusoidal embedding args to σ·1e6, which is + // numerical nonsense and was a real risk before this refactor. + // + // Returns {modulation, embedded_timestep}. + // modulation ne: [embedding_dim, coefficient, B] + // embedded_timestep ne: [embedding_dim, B] + std::pair forward(GGMLRunnerContext* ctx, + ggml_tensor* timestep) { + auto l1 = std::dynamic_pointer_cast(blocks["emb.timestep_embedder.linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["emb.timestep_embedder.linear_2"]); + auto proj = std::dynamic_pointer_cast(blocks["linear"]); + + auto t_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, TIME_PROJ_DIM, 10000, 1.0f); + auto hidden = l1->forward(ctx, t_proj); + hidden = ggml_silu_inplace(ctx->ggml_ctx, hidden); + auto embedded = l2->forward(ctx, hidden); // [embedding_dim, B] + + auto modulation = ggml_silu(ctx->ggml_ctx, embedded); + modulation = proj->forward(ctx, modulation); // [coeff*embedding_dim, B] + + int64_t B = modulation->ne[1]; + modulation = ggml_reshape_3d(ctx->ggml_ctx, modulation, embedding_dim, embedding_coefficient, B); + return {modulation, embedded}; + } + }; + + // GELUApprox block: Linear(dim_in → dim_out) + gelu(tanh approximation). + // Python: GELUApprox uses torch.nn.functional.gelu(..., approximate="tanh") which matches ggml_gelu. + struct GELUApprox : public GGMLBlock { + public: + GELUApprox() = default; + GELUApprox(int64_t dim_in, int64_t dim_out) { + blocks["proj"] = std::make_shared(dim_in, dim_out, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + x = proj->forward(ctx, x); + return ggml_ext_gelu(ctx->ggml_ctx, x, true); + } + }; + + struct FeedForward : public GGMLBlock { + public: + FeedForward() = default; + FeedForward(int64_t dim, int64_t dim_out, int mult = 4) { + int64_t inner = dim * mult; + // Python: self.net = Sequential(GELUApprox(dim, inner), Identity(), Linear(inner, dim_out)) + blocks["net.0"] = std::make_shared(dim, inner); + blocks["net.2"] = std::make_shared(inner, dim_out, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto gelu_proj = std::dynamic_pointer_cast(blocks["net.0"]); + auto out_proj = std::dynamic_pointer_cast(blocks["net.2"]); + x = gelu_proj->forward(ctx, x); + x = out_proj->forward(ctx, x); + return x; + } + }; + + struct LTXAttention : public GGMLBlock { + protected: + int64_t query_dim; + int64_t context_dim; + int num_heads; + int head_dim; + int64_t inner_dim; + float norm_eps; + bool apply_gated_attention; + RopeType rope_type; + + public: + LTXAttention() = default; + LTXAttention(int64_t query_dim, int64_t context_dim, int num_heads, int head_dim, + bool apply_gated_attention = false, float norm_eps = 1e-6f, + RopeType rope_type = RopeType::SPLIT) + : query_dim(query_dim), context_dim(context_dim), num_heads(num_heads), + head_dim(head_dim), inner_dim(static_cast(num_heads) * head_dim), + norm_eps(norm_eps), apply_gated_attention(apply_gated_attention), + rope_type(rope_type) { + blocks["to_q"] = std::make_shared(query_dim, inner_dim, true); + blocks["to_k"] = std::make_shared(context_dim, inner_dim, true); + blocks["to_v"] = std::make_shared(context_dim, inner_dim, true); + blocks["q_norm"] = std::make_shared(inner_dim, norm_eps); + blocks["k_norm"] = std::make_shared(inner_dim, norm_eps); + blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, true); + if (apply_gated_attention) { + blocks["to_gate_logits"] = std::make_shared(query_dim, num_heads, true); + } + } + + // x: [query_dim, L_q, B] + // context: [context_dim, L_kv, B] (defaults to x for self-attn) + // pe: optional packed cos/sin [inner_dim, L_q, 2] (shared between q and k) + // mask: optional additive attention mask + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* pe, + ggml_tensor* mask = nullptr) { + if (context == nullptr) { + context = x; + } + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto to_out = std::dynamic_pointer_cast(blocks["to_out.0"]); + + auto q = to_q->forward(ctx, x); // [inner_dim, L_q, B] + auto k = to_k->forward(ctx, context); // [inner_dim, L_kv, B] + auto v = to_v->forward(ctx, context); // [inner_dim, L_kv, B] + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + + if (pe != nullptr) { + if (rope_type == RopeType::SPLIT) { + auto cos_sin = LTXRope::split_pe_split(ctx->ggml_ctx, pe); + q = LTXRope::apply_rotary_emb_split(ctx->ggml_ctx, q, cos_sin.first, cos_sin.second, num_heads); + k = LTXRope::apply_rotary_emb_split(ctx->ggml_ctx, k, cos_sin.first, cos_sin.second, num_heads); + } else { + auto cos_sin = LTXRope::split_pe(ctx->ggml_ctx, pe); + q = LTXRope::apply_rotary_emb_interleaved(ctx->ggml_ctx, q, cos_sin.first, cos_sin.second); + k = LTXRope::apply_rotary_emb_interleaved(ctx->ggml_ctx, k, cos_sin.first, cos_sin.second); + } + } + + auto out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, + num_heads, mask, false, ctx->flash_attn_enabled); + // out: [inner_dim, L_q, B] + + if (apply_gated_attention) { + auto gate_proj = std::dynamic_pointer_cast(blocks["to_gate_logits"]); + auto gate_logits = gate_proj->forward(ctx, x); // [num_heads, L_q, B] + auto gates = ggml_sigmoid(ctx->ggml_ctx, gate_logits); + gates = ggml_scale(ctx->ggml_ctx, gates, 2.f); + // out is [inner_dim, L_q, B]; reshape to [head_dim, num_heads, L_q, B], multiply gates as [1, num_heads, L_q, B] broadcast. + int64_t L_q = out->ne[1]; + int64_t B = out->ne[2]; + auto out4 = ggml_reshape_4d(ctx->ggml_ctx, out, head_dim, num_heads, L_q, B); + auto g4 = ggml_reshape_4d(ctx->ggml_ctx, gates, 1, num_heads, L_q, B); + out4 = ggml_mul(ctx->ggml_ctx, out4, g4); + out = ggml_reshape_3d(ctx->ggml_ctx, out4, inner_dim, L_q, B); + } + + out = to_out->forward(ctx, out); // [query_dim, L_q, B] + return out; + } + }; + + // PixArtAlphaTextProjection — caption_projection inside the DiT. + // Python: ltx_core/model/transformer/text_projection.py. + // linear_1 (caption_channels → hidden) → GELU(tanh) → linear_2 (hidden → out). + // Used in V1 / 19B to bring the connector's 3840-dim output up to the DiT's + // 4096-dim inner space. In config the `caption_proj_before_connector` flag + // distinguishes V1 (True, used here) from V2 (False, handled separately). + struct PixArtAlphaTextProjection : public GGMLBlock { + protected: + int64_t in_features; + int64_t hidden_size; + int64_t out_features; + + public: + PixArtAlphaTextProjection() = default; + PixArtAlphaTextProjection(int64_t in_features, int64_t hidden_size, int64_t out_features = 0) + : in_features(in_features), hidden_size(hidden_size), + out_features(out_features == 0 ? hidden_size : out_features) { + blocks["linear_1"] = std::make_shared(in_features, hidden_size, true); + blocks["linear_2"] = std::make_shared(hidden_size, this->out_features, true); + } + + int64_t get_in_features() const { return in_features; } + int64_t get_out_features() const { return out_features; } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto l1 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["linear_2"]); + x = l1->forward(ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, /*approximate_tanh=*/true); + x = l2->forward(ctx, x); + return x; + } + }; + + struct LTXTransformerBlock : public GGMLBlock { + protected: + int64_t dim; + int num_heads; + int head_dim; + int64_t context_dim; + bool cross_attention_adaln; + bool apply_gated_attention; + float norm_eps; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + int num_params = cross_attention_adaln ? ADALN_WITH_CA : ADALN_BASE; + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, num_params); + if (cross_attention_adaln) { + params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2); + } + } + + public: + LTXTransformerBlock() = default; + LTXTransformerBlock(int64_t dim, int num_heads, int head_dim, int64_t context_dim, + bool cross_attention_adaln = false, bool apply_gated_attention = false, + float norm_eps = 1e-6f, + RopeType rope_type = RopeType::SPLIT) + : dim(dim), num_heads(num_heads), head_dim(head_dim), context_dim(context_dim), + cross_attention_adaln(cross_attention_adaln), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps) { + blocks["attn1"] = std::make_shared(dim, dim, num_heads, head_dim, apply_gated_attention, norm_eps, rope_type); + blocks["attn2"] = std::make_shared(dim, context_dim, num_heads, head_dim, apply_gated_attention, norm_eps, rope_type); + blocks["ff"] = std::make_shared(dim, dim); + } + + // Helper — returns a triple (a, b, c) from scale_shift_table[start:start+3] + modulation[:, start:start+3, :] + // scale_shift_table: ne [dim, num_params] + // modulation: ne [dim, num_params, B] + // Returns three tensors each ne [dim, 1, B]. + std::tuple extract_triple(ggml_context* ctx, + ggml_tensor* sst, + ggml_tensor* modulation, + int start) { + int64_t B = modulation->ne[2]; + + // Slice scale_shift_table rows [start, start+3). + auto sst_slice = ggml_ext_slice(ctx, sst, 1, start, start + 3); // ne [dim, 3] + + // Slice modulation along dim 1 [start, start+3). + auto mod_slice = ggml_ext_slice(ctx, modulation, 1, start, start + 3); // ne [dim, 3, B] + + // Broadcast add: sst_slice [dim, 3] + mod_slice [dim, 3, B] → [dim, 3, B]. + auto combined = ggml_add(ctx, mod_slice, sst_slice); + + auto chunks = ggml_ext_chunk(ctx, combined, 3, 1); + // Each chunk ne [dim, 1, B] + return std::make_tuple(chunks[0], chunks[1], chunks[2]); + } + + // Extract (shift, scale) from prompt_scale_shift_table [dim, 2] + prompt_modulation [dim, 2, B]. + // Python: `(prompt_scale_shift_table[None, None] + prompt_timestep.reshape(...,2,-1)).unbind(2)`. + std::pair extract_kv_pair(ggml_context* ctx, + ggml_tensor* psst, + ggml_tensor* prompt_mod) { + auto combined = ggml_add(ctx, prompt_mod, psst); // [dim, 2, B] + auto chunks = ggml_ext_chunk(ctx, combined, 2, 1); + return {chunks[0], chunks[1]}; // (shift_kv, scale_kv), each [dim, 1, B] + } + + // x: [dim, L_q, B] + // context: [context_dim, L_kv, B] + // modulation: [dim, num_params, B] (num_params = 6 for V1, 9 for V2) + // pe: packed cos/sin tensor [dim, L_q, 2] + // prompt_modulation: [dim, 2, B] — required when cross_attention_adaln=true, else nullptr + // context_mask: [L_kv, L_q, 1, B] additive mask (or nullptr) + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* modulation, + ggml_tensor* pe, + ggml_tensor* prompt_modulation = nullptr, + ggml_tensor* context_mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + auto sst = params["scale_shift_table"]; + + // --- Self-attention path (modulation slice 0:3 → shift, scale, gate) --- + auto triple1 = extract_triple(ctx->ggml_ctx, sst, modulation, 0); + auto shift_msa = std::get<0>(triple1); + auto scale_msa = std::get<1>(triple1); + auto gate_msa = std::get<2>(triple1); + + auto norm_x = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm_x, ggml_mul(ctx->ggml_ctx, norm_x, scale_msa)); + auto modulated = ggml_add(ctx->ggml_ctx, scaled, shift_msa); + auto attn_out = attn1->forward(ctx, modulated, nullptr, pe, nullptr); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + + // --- Cross-attention --- + // V1 (cross_attention_adaln=false): plain rms_norm → attn2 → residual. + // V2 (cross_attention_adaln=true): + // modulation[6:9] → (q_shift, q_scale, q_gate) for the query path + // prompt_scale_shift_table + prompt_modulation → (kv_shift, kv_scale) for the context + // attn_input = rms_norm(x) * (1 + q_scale) + q_shift + // context_mod = context * (1 + kv_scale) + kv_shift + // x = x + attn2(attn_input, context_mod) * q_gate + if (cross_attention_adaln) { + GGML_ASSERT(prompt_modulation != nullptr && "cross_attention_adaln requires prompt_modulation"); + auto triple_ca = extract_triple(ctx->ggml_ctx, sst, modulation, 6); + auto shift_q = std::get<0>(triple_ca); + auto scale_q = std::get<1>(triple_ca); + auto gate_q = std::get<2>(triple_ca); + + auto psst = params["prompt_scale_shift_table"]; // [dim, 2] + auto kv_pair = extract_kv_pair(ctx->ggml_ctx, psst, prompt_modulation); + auto shift_kv = kv_pair.first; + auto scale_kv = kv_pair.second; + + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto q_scaled = ggml_add(ctx->ggml_ctx, norm_x_ca, ggml_mul(ctx->ggml_ctx, norm_x_ca, scale_q)); + auto q_modulated = ggml_add(ctx->ggml_ctx, q_scaled, shift_q); + auto ctx_scaled = ggml_add(ctx->ggml_ctx, context, ggml_mul(ctx->ggml_ctx, context, scale_kv)); + auto ctx_modulated = ggml_add(ctx->ggml_ctx, ctx_scaled, shift_kv); + + auto ca_out = attn2->forward(ctx, q_modulated, ctx_modulated, nullptr, context_mask); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, ca_out, gate_q)); + } else { + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto ca_out = attn2->forward(ctx, norm_x_ca, context, nullptr, context_mask); + x = ggml_add(ctx->ggml_ctx, x, ca_out); + } + + // --- FeedForward path (modulation slice 3:6 → shift, scale, gate) --- + auto triple2 = extract_triple(ctx->ggml_ctx, sst, modulation, 3); + auto shift_mlp = std::get<0>(triple2); + auto scale_mlp = std::get<1>(triple2); + auto gate_mlp = std::get<2>(triple2); + + auto norm_x2 = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto scaled_mlp = ggml_add(ctx->ggml_ctx, norm_x2, ggml_mul(ctx->ggml_ctx, norm_x2, scale_mlp)); + auto modulated_mlp = ggml_add(ctx->ggml_ctx, scaled_mlp, shift_mlp); + auto ff_out = ff->forward(ctx, modulated_mlp); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, ff_out, gate_mlp)); + + return x; + } + }; + + struct LTXParams { + int64_t in_channels = 128; + int64_t out_channels = 128; + int64_t inner_dim = 4096; + int num_heads = 32; + int head_dim = 128; + int num_layers = 48; + int64_t cross_attention_dim = 4096; + bool cross_attention_adaln = false; + bool apply_gated_attention = false; + float norm_eps = 1e-6f; + float positional_embedding_theta = 10000.f; + std::vector positional_embedding_max_pos = {20, 2048, 2048}; + float timestep_scale_multiplier = 1000.f; + bool use_middle_indices_grid = true; + RopeType rope_type = RopeType::SPLIT; // real LTX-2.3 default + // Optional caption_projection sitting on the DiT side (V1 / 19B); absent for + // tiny parity tests that feed context in DiT inner_dim already. When enabled, + // `caption_channels` is the input dim (connector output) and `caption_hidden` + // / `caption_out` follow the PixArtAlphaTextProjection defaults. + bool has_caption_projection = false; + int64_t caption_channels = 0; + int64_t caption_hidden = 0; + int64_t caption_out = 0; + }; + + struct LTXModel : public GGMLBlock { + LTXParams p; + + protected: + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, p.inner_dim, 2); + } + + public: + LTXModel() = default; + LTXModel(LTXParams p) : p(p) { + blocks["patchify_proj"] = std::make_shared(p.in_channels, p.inner_dim, true); + int coeff = p.cross_attention_adaln ? ADALN_WITH_CA : ADALN_BASE; + blocks["adaln_single"] = std::make_shared(p.inner_dim, coeff); + blocks["proj_out"] = std::make_shared(p.inner_dim, p.out_channels, true); + + // V2: a second AdaLayerNormSingle that generates modulation for the + // context path inside cross-attention. Python: + // `prompt_adaln_single = AdaLayerNormSingle(inner_dim, embedding_coefficient=2)`. + if (p.cross_attention_adaln) { + blocks["prompt_adaln_single"] = std::make_shared(p.inner_dim, 2); + } + + for (int i = 0; i < p.num_layers; ++i) { + blocks["transformer_blocks." + std::to_string(i)] = + std::make_shared(p.inner_dim, p.num_heads, p.head_dim, + p.cross_attention_dim, + p.cross_attention_adaln, + p.apply_gated_attention, + p.norm_eps, + p.rope_type); + } + + if (p.has_caption_projection) { + blocks["caption_projection"] = std::make_shared( + p.caption_channels, p.caption_hidden, p.caption_out); + } + } + + // latent: ne [in_channels, T*H*W, B] (already patchified by caller) + // timestep: ne [B] + // context: ne [cross_attention_dim, S, B] + // pe: ne [inner_dim, T*H*W, 2] (interleaved cos/sin) + // context_mask: ne [S, T*H*W, 1, B] or nullptr + // Returns: ne [out_channels, T*H*W, B] + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* timestep, + ggml_tensor* context, + ggml_tensor* pe, + ggml_tensor* context_mask = nullptr) { + auto patchify_proj = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto adaln_single = std::dynamic_pointer_cast(blocks["adaln_single"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + + // Apply caption_projection (V1 / 19B) to lift context from connector dim + // to DiT inner_dim. Python: TransformerArgs._prepare_context. + if (p.has_caption_projection && context != nullptr) { + auto caption_proj = std::dynamic_pointer_cast(blocks["caption_projection"]); + context = caption_proj->forward(ctx, context); + } + + auto x = patchify_proj->forward(ctx, latent); // [inner_dim, T*H*W, B] + + // Caller must feed the already-scaled timestep (σ * 1000). The LTX2 denoiser's sigma_to_t + // is the single source of truth for that scaling — see LTXParams::timestep_scale_multiplier + // which is kept as documentation/config only, not applied here. + auto adaln_res = adaln_single->forward(ctx, timestep); + auto modulation = adaln_res.first; // [inner_dim, coeff, B] (coeff = 6 or 9) + auto embedded_t = adaln_res.second; // [inner_dim, B] + + // V2: prompt_adaln_single takes the same σ (raw timestep before AdaLN-scaling) + // and emits a [inner_dim, 2, B] modulation that's shared across all blocks' + // cross-attention kv path. In Python video_args_preprocessor passes + // `modality.sigma`; for our single-prompt inference sigma == timestep. We reuse + // the same timestep tensor here. + ggml_tensor* prompt_modulation = nullptr; + if (p.cross_attention_adaln) { + auto prompt_adaln = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); + auto prompt_res = prompt_adaln->forward(ctx, timestep); + prompt_modulation = prompt_res.first; // [inner_dim, 2, B] + } + + for (int i = 0; i < p.num_layers; ++i) { + auto block = std::dynamic_pointer_cast( + blocks["transformer_blocks." + std::to_string(i)]); + x = block->forward(ctx, x, context, modulation, pe, prompt_modulation, context_mask); + } + + // Output modulation: python has `sst[None,None] + embedded[:,:,None]` giving (B, 1, 2, dim). + // In ggml ne that's [dim, 2, 1, B]. For B>1 we'd need to broadcast sst over B explicitly; + // current parity test uses B=1 so we pick the direct add path here and rely on ggml's + // ggml_can_repeat(b, a) — `a` must be >= `b` in every dim so we put sst first. + // sst ne: [inner_dim, 2, 1, 1] + // embedded: [inner_dim, 1, 1, B] (after reshape_4d from [inner_dim, B]) + // sum: [inner_dim, 2, 1, B] (provided B == 1; see TODO for B>1) + int64_t B = x->ne[2]; + GGML_ASSERT(B == 1 && "LTXModel output modulation currently assumes batch=1"); + auto sst = params["scale_shift_table"]; // ne [inner_dim, 2] + auto emb_view = ggml_reshape_4d(ctx->ggml_ctx, embedded_t, p.inner_dim, 1, 1, B); // ne [inner_dim, 1, 1, B] + auto ss_sum = ggml_add(ctx->ggml_ctx, sst, emb_view); // ne [inner_dim, 2, 1, 1] + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, ss_sum, 2, 1); // 2× ne [inner_dim, 1, 1, 1] + auto shift = ggml_reshape_3d(ctx->ggml_ctx, chunks[0], p.inner_dim, 1, 1); // ne [inner_dim, 1, 1] + auto scale = ggml_reshape_3d(ctx->ggml_ctx, chunks[1], p.inner_dim, 1, 1); // ne [inner_dim, 1, 1] + + x = ggml_ext_layer_norm(ctx->ggml_ctx, x, nullptr, nullptr, p.norm_eps); // param-less LN + + // x ne: [inner_dim, T, 1]; scale/shift ne: [inner_dim, 1, 1] — second arg broadcasts ok. + x = ggml_add(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x, scale)), shift); + x = proj_out->forward(ctx, x); // [out_channels, T*H*W, B] + return x; + } + }; + + struct LTXRunner : public GGMLRunner { + public: + LTXParams ltx_params; + LTXModel ltx; + std::vector pe_vec; + SDVersion version; + // fps used for temporal RoPE normalisation — see LTXRope::gen_video_positions. + // Defaults to 24 (LTX-2's canonical output fps); callers can override before compute(). + float fps = 24.0f; + // VAE spatiotemporal compression factors (time, height, width) applied to latent + // coordinates to reconstruct the pixel-space positions used for RoPE. Defaults match + // the LTX-2 22B VAE: 8× temporal, 32× spatial. The parity tests feed the Python model + // simplified positions (f/fps, h, w) — set scale_factors={1,1,1} and causal_fix=false + // in that path to keep parity assertions valid. + std::vector scale_factors = {8, 32, 32}; + bool causal_fix = true; + + void set_fps(float new_fps) { fps = new_fps; } + void set_scale_factors(int time, int height, int width) { + scale_factors = {time, height, width}; + } + void set_causal_fix(bool enable) { causal_fix = enable; } + + // params_override forces the given LTXParams instead of auto-detecting from the tensor map. + // Useful for parity tests and for cases where metadata pins the head_dim / num_heads to + // values that can't be inferred from weight shapes alone (q_norm etc. are inner_dim-wide). + LTXRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_LTX2, + const LTXParams* params_override = nullptr) + : GGMLRunner(backend, offload_params_to_cpu), version(version) { + if (params_override != nullptr) { + ltx_params = *params_override; + } else { + detect_params(tensor_storage_map, prefix); + } + ltx = LTXModel(ltx_params); + ltx.init(params_ctx, tensor_storage_map, prefix); + } + + void detect_params(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + std::string pre = prefix.empty() ? "" : prefix + "."; + + auto patchify_it = tensor_storage_map.find(pre + "patchify_proj.weight"); + if (patchify_it != tensor_storage_map.end()) { + const auto& ts = patchify_it->second; + if (ts.n_dims >= 2) { + ltx_params.in_channels = ts.ne[0]; + ltx_params.inner_dim = ts.ne[1]; + } + } + + auto proj_out_it = tensor_storage_map.find(pre + "proj_out.weight"); + if (proj_out_it != tensor_storage_map.end()) { + const auto& ts = proj_out_it->second; + if (ts.n_dims >= 2) { + ltx_params.out_channels = ts.ne[1]; + } + } + + // Infer num_layers from highest transformer_blocks index. + int max_layer = -1; + std::string block_prefix = pre + "transformer_blocks."; + for (auto& pair : tensor_storage_map) { + const std::string& name = pair.first; + if (name.rfind(block_prefix, 0) != 0) { + continue; + } + size_t start = block_prefix.size(); + size_t end = name.find('.', start); + if (end == std::string::npos) { + continue; + } + try { + int idx = std::stoi(name.substr(start, end - start)); + max_layer = std::max(max_layer, idx); + } catch (...) { + } + } + if (max_layer >= 0) { + ltx_params.num_layers = max_layer + 1; + } + + // Detect cross_attention_adaln from the size of scale_shift_table (9 if CA-AdaLN, 6 otherwise). + auto sst_it = tensor_storage_map.find(pre + "transformer_blocks.0.scale_shift_table"); + if (sst_it != tensor_storage_map.end()) { + const auto& ts = sst_it->second; + if (ts.n_dims >= 2 && ts.ne[1] == ADALN_WITH_CA) { + ltx_params.cross_attention_adaln = true; + } + } + + // Infer head_dim × num_heads from attn1.to_q.weight shape. + auto q_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn1.to_q.weight"); + if (q_it != tensor_storage_map.end()) { + const auto& ts = q_it->second; + if (ts.n_dims >= 2) { + ltx_params.inner_dim = ts.ne[1]; + } + } + // head_dim is a fixed LTX-2 hyperparam (128) unless a config tensor overrides. + ltx_params.head_dim = 128; + ltx_params.num_heads = static_cast(ltx_params.inner_dim / ltx_params.head_dim); + + // Infer cross_attention_dim from attn2.to_k weight shape. + auto k_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn2.to_k.weight"); + if (k_it != tensor_storage_map.end()) { + const auto& ts = k_it->second; + if (ts.n_dims >= 2) { + ltx_params.cross_attention_dim = ts.ne[0]; + } + } + + // Detect gated attention from presence of to_gate_logits. + auto gate_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn1.to_gate_logits.weight"); + if (gate_it != tensor_storage_map.end()) { + ltx_params.apply_gated_attention = true; + } + + // Detect optional caption_projection (V1 / 19B). + // linear_1 weight shape [in_features, hidden_size]; linear_2 shape [hidden_size, out_features]. + // (ggml ne[0] = innermost dim = PyTorch's in_features / hidden_size.) + auto cap1_it = tensor_storage_map.find(pre + "caption_projection.linear_1.weight"); + auto cap2_it = tensor_storage_map.find(pre + "caption_projection.linear_2.weight"); + if (cap1_it != tensor_storage_map.end() && cap2_it != tensor_storage_map.end()) { + const auto& l1 = cap1_it->second; + const auto& l2 = cap2_it->second; + if (l1.n_dims >= 2 && l2.n_dims >= 2) { + ltx_params.has_caption_projection = true; + ltx_params.caption_channels = l1.ne[0]; + ltx_params.caption_hidden = l1.ne[1]; + ltx_params.caption_out = l2.ne[1]; + } + } + } + + std::string get_desc() override { + return "ltx2"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + ltx.get_param_tensors(tensors, prefix); + } + + // Build the diffusion graph. + // x_tensor layout (ggml ne order): [W, H, T, in_channels] — follows the Wan / video convention with implicit batch N=1. + // timesteps: ne [N] + // context: ne [cross_attention_dim, S, N] + // context_mask: empty (not yet wired through) + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor, + const sd::Tensor& context_mask_tensor) { + ggml_cgraph* gf = new_graph_custom(LTX_GRAPH_SIZE); + + ggml_tensor* x = make_input(x_tensor); + ggml_tensor* timesteps = make_input(timesteps_tensor); + ggml_tensor* context = make_input(context_tensor); + ggml_tensor* ctx_mask = context_mask_tensor.empty() ? nullptr : make_input(context_mask_tensor); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t T = x->ne[2]; + int64_t C = x->ne[3]; + + LOG_DEBUG("LTX build_graph: x=[%lld,%lld,%lld,%lld] timesteps=[%lld] context=[%lld,%lld,%lld] inner_dim=%lld cross_attn_dim=%lld has_cap_proj=%d ca_adaln=%d gated=%d", + (long long)x->ne[0], (long long)x->ne[1], (long long)x->ne[2], (long long)x->ne[3], + (long long)timesteps->ne[0], + (long long)context->ne[0], (long long)context->ne[1], (long long)context->ne[2], + (long long)ltx_params.inner_dim, (long long)ltx_params.cross_attention_dim, + ltx_params.has_caption_projection ? 1 : 0, + ltx_params.cross_attention_adaln ? 1 : 0, + ltx_params.apply_gated_attention ? 1 : 0); + + // Flatten spatiotemporal dims into a sequence and move channels to ne[0]. + auto latent = ggml_reshape_3d(compute_ctx, x, W * H * T, C, 1); // [W*H*T, C, 1] + latent = ggml_cont(compute_ctx, ggml_permute(compute_ctx, latent, 1, 0, 2, 3)); // [C, W*H*T, 1] + + auto positions = LTXRope::gen_video_positions(static_cast(T), static_cast(H), static_cast(W), + ltx_params.use_middle_indices_grid, fps, + scale_factors, causal_fix); + ggml_tensor* pe = nullptr; + if (ltx_params.rope_type == RopeType::SPLIT) { + pe_vec = LTXRope::precompute_freqs_cis_split(positions, + static_cast(ltx_params.inner_dim), + ltx_params.num_heads, + ltx_params.positional_embedding_theta, + ltx_params.positional_embedding_max_pos); + // Split layout ne: [head_dim/2, num_heads, T*H*W, 2]. + int64_t half = ltx_params.inner_dim / 2; + int64_t per_head_half = half / ltx_params.num_heads; + pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, + per_head_half, ltx_params.num_heads, T * H * W, 2); + } else { + pe_vec = LTXRope::precompute_freqs_cis_interleaved(positions, + static_cast(ltx_params.inner_dim), + ltx_params.positional_embedding_theta, + ltx_params.positional_embedding_max_pos); + pe = ggml_new_tensor_3d(compute_ctx, GGML_TYPE_F32, ltx_params.inner_dim, T * H * W, 2); + } + set_backend_tensor_data(pe, pe_vec.data()); + + auto runner_ctx = get_context(); + ggml_tensor* out = ltx.forward(&runner_ctx, latent, timesteps, context, pe, ctx_mask); + + // out: [out_channels, T*H*W, 1] → [W, H, T, out_channels] to match Wan-style output. + out = ggml_cont(compute_ctx, ggml_permute(compute_ctx, out, 1, 0, 2, 3)); // [T*H*W, out_channels, 1] + out = ggml_reshape_4d(compute_ctx, out, W, H, T, ltx_params.out_channels); + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context, + const sd::Tensor& context_mask) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context, context_mask); + }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } + }; + +} // namespace LTX + +#endif // __LTX_HPP__ diff --git a/src/ltx_connector.hpp b/src/ltx_connector.hpp new file mode 100644 index 000000000..3641f9f21 --- /dev/null +++ b/src/ltx_connector.hpp @@ -0,0 +1,623 @@ +#ifndef __LTX_CONNECTOR_HPP__ +#define __LTX_CONNECTOR_HPP__ + +#include +#include +#include +#include + +#include "ggml_extend.hpp" +#include "ltx.hpp" +#include "ltx_rope.hpp" +#include "model.h" + +// 1D position generator for the connector's RoPE (n_pos_dims=1, max_pos=[1], +// positions[t] = t). Lives here so it sits next to its only caller, but stays +// in the LTXRope namespace. +namespace LTXRope { + __STATIC_INLINE__ std::vector> gen_1d_positions(int T) { + std::vector> pos(1, std::vector(T, 0.f)); + for (int t = 0; t < T; ++t) pos[0][t] = static_cast(t); + return pos; + } +} // namespace LTXRope + +// LTX-2 text connector (Phase 9.1, V1 / 19B). +// +// Python reference: +// ltx_core/text_encoders/gemma/feature_extractor.py (FeatureExtractorV1) +// ltx_core/text_encoders/gemma/embeddings_connector.py (Embeddings1DConnector) +// ltx_core/model/transformer/text_projection.py (PixArtAlphaTextProjection) +// +// Pipeline (Gemma 49-layer stack → DiT cross-attention context): +// stacked[B, T, D, L] → feature_extractor_normalize() (CPU, per-(B,L) masked +// mean/range → normed[B, T, D*L]) +// normed[B, T, D*L] → FeatureExtractorV1::forward (aggregate_embed Linear) +// → video_features[B, T, inner_dim] +// video_features → Embeddings1DConnector::forward (2× BasicTransformerBlock1D +// + final rms_norm) → [B, T, inner_dim] +// connector_out → PixArtAlphaTextProjection::forward (linear, gelu_tanh, +// linear) → [B, T, caption_out_dim] (= DiT inner_dim) + +namespace LTXConnector { + + // Compute FeatureExtractorV1's _norm_and_concat_padded_batch on the CPU. + // Python reference: _norm_and_concat_padded_batch in feature_extractor.py. + // + // Input: + // stacked: [B*T*D*L] contiguous, logical shape [B, T, D, L] + // seq_lengths: [B] — valid (non-pad) token count per batch + // padding_side: "left" or "right" + // Output: + // normed: [B*T*(D*L)] contiguous, logical shape [B, T, D*L] + // + // Padded positions (outside [0, seq_len) for "right", outside [T - seq_len, T) for "left") + // are zero'd after the normalization. + __STATIC_INLINE__ void feature_extractor_normalize(const float* stacked, + const int* seq_lengths, + float* normed, + int B, int T, int D, int L, + const std::string& padding_side = "left", + float eps = 1e-6f) { + const float FINF = std::numeric_limits::infinity(); + const float NINF = -FINF; + const bool is_left = (padding_side == "left"); + + for (int b = 0; b < B; ++b) { + int seq_len = seq_lengths[b]; + int valid_start = is_left ? (T - seq_len) : 0; + int valid_end = is_left ? T : seq_len; + + for (int l = 0; l < L; ++l) { + // Compute per-(b,l) masked mean, min, max over (t, d) where mask == 1. + double sum = 0.0; + float vmin = FINF; + float vmax = NINF; + for (int t = valid_start; t < valid_end; ++t) { + for (int d = 0; d < D; ++d) { + // Python layout: encoded[b, t, d, l] + // Flat index with ne [L, D, T, B] order would be ((b*T + t)*D + d)*L + l. + int64_t idx = ((static_cast(b) * T + t) * D + d) * L + l; + float v = stacked[idx]; + sum += v; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + } + double denom = static_cast(seq_len) * D; + float mean = static_cast(sum / (denom + eps)); + float range = vmax - vmin; + float inv = 8.0f / (range + eps); + + // Apply normalization over all T positions; zero out padded ones. + for (int t = 0; t < T; ++t) { + bool in_valid = (t >= valid_start && t < valid_end); + for (int d = 0; d < D; ++d) { + int64_t src_idx = ((static_cast(b) * T + t) * D + d) * L + l; + // normed layout: [B, T, D*L] with flat index (b*T + t)*(D*L) + (d*L + l). + int64_t dst_idx = (static_cast(b) * T + t) * (D * L) + (d * L + l); + if (in_valid) { + normed[dst_idx] = (stacked[src_idx] - mean) * inv; + } else { + normed[dst_idx] = 0.0f; + } + } + } + } + } + } + + // Per-token RMSNorm used by FeatureExtractorV2 (22B / V2 text path). Mirrors + // norm_and_concat_per_token_rms in Python feature_extractor.py. + // + // Input layout (ggml ne order, matches llm->compute_all_hidden_states): + // stacked[l + L*(d + D*(t + T*b))] — logical shape [B, T, D, L] + // + // Output layout (ggml ne order): + // normed[k + (D*L)*(t + T*b)] — logical shape [B, T, D*L] with k = d*L + l + // + // Per-(B, T, L) variance is computed over D; every entry is scaled by the + // corresponding rsqrt(var + eps). Padded positions (per `attention_mask`) get + // zeroed out post-reshape, matching Python's `torch.where(mask_3d, normed, 0)`. + // + // The result is NOT yet rescaled by sqrt(target/source) — that's applied as a + // `ggml_scale` in the graph immediately before the aggregate_embed Linear so + // video and audio branches (with different target dims) can share this buffer. + __STATIC_INLINE__ void feature_extractor_normalize_v2(const float* stacked, + const int* seq_lengths, + float* normed, + int B, int T, int D, int L, + const std::string& padding_side = "left", + float eps = 1e-6f) { + const bool is_left = (padding_side == "left"); + for (int b = 0; b < B; ++b) { + int seq_len = seq_lengths[b]; + int valid_start = is_left ? (T - seq_len) : 0; + int valid_end = is_left ? T : seq_len; + + for (int t = 0; t < T; ++t) { + bool in_valid = (t >= valid_start && t < valid_end); + // Per-layer rsqrt factor for the (b, t, *, l) row. + for (int l = 0; l < L; ++l) { + double sum_sq = 0.0; + for (int d = 0; d < D; ++d) { + int64_t idx = ((static_cast(b) * T + t) * D + d) * L + l; + double v = stacked[idx]; + sum_sq += v * v; + } + double variance = sum_sq / static_cast(D); + float rsq = static_cast(1.0 / std::sqrt(variance + eps)); + + for (int d = 0; d < D; ++d) { + int64_t src_idx = ((static_cast(b) * T + t) * D + d) * L + l; + int64_t dst_idx = (static_cast(b) * T + t) * (D * L) + (d * L + l); + if (in_valid) { + normed[dst_idx] = stacked[src_idx] * rsq; + } else { + normed[dst_idx] = 0.0f; + } + } + } + } + } + } + + // FeatureExtractorV1 block — just wraps the aggregate_embed Linear + // (feature_extractor.aggregate_embed.weight). + // + // The CPU-side normalization lives in feature_extractor_normalize(); this block + // expects an already-normalized [B, T, D*L] tensor as input. + struct FeatureExtractorV1 : public GGMLBlock { + protected: + int64_t flat_dim; + int64_t inner_dim; + + public: + FeatureExtractorV1() = default; + FeatureExtractorV1(int64_t flat_dim, int64_t inner_dim) + : flat_dim(flat_dim), inner_dim(inner_dim) { + // Python: aggregate_embed = Linear(flat_dim, inner_dim, bias=False). + blocks["aggregate_embed"] = std::make_shared(flat_dim, inner_dim, /*bias=*/false); + } + + // x: ne [flat_dim, T, B] (already normalized via feature_extractor_normalize). + // returns: ne [inner_dim, T, B]. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto agg = std::dynamic_pointer_cast(blocks["aggregate_embed"]); + return agg->forward(ctx, x); + } + }; + + // FeatureExtractorV2 block — V2 / 22B text path. Two parallel Linears + // (video + optional audio) WITH bias on a per-token RMS-normalized input. + // Python: ltx_core/text_encoders/gemma/feature_extractor.py::FeatureExtractorV2. + // + // The CPU-side normalization lives in feature_extractor_normalize_v2(); this + // block applies the in-graph rescale factor sqrt(target/source_dim) and the + // video_aggregate_embed Linear. Audio path is declared optional — if audio + // weights are absent the block skips it and is video-only. + struct FeatureExtractorV2 : public GGMLBlock { + protected: + int64_t flat_dim; // D * L (Gemma hidden × num_layers) + int64_t source_dim; // Gemma hidden size (D) + int64_t video_out_dim; // DiT inner_dim + int64_t audio_out_dim; // optional; 0 when no audio aggregate_embed + float video_scale; // sqrt(video_out_dim / source_dim) + float audio_scale; // sqrt(audio_out_dim / source_dim) + + public: + FeatureExtractorV2() = default; + FeatureExtractorV2(int64_t flat_dim, int64_t source_dim, + int64_t video_out_dim, + int64_t audio_out_dim = 0) + : flat_dim(flat_dim), source_dim(source_dim), + video_out_dim(video_out_dim), audio_out_dim(audio_out_dim) { + video_scale = std::sqrt(static_cast(video_out_dim) / static_cast(source_dim)); + audio_scale = audio_out_dim > 0 + ? std::sqrt(static_cast(audio_out_dim) / static_cast(source_dim)) + : 0.f; + blocks["video_aggregate_embed"] = std::make_shared(flat_dim, video_out_dim, /*bias=*/true); + if (audio_out_dim > 0) { + blocks["audio_aggregate_embed"] = std::make_shared(flat_dim, audio_out_dim, /*bias=*/true); + } + } + + bool has_audio() const { return audio_out_dim > 0; } + int64_t get_video_out_dim() const { return video_out_dim; } + int64_t get_audio_out_dim() const { return audio_out_dim; } + + // x: ne [flat_dim, T, B] (already per-token RMS-normalized via feature_extractor_normalize_v2). + // Returns video_features ne [video_out_dim, T, B]. Audio branch unused for video-only smoke tests. + ggml_tensor* forward_video(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto agg = std::dynamic_pointer_cast(blocks["video_aggregate_embed"]); + auto scaled = ggml_scale(ctx->ggml_ctx, x, video_scale); + return agg->forward(ctx, scaled); + } + + ggml_tensor* forward_audio(GGMLRunnerContext* ctx, ggml_tensor* x) { + GGML_ASSERT(has_audio() && "FeatureExtractorV2: audio_aggregate_embed not allocated"); + auto agg = std::dynamic_pointer_cast(blocks["audio_aggregate_embed"]); + auto scaled = ggml_scale(ctx->ggml_ctx, x, audio_scale); + return agg->forward(ctx, scaled); + } + }; + + // A single 1D transformer block in the connector. + // Python: _BasicTransformerBlock1D in embeddings_connector.py. + // + // Self-attention only (no cross-attention, no AdaLN). Parameter-free rms_norm + // before attention and before the feed-forward. + struct BasicTransformerBlock1D : public GGMLBlock { + protected: + int64_t dim; + int num_heads; + int head_dim; + bool apply_gated_attention; + float norm_eps; + + public: + BasicTransformerBlock1D() = default; + BasicTransformerBlock1D(int64_t dim, int num_heads, int head_dim, + bool apply_gated_attention = false, + float norm_eps = 1e-6f) + : dim(dim), num_heads(num_heads), head_dim(head_dim), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps) { + // Self-attention: context_dim = query_dim = dim. The connector's 1D RoPE + // uses INTERLEAVED layout (Python embeddings_connector.py calls + // precompute_freqs_cis with default rope_type=INTERLEAVED); only the DiT + // was switched to SPLIT in LTX-2.3. + blocks["attn1"] = std::make_shared(dim, dim, num_heads, head_dim, + apply_gated_attention, norm_eps, + LTX::RopeType::INTERLEAVED); + blocks["ff"] = std::make_shared(dim, dim); + } + + // hidden_states: ne [dim, T, B] + // pe: ne [dim, T, 2] packed cos/sin (or nullptr) + // mask: additive attention mask (or nullptr) + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* pe, + ggml_tensor* mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + // Pre-norm + self-attention + residual. + auto norm1 = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + auto a_out = attn1->forward(ctx, norm1, /*context=*/nullptr, pe, mask); + hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, a_out); + + // Pre-norm + feed-forward + residual. + auto norm2 = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + auto f_out = ff->forward(ctx, norm2); + hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, f_out); + + return hidden_states; + } + }; + + // Embeddings1DConnector: 2-layer 1D transformer with learnable registers + + // final parameter-free rms_norm. 1D RoPE with max_pos=[1], theta=10000.0. + struct Embeddings1DConnector : public GGMLBlock { + protected: + int num_heads; + int head_dim; + int64_t inner_dim; + int num_layers; + int num_registers; // 0 disables the learnable-registers path. + float theta; + std::vector max_pos; + bool apply_gated_attention; + float norm_eps; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + if (num_registers > 0) { + // Python: learnable_registers = Parameter(rand(num_registers, inner_dim) * 2 - 1) + // ggml ne layout: innermost = inner_dim, so [inner_dim, num_registers]. + params["learnable_registers"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, inner_dim, num_registers); + } + } + + public: + Embeddings1DConnector() = default; + Embeddings1DConnector(int num_heads, int head_dim, int num_layers, + int num_registers = 128, + float theta = 10000.0f, + const std::vector& max_pos = {1}, + bool apply_gated_attention = false, + float norm_eps = 1e-6f) + : num_heads(num_heads), head_dim(head_dim), + inner_dim(static_cast(num_heads) * head_dim), + num_layers(num_layers), num_registers(num_registers), + theta(theta), max_pos(max_pos), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps) { + for (int i = 0; i < num_layers; ++i) { + blocks["transformer_1d_blocks." + std::to_string(i)] = + std::make_shared(inner_dim, num_heads, head_dim, + apply_gated_attention, norm_eps); + } + } + + int64_t get_inner_dim() const { return inner_dim; } + int get_num_registers() const { return num_registers; } + int get_num_layers() const { return num_layers; } + + ggml_tensor* get_learnable_registers() { + auto it = params.find("learnable_registers"); + return it == params.end() ? nullptr : it->second; + } + + std::shared_ptr get_block(int i) { + return std::dynamic_pointer_cast( + blocks["transformer_1d_blocks." + std::to_string(i)]); + } + + // hidden_states: ne [inner_dim, T, B] + // pe: ne [inner_dim, T, 2] packed cos/sin + // mask: additive attention mask (or nullptr) + // + // NOTE: this currently skips `_replace_padded_with_learnable_registers` — + // callers must guarantee the input is already register-substituted (or no + // padding is present). Handling the register replacement in ggml requires + // boolean indexing/scatter semantics that we defer. + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* pe, + ggml_tensor* mask = nullptr) { + for (int i = 0; i < num_layers; ++i) { + auto block = std::dynamic_pointer_cast( + blocks["transformer_1d_blocks." + std::to_string(i)]); + hidden_states = block->forward(ctx, hidden_states, pe, mask); + } + hidden_states = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + return hidden_states; + } + }; + + // Which feature-extractor flavor the runner uses. V1 (19B) has a single + // Linear(flat_dim → inner_dim, bias=False) named `aggregate_embed.weight`, with + // CPU pre-norm via _norm_and_concat_padded_batch. V2 (22B) has two parallel + // Linears with bias (`video_aggregate_embed`, `audio_aggregate_embed`) on a + // per-token RMS-normalized input; we currently wire only the video path. + enum class FeatureExtractorVersion { V1, V2 }; + + // Runner that bundles feature_extractor + Embeddings1DConnector (and optionally + // caption_projection for end-to-end parity testing). Used both by the parity + // test (default ctor args match dump_connector.py) and by LTX2GemmaConditioner + // (which passes real-checkpoint prefixes and sets include_caption_projection=false + // because the DiT owns caption_projection). + // + // Input is the already-normalized [B, T, flat_dim] tensor (see + // feature_extractor_normalize[_v2] for the CPU pre-processing). + struct LTX2ConnectorRunner : public GGMLRunner { + int64_t flat_dim; + int64_t connector_inner_dim; + int num_heads; + int head_dim; + int num_layers; + int num_registers; + int64_t caption_channels; + int64_t caption_hidden; + int64_t caption_out; + float theta; + std::vector max_pos; + bool include_caption_projection; + FeatureExtractorVersion fe_version; + int64_t source_dim; // V2 only: Gemma hidden_size used for rescale + + std::string feat_ext_prefix; + std::string connector_prefix; + std::string caption_proj_prefix; + + FeatureExtractorV1 feature_extractor_v1; + FeatureExtractorV2 feature_extractor_v2; + Embeddings1DConnector connector; + LTX::PixArtAlphaTextProjection caption_projection; + + std::vector pe_vec; + + // probe_stage selects the returned tensor. Stages <1 and >2 are shared + // between V1 and V2; 1 and 2 are legacy V1 parity probes (after block 0/1) + // and only work when num_layers >= 2. For V2 (production use), stage 3 + // (final rms_norm) is what the conditioner calls. + // 0 = after feature_extractor (+ graph-side rescale for V2) + // 1 = after connector block 0 + // 2 = after connector block 1 + // 3 = after all blocks + final rms_norm (connector output) + // 4 = after caption_projection (requires include_caption_projection) + int probe_stage = 3; + + // Target sequence length fed into the 1D connector. Python's + // LTXVGemmaTokenizer pads to max_length=1024 so the connector always sees + // 1024 tokens with learnable_registers tiled max_length/num_registers times. + // A value of 0 falls back to num_registers (the old, compact behaviour used + // by the parity dumper). Real inference MUST set this to match the Python + // tokenizer max_length (1024) — see LTX-2 ti2vid pipelines. + int target_seq_len = 0; + void set_target_seq_len(int len) { target_seq_len = len; } + + LTX2ConnectorRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + int64_t flat_dim, + int num_heads, + int head_dim, + int num_layers, + int num_registers, + int64_t caption_channels = 0, + int64_t caption_hidden = 0, + int64_t caption_out = 0, + float theta = 10000.0f, + const std::vector& max_pos = {1}, + const String2TensorStorage& tsm = {}, + bool include_caption_projection = true, + const std::string& feat_ext_prefix = "feature_extractor", + const std::string& connector_prefix = "connector", + const std::string& caption_proj_prefix = "caption_projection", + FeatureExtractorVersion fe_version = FeatureExtractorVersion::V1, + int64_t source_dim = 0, + bool apply_gated_attention = false) + : GGMLRunner(backend, offload_params_to_cpu), + flat_dim(flat_dim), + connector_inner_dim(static_cast(num_heads) * head_dim), + num_heads(num_heads), head_dim(head_dim), num_layers(num_layers), + num_registers(num_registers), + caption_channels(caption_channels), + caption_hidden(caption_hidden), + caption_out(caption_out), + theta(theta), max_pos(max_pos), + include_caption_projection(include_caption_projection), + fe_version(fe_version), + source_dim(source_dim), + feat_ext_prefix(feat_ext_prefix), + connector_prefix(connector_prefix), + caption_proj_prefix(caption_proj_prefix) { + if (fe_version == FeatureExtractorVersion::V2) { + GGML_ASSERT(source_dim > 0 && "FeatureExtractorV2 needs Gemma source_dim for the sqrt-rescale"); + feature_extractor_v2 = FeatureExtractorV2(flat_dim, source_dim, connector_inner_dim); + feature_extractor_v2.init(params_ctx, tsm, feat_ext_prefix); + } else { + feature_extractor_v1 = FeatureExtractorV1(flat_dim, connector_inner_dim); + feature_extractor_v1.init(params_ctx, tsm, feat_ext_prefix); + } + connector = Embeddings1DConnector(num_heads, head_dim, num_layers, + num_registers, theta, max_pos, + apply_gated_attention); + connector.init(params_ctx, tsm, connector_prefix); + if (include_caption_projection) { + caption_projection = LTX::PixArtAlphaTextProjection(caption_channels, caption_hidden, caption_out); + caption_projection.init(params_ctx, tsm, caption_proj_prefix); + } + } + + std::string get_desc() override { return "ltx2-connector"; } + + void get_param_tensors(std::map& tensors, + const std::string /*unused*/ = "") { + if (fe_version == FeatureExtractorVersion::V2) { + feature_extractor_v2.get_param_tensors(tensors, feat_ext_prefix); + } else { + feature_extractor_v1.get_param_tensors(tensors, feat_ext_prefix); + } + connector.get_param_tensors(tensors, connector_prefix); + if (include_caption_projection) { + caption_projection.get_param_tensors(tensors, caption_proj_prefix); + } + } + + // Build the full graph. probe_stage selects the final returned tensor: + // 0: after feature_extractor (shape [connector_inner_dim, T, B]) + // 1: after connector block 0 (V1 parity probe, legacy) + // 2: after connector block 1 (V1 parity probe, legacy) + // 3: after all connector blocks + final rms_norm + // 4: after caption_projection (needs include_caption_projection=true) + ggml_cgraph* build_graph(const sd::Tensor& normed_in) { + ggml_cgraph* gf = new_graph_custom(LTX::LTX_GRAPH_SIZE); + + ggml_tensor* x = make_input(normed_in); // ne [flat_dim, T, B] + int64_t T = x->ne[1]; + + auto runner_ctx = get_context(); + + // Step 1: feature_extractor → [inner_dim, T, B]. + ggml_tensor* feat = nullptr; + if (fe_version == FeatureExtractorVersion::V2) { + feat = feature_extractor_v2.forward_video(&runner_ctx, x); + } else { + feat = feature_extractor_v1.forward(&runner_ctx, x); + } + + // Step 1.5: Pad to the target length by filling the tail with + // learnable_registers (tiled when target > num_registers). + // + // Python reference: `_replace_padded_with_learnable_registers` in + // ltx_core/text_encoders/gemma/embeddings_connector.py. It: + // 1. tiles learnable_registers by (seq_len / num_registers) so the tiled + // buffer covers the whole sequence (seq_len == tokenizer max_length), + // 2. moves real tokens to [0, T_real), + // 3. fills [T_real, seq_len) with tiled_registers[T_real, seq_len). + // + // The caller (conditioner.hpp) already does step 2 on CPU and passes feat + // as [inner_dim, T_real, B]. We pick the target length in this order of + // preference: (a) explicit target_seq_len (set by the conditioner to + // Gemma's max_length), (b) num_registers (legacy/parity default). + // + // Tiling is implemented with a ggml_repeat into a [inner_dim, target, B] + // destination — cheap on GPU and matches torch.tile semantics for the + // innermost tiling axis. + const int num_registers = connector.get_num_registers(); + int64_t target_len = + target_seq_len > 0 ? static_cast(target_seq_len) + : static_cast(num_registers); + if (num_registers > 0 && target_len > 0 && T < target_len) { + GGML_ASSERT(target_len % num_registers == 0 && + "target_seq_len must be a multiple of num_registers " + "(Embeddings1DConnector tiles learnable_registers)."); + auto regs = connector.get_learnable_registers(); // [inner_dim, num_registers] + GGML_ASSERT(regs != nullptr && "learnable_registers not initialized"); + + // Build the tiled registers tensor [inner_dim, target_len] by + // repeating learnable_registers along axis 1. + ggml_tensor* tiled = regs; + if (target_len > num_registers) { + auto repeat_tgt = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, + connector_inner_dim, target_len); + tiled = ggml_repeat(compute_ctx, regs, repeat_tgt); + } + + // Slice rows [T : target_len] along axis 1 to get the padding tail. + auto regs_slice = ggml_ext_slice(compute_ctx, tiled, 1, + static_cast(T), + static_cast(target_len)); // [inner_dim, target-T] + regs_slice = ggml_reshape_3d(compute_ctx, ggml_cont(compute_ctx, regs_slice), + connector_inner_dim, + target_len - T, + 1); + feat = ggml_concat(compute_ctx, feat, regs_slice, 1); // [inner_dim, target, B] + T = target_len; + } + + // Build only the subgraph up to the selected probe stage. The final + // named result is the LAST node added (GGMLRunner::get_compute_graph + // picks `ggml_graph_node(gf, -1)`). + ggml_tensor* out = feat; + if (probe_stage >= 1) { + // Precompute 1D RoPE for connector. + auto positions = LTXRope::gen_1d_positions(static_cast(T)); + pe_vec = LTXRope::precompute_freqs_cis_interleaved(positions, + static_cast(connector_inner_dim), + theta, max_pos); + auto pe = ggml_new_tensor_3d(compute_ctx, GGML_TYPE_F32, connector_inner_dim, T, 2); + set_backend_tensor_data(pe, pe_vec.data()); + + // Stages 1, 2: legacy V1 parity probes — stop after block 0/1. + // Stages 3+: production path — run all blocks and the final rms_norm. + if (probe_stage == 1 || probe_stage == 2) { + int blocks_to_run = probe_stage; // 1 → block 0 only; 2 → blocks 0 and 1 + for (int i = 0; i < blocks_to_run && i < num_layers; ++i) { + out = connector.get_block(i)->forward(&runner_ctx, out, pe, nullptr); + } + } else { + for (int i = 0; i < num_layers; ++i) { + out = connector.get_block(i)->forward(&runner_ctx, out, pe, nullptr); + } + out = LTX::parameterless_rms_norm(compute_ctx, out, 1e-6f); + } + } + if (probe_stage >= 4 && include_caption_projection) { + out = caption_projection.forward(&runner_ctx, out); + } + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, const sd::Tensor& normed_in, int stage = 4) { + probe_stage = stage; + auto get_graph = [&]() -> ggml_cgraph* { return build_graph(normed_in); }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } + }; + +} // namespace LTXConnector + +#endif // __LTX_CONNECTOR_HPP__ diff --git a/src/ltx_rope.hpp b/src/ltx_rope.hpp new file mode 100644 index 000000000..c058e7452 --- /dev/null +++ b/src/ltx_rope.hpp @@ -0,0 +1,350 @@ +#ifndef __LTX_ROPE_HPP__ +#define __LTX_ROPE_HPP__ + +#include +#include +#include "ggml_extend.hpp" + +namespace LTXRope { + // Generate a log-spaced frequency grid from 1 to theta, scaled by pi/2. + // Returns num_freqs = inner_dim / (2 * n_pos_dims) values. + // + // Python reference: generate_freq_grid_pytorch in ltx_core/model/transformer/rope.py. + // We mirror the fp32 linspace path byte-exactly: torch.linspace(0., 1., N, fp32) + // produces indices computed as `i * (1/(N-1))` in fp32 (start + step*i), so we + // replicate that order of operations rather than `(double)i / (N-1)` which + // differs by ~1 ULP at the tail. That 1-ULP freq drift becomes ~3-5 ULPs in + // the freq value and ~5e-2 cos/sin error once the angle hits 1e5 radians at + // T=8. `pow(theta, v)` is then computed in fp32 (std::powf) to match. + __STATIC_INLINE__ std::vector generate_freq_grid(float theta, + int n_pos_dims, + int inner_dim) { + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + std::vector indices(num_freqs); + // Compute in fp64 then cast. For the video DiT (3D RoPE, max_pos normalizes + // to [0, 1]) fp32 would be fine, but the connector's 1D RoPE uses max_pos=[1] + // so raw integer positions feed into the angle → arguments reach ~2e5 radians + // at T=8. At that scale, fp32 libm `exp(t*log(theta))` drifts ~1 ULP in + // the freq value, cascading to ~5e-2 cos/sin diffs vs the numpy-fp64 reference + // used by the connector dumper (`double_precision_rope=True`). fp64 pow matches + // numpy closely enough to land connector parity at ~2e-3 max_abs. + constexpr double pi_half = 1.57079632679489661923; + double theta_d = static_cast(theta); + for (int i = 0; i < num_freqs; ++i) { + double t = num_freqs == 1 ? 0.0 : static_cast(i) / (num_freqs - 1); + indices[i] = static_cast(std::pow(theta_d, t) * pi_half); + } + return indices; + } + + // Build a 3D indices grid for a video latent of shape (F, H, W). + // + // Mirrors the real LTX-2 pipeline: VideoLatentTools.create_initial_state -> + // get_patch_grid_bounds -> get_pixel_coords (ltx_core/components/patchifiers.py and + // ltx_core/tools.py). Per-axis behaviour: + // latent_coords[axis] = [f, f+1] (integer latent indices per patch) + // pixel_coords[axis] = latent_coords * scale_factors[axis] + // if causal_fix: pixel_coords[0] = clamp(pixel_coords[0] + 1 - scale_factors[0], 0, +) + // positions[0] /= fps (temporal axis only) + // if use_middle_indices_grid: pos = midpoint(start, end); else pos = start + // + // Defaults ({1,1,1}, causal_fix=false, fps=1) preserve the parity-test flow, which + // feeds the Python model the simplified (f, h, w) positions directly. Real inference + // MUST pass scale_factors={8, 32, 32} and causal_fix=true (the LTX-2 VAE scale). + // + // Returns a 3×(F*H*W) matrix with layout [axis][token_idx]. + __STATIC_INLINE__ std::vector> gen_video_positions(int F, + int H, + int W, + bool use_middle_indices_grid = true, + float fps = 1.0f, + const std::vector& scale_factors = {1, 1, 1}, + bool causal_fix = false) { + GGML_ASSERT(fps > 0.0f); + GGML_ASSERT(scale_factors.size() == 3); + int total = F * H * W; + std::vector> pos(3, std::vector(total, 0.f)); + const float s0 = static_cast(scale_factors[0]); + const float s1 = static_cast(scale_factors[1]); + const float s2 = static_cast(scale_factors[2]); + for (int f = 0; f < F; ++f) { + float t_s = static_cast(f) * s0; + float t_e = static_cast(f + 1) * s0; + if (causal_fix) { + const float shift = 1.f - s0; + t_s = std::max(0.f, t_s + shift); + t_e = std::max(0.f, t_e + shift); + } + t_s /= fps; + t_e /= fps; + for (int h = 0; h < H; ++h) { + float h_s = static_cast(h) * s1; + float h_e = static_cast(h + 1) * s1; + for (int w = 0; w < W; ++w) { + float w_s = static_cast(w) * s2; + float w_e = static_cast(w + 1) * s2; + int idx = (f * H + h) * W + w; + if (use_middle_indices_grid) { + pos[0][idx] = (t_s + t_e) * 0.5f; + pos[1][idx] = (h_s + h_e) * 0.5f; + pos[2][idx] = (w_s + w_e) * 0.5f; + } else { + pos[0][idx] = t_s; + pos[1][idx] = h_s; + pos[2][idx] = w_s; + } + } + } + } + return pos; + } + + // Precompute interleaved cos/sin freqs for LTX-2 RoPE. + // positions[axis][token]: fractional-ready float positions, size n_pos_dims * T. + // max_pos: normalisation per axis, e.g. {20, 2048, 2048}. + // Returns a packed [2, T, inner_dim] vector: slice [0] = cos, slice [1] = sin. + __STATIC_INLINE__ std::vector precompute_freqs_cis_interleaved(const std::vector>& positions, + int inner_dim, + float theta = 10000.f, + const std::vector& max_pos = {20, 2048, 2048}) { + int n_pos_dims = static_cast(positions.size()); + GGML_ASSERT(n_pos_dims > 0); + GGML_ASSERT(static_cast(max_pos.size()) == n_pos_dims); + int T = static_cast(positions[0].size()); + + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + int pad_size = inner_dim - (num_freqs * n_pos_dims * 2); + + std::vector freq_grid = generate_freq_grid(theta, n_pos_dims, inner_dim); // [num_freqs] + + std::vector pe(2 * T * inner_dim, 0.f); + // Slice 0 (cos) starts at offset 0, slice 1 (sin) starts at T * inner_dim. + size_t cos_off = 0; + size_t sin_off = static_cast(T) * inner_dim; + + // Initialise the pad region: cos = 1.0, sin = 0.0. + for (int t = 0; t < T; ++t) { + for (int i = 0; i < pad_size; ++i) { + pe[cos_off + static_cast(t) * inner_dim + i] = 1.f; + } + } + + for (int t = 0; t < T; ++t) { + std::vector frac_pos(n_pos_dims); + for (int d = 0; d < n_pos_dims; ++d) { + frac_pos[d] = positions[d][t] / static_cast(max_pos[d]); + } + // Freq layout after flatten is [f * n_pos_dims + d], so pair index p = f*n_pos_dims + d. + // After repeat_interleave(2), each pair p corresponds to slots (2p, 2p+1) in the [pad_size:] region. + // + // Note: compute cos/sin in double precision then cast to float. At high frequencies + // (theta^1 * pi/2 ≈ 15708) times (2*t - 1), the angle reaches hundreds of thousands of + // radians — fp32 argument reduction in std::cosf/sinf loses enough precision to drift + // ~5e-2 from PyTorch's tensor-level cos/sin. Python's torch.cos does the reduction + // against a more precise modulus internally (matching fp64 behavior closely enough). + for (int f = 0; f < num_freqs; ++f) { + for (int d = 0; d < n_pos_dims; ++d) { + double angle = static_cast(freq_grid[f]) * + (static_cast(frac_pos[d]) * 2.0 - 1.0); + float c = static_cast(std::cos(angle)); + float s = static_cast(std::sin(angle)); + int pair_i = f * n_pos_dims + d; + int slot0 = pad_size + 2 * pair_i; + int slot1 = pad_size + 2 * pair_i + 1; + pe[cos_off + static_cast(t) * inner_dim + slot0] = c; + pe[cos_off + static_cast(t) * inner_dim + slot1] = c; + pe[sin_off + static_cast(t) * inner_dim + slot0] = s; + pe[sin_off + static_cast(t) * inner_dim + slot1] = s; + } + } + } + return pe; + } + + // Apply LTX-2 interleaved rotary embedding to x. + // x: [inner_dim, T, B] (ggml ne order; logical shape [B, T, inner_dim]) + // cos, sin: [inner_dim, T, 1] (broadcast across batch) + // Returns x rotated, same shape as x. + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb_interleaved(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* cos_freq, + ggml_tensor* sin_freq) { + int64_t inner_dim = x->ne[0]; + int64_t T = x->ne[1]; + int64_t B = x->ne[2]; + GGML_ASSERT(inner_dim % 2 == 0); + + // Reshape to pairs: [2, inner_dim/2, T, B]. + auto x_pairs = ggml_reshape_4d(ctx, x, 2, inner_dim / 2, T, B); + + // Views: x_even (offset 0) and x_odd (offset nb[0]) each shape [1, inner_dim/2, T, B]. + auto x_even = ggml_view_4d(ctx, x_pairs, 1, inner_dim / 2, T, B, + x_pairs->nb[1], x_pairs->nb[2], x_pairs->nb[3], 0); + auto x_odd = ggml_view_4d(ctx, x_pairs, 1, inner_dim / 2, T, B, + x_pairs->nb[1], x_pairs->nb[2], x_pairs->nb[3], x_pairs->nb[0]); + x_even = ggml_cont(ctx, x_even); + x_odd = ggml_cont(ctx, x_odd); + + // Rotated pair (−x_odd, x_even) → concat along dim 0 → [2, inner_dim/2, T, B]. + auto neg_x_odd = ggml_scale(ctx, x_odd, -1.f); + auto rotated = ggml_concat(ctx, neg_x_odd, x_even, 0); + rotated = ggml_reshape_3d(ctx, rotated, inner_dim, T, B); + + // out = x * cos + rotated * sin + auto out = ggml_add(ctx, ggml_mul(ctx, x, cos_freq), ggml_mul(ctx, rotated, sin_freq)); + return out; + } + + // Precompute SPLIT cos/sin freqs for LTX-2.3 DiT. Python reference: + // `precompute_freqs_cis(..., rope_type=LTXRopeType.SPLIT)`. + // - Unlike the interleaved variant, freqs are NOT repeat_interleaved; each of + // the inner_dim/2 frequencies is broadcast once across the corresponding + // position in the first AND second halves of head_dim. + // - cos/sin are reshaped to per-head: shape [B, T, H, head_dim/2]. + // - We pack both into a single buffer of ne [head_dim/2, num_heads, T, 2] + // (slice 0 = cos, slice 1 = sin), matching the interleaved helper's + // single-buffer convention. split_pe_split() below slices that back. + // + // freqs flattened length is num_freqs * n_pos_dims; when it's less than + // inner_dim/2, the leading (pad_size) slots are filled cos=1, sin=0, matching + // Python's `split_freqs_cis`. + __STATIC_INLINE__ std::vector precompute_freqs_cis_split(const std::vector>& positions, + int inner_dim, + int num_heads, + float theta = 10000.f, + const std::vector& max_pos = {20, 2048, 2048}) { + int n_pos_dims = static_cast(positions.size()); + GGML_ASSERT(n_pos_dims > 0); + GGML_ASSERT(static_cast(max_pos.size()) == n_pos_dims); + GGML_ASSERT(inner_dim % (2 * num_heads) == 0); + int T = static_cast(positions[0].size()); + int half_dim = inner_dim / 2; // per-token freq count + int head_dim2 = half_dim / num_heads; // per-head freq count + + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + int current = num_freqs * n_pos_dims; // pre-pad flat freq count + int pad_size = half_dim - current; + GGML_ASSERT(pad_size >= 0); + + std::vector freq_grid = generate_freq_grid(theta, n_pos_dims, inner_dim); + + // Output layout (ne): [head_dim/2, num_heads, T, 2]. Flat index: + // (slice=cos/sin)*T*num_heads*head_dim2 + t*num_heads*head_dim2 + h*head_dim2 + k + std::vector pe(2 * T * num_heads * head_dim2, 0.f); + size_t cos_off = 0; + size_t sin_off = static_cast(T) * num_heads * head_dim2; + + // Pad region (first `pad_size` columns of the per-token freq vector): cos=1, sin=0. + // Per-head reshape means pad_size slots at the start of the head-major flat + // vector. Since cos/sin for a token are stored as [h=0 head_dim2, h=1 head_dim2, …], + // the pad falls in the first pad_size consecutive positions across the head groups. + for (int t = 0; t < T; ++t) { + for (int p = 0; p < pad_size; ++p) { + int h = p / head_dim2; + int k = p % head_dim2; + size_t dst = static_cast(t) * num_heads * head_dim2 + h * head_dim2 + k; + pe[cos_off + dst] = 1.f; + pe[sin_off + dst] = 0.f; + } + } + + constexpr double pi_half = 1.57079632679489661923; + (void)pi_half; + for (int t = 0; t < T; ++t) { + std::vector frac_pos(n_pos_dims); + for (int d = 0; d < n_pos_dims; ++d) { + frac_pos[d] = positions[d][t] / static_cast(max_pos[d]); + } + // Non-pad slots start at column `pad_size` in the flat per-token freq vector. + // Python layout: freqs = (indices * (fractional*2-1)).transpose(-1,-2).flatten(2). + // With indices shape [num_freqs] and fractional [n_pos_dims], after broadcast + // and transpose the order is [f * n_pos_dims + d]. Slot index in the padded + // per-token vector = pad_size + f*n_pos_dims + d. + for (int f = 0; f < num_freqs; ++f) { + for (int d = 0; d < n_pos_dims; ++d) { + double angle = static_cast(freq_grid[f]) * + (static_cast(frac_pos[d]) * 2.0 - 1.0); + float c = static_cast(std::cos(angle)); + float s = static_cast(std::sin(angle)); + int flat_slot = pad_size + f * n_pos_dims + d; + int h = flat_slot / head_dim2; + int k = flat_slot % head_dim2; + size_t dst = static_cast(t) * num_heads * head_dim2 + h * head_dim2 + k; + pe[cos_off + dst] = c; + pe[sin_off + dst] = s; + } + } + } + return pe; + } + + // Split-half rotary embedding. Python: apply_split_rotary_emb. + // first_half = x[..., 0:head_dim/2] + // second_half = x[..., head_dim/2:head_dim] + // out = concat(first*cos - second*sin, second*cos + first*sin, dim=last) + // Operates per-head. x ne=[inner_dim, T, B]; pe tensors (cos/sin) ne=[head_dim/2, num_heads, T, 1]. + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb_split(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* cos_freq, + ggml_tensor* sin_freq, + int num_heads) { + int64_t inner_dim = x->ne[0]; + int64_t T = x->ne[1]; + int64_t B = x->ne[2]; + GGML_ASSERT(inner_dim % (2 * num_heads) == 0); + int64_t head_dim = inner_dim / num_heads; + int64_t half = head_dim / 2; + + // Reshape x [inner_dim, T, B] → [head_dim, num_heads, T, B], then split halves. + auto x4 = ggml_reshape_4d(ctx, x, head_dim, num_heads, T, B); + + // first_half view: offset 0, shape [half, num_heads, T, B]. + auto first = ggml_view_4d(ctx, x4, half, num_heads, T, B, + x4->nb[1], x4->nb[2], x4->nb[3], 0); + // second_half view: offset = half * sizeof(el). + auto second = ggml_view_4d(ctx, x4, half, num_heads, T, B, + x4->nb[1], x4->nb[2], x4->nb[3], half * x4->nb[0]); + first = ggml_cont(ctx, first); + second = ggml_cont(ctx, second); + + // cos/sin ne [half, num_heads, T, 1] broadcast on B axis with first/second [half, num_heads, T, B]. + auto first_out = ggml_sub(ctx, ggml_mul(ctx, first, cos_freq), + ggml_mul(ctx, second, sin_freq)); + auto second_out = ggml_add(ctx, ggml_mul(ctx, second, cos_freq), + ggml_mul(ctx, first, sin_freq)); + + // Re-concat along dim 0 (head_dim) → [head_dim, num_heads, T, B]. + auto joined = ggml_concat(ctx, first_out, second_out, 0); + joined = ggml_reshape_3d(ctx, joined, inner_dim, T, B); + return joined; + } + + // Slice a packed split pe buffer of ne [half, num_heads, T, 2] into cos (slice 0) + // and sin (slice 1) views, each ne=[half, num_heads, T, 1]. + __STATIC_INLINE__ std::pair split_pe_split(ggml_context* ctx, ggml_tensor* pe) { + int64_t half = pe->ne[0]; + int64_t num_heads = pe->ne[1]; + int64_t T = pe->ne[2]; + auto cos_freq = ggml_view_4d(ctx, pe, half, num_heads, T, 1, + pe->nb[1], pe->nb[2], pe->nb[3], 0); + auto sin_freq = ggml_view_4d(ctx, pe, half, num_heads, T, 1, + pe->nb[1], pe->nb[2], pe->nb[3], pe->nb[3]); + return {cos_freq, sin_freq}; + } + + // Convenience: split a packed [2, T, inner_dim] pe tensor (slice 0 = cos, slice 1 = sin) + // into two views usable as cos/sin operands. + __STATIC_INLINE__ std::pair split_pe(ggml_context* ctx, ggml_tensor* pe) { + // pe: [inner_dim, T, 2] in ggml ne order. + int64_t inner_dim = pe->ne[0]; + int64_t T = pe->ne[1]; + auto cos_freq = ggml_view_3d(ctx, pe, inner_dim, T, 1, pe->nb[1], pe->nb[2], 0); + auto sin_freq = ggml_view_3d(ctx, pe, inner_dim, T, 1, pe->nb[1], pe->nb[2], pe->nb[2]); + return {cos_freq, sin_freq}; + } +}; // namespace LTXRope + +#endif // __LTX_ROPE_HPP__ diff --git a/src/ltxv.hpp b/src/ltxv.hpp index fb37dbe02..0e493443d 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -5,9 +5,13 @@ namespace LTXV { + enum class SpatialPadding { ZEROS, REFLECT }; + class CausalConv3d : public GGMLBlock { protected: int time_kernel_size; + int spatial_kernel_size; + SpatialPadding spatial_padding; public: CausalConv3d(int64_t in_channels, @@ -15,52 +19,98 @@ namespace LTXV { int kernel_size = 3, std::tuple stride = {1, 1, 1}, int dilation = 1, - bool bias = true) { - time_kernel_size = kernel_size / 2; - blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, + bool bias = true, + SpatialPadding padding_mode = SpatialPadding::ZEROS) { + // Python reference: self.time_kernel_size = kernel_size[0] — the full temporal kernel. + // Earlier revisions of this file used `kernel_size / 2` which under-padded by a factor of 2 for k>=3 + // and padded 1 frame when k=1/2 where no padding was expected. Match Python verbatim. + time_kernel_size = kernel_size; + spatial_kernel_size = kernel_size; + spatial_padding = padding_mode; + // When using reflect padding we do it manually in forward(), so the inner Conv3d + // must run with spatial padding=0. For zeros mode the Conv3d handles padding itself. + int conv_pad_hw = (padding_mode == SpatialPadding::ZEROS) ? (kernel_size / 2) : 0; + blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, out_channels, {kernel_size, kernel_size, kernel_size}, stride, - {0, kernel_size / 2, kernel_size / 2}, + {0, conv_pad_hw, conv_pad_hw}, {dilation, 1, 1}, bias)); } + // Helper: replicate the given single-frame tensor `count` times along the depth axis. + // Returns a [IW, IH, count, N*IC] tensor. count must be >= 1. + static ggml_tensor* repeat_frame(ggml_context* ctx, ggml_tensor* frame, int count) { + auto out = frame; + for (int i = 1; i < count; i++) { + out = ggml_concat(ctx, out, frame, 2); + } + return out; + } + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { - // x: [N*IC, ID, IH, IW] - // result: [N*OC, OD, OH, OW] - auto conv = std::dynamic_pointer_cast(blocks["conv"]); + // x logical shape: [N*IC, ID, IH, IW] (Python order); ggml ne: [IW, IH, ID, N*IC] + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto ggml_cx = ctx->ggml_ctx; + + int pad_front = 0; + int pad_back = 0; if (causal) { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < time_kernel_size - 1; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } - x = ggml_concat(ctx, first_frame_pad, x, 2); + pad_front = time_kernel_size - 1; } else { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - int64_t offset = h->nb[2] * h->ne[2]; + pad_front = (time_kernel_size - 1) / 2; + pad_back = (time_kernel_size - 1) / 2; + } - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } + if (pad_front > 0 || pad_back > 0) { + // Extract first frame as a [IW, IH, 1, N*IC] view on x along the depth axis (ne[2]). + auto first_frame = ggml_view_4d(ggml_cx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + first_frame = ggml_cont(ggml_cx, first_frame); - auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW] - last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW] - auto last_frame_pad = last_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2); + if (pad_front > 0) { + auto front_pad = repeat_frame(ggml_cx, first_frame, pad_front); + x = ggml_concat(ggml_cx, front_pad, x, 2); + } + if (pad_back > 0) { + auto last_frame = ggml_view_4d(ggml_cx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], (x->ne[2] - 1) * x->nb[2]); + last_frame = ggml_cont(ggml_cx, last_frame); + auto back_pad = repeat_frame(ggml_cx, last_frame, pad_back); + x = ggml_concat(ggml_cx, x, back_pad, 2); } + } - x = ggml_concat(ctx, first_frame_pad, x, 2); - x = ggml_concat(ctx, x, last_frame_pad, 2); + // Spatial reflect padding (H, W by k/2 each side). nn.Conv3d with padding_mode='reflect' + // mirrors the edge rows/cols: [a,b,c,d] with pad=1 → [b,a,b,c,d,c]. + if (spatial_padding == SpatialPadding::REFLECT) { + int pad = spatial_kernel_size / 2; + if (pad > 0) { + GGML_ASSERT(pad == 1 && "reflect padding only implemented for kernel=3 (pad=1)"); + x = ggml_cont(ggml_cx, x); + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + // H-axis reflect: top = row 1, bottom = row H-2. + auto row_top = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, W, 1, T, C, + x->nb[1], x->nb[2], x->nb[3], 1 * x->nb[1])); + auto row_bot = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, W, 1, T, C, + x->nb[1], x->nb[2], x->nb[3], (H - 2) * x->nb[1])); + x = ggml_concat(ggml_cx, row_top, x, 1); + x = ggml_concat(ggml_cx, x, row_bot, 1); + x = ggml_cont(ggml_cx, x); + W = x->ne[0]; H = x->ne[1]; T = x->ne[2]; C = x->ne[3]; + // W-axis reflect: left = col 1, right = col W-2. + auto col_left = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, 1, H, T, C, + x->nb[1], x->nb[2], x->nb[3], 1 * x->nb[0])); + auto col_right = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, 1, H, T, C, + x->nb[1], x->nb[2], x->nb[3], (W - 2) * x->nb[0])); + x = ggml_concat(ggml_cx, col_left, x, 0); + x = ggml_concat(ggml_cx, x, col_right, 0); + } } x = conv->forward(ctx, x); diff --git a/src/ltxvae.hpp b/src/ltxvae.hpp new file mode 100644 index 000000000..198d59fc7 --- /dev/null +++ b/src/ltxvae.hpp @@ -0,0 +1,913 @@ +#ifndef __LTXVAE_HPP__ +#define __LTXVAE_HPP__ + +#include "common_block.hpp" +#include "ltxv.hpp" // CausalConv3d +#include "ltxvae_primitives.hpp" // space/depth, pixel_norm, pcs_* +#include "vae.hpp" // VAE base class + +// LTX-2 video VAE. Companion to src/ltxvae_primitives.hpp (pure ggml ops) — +// this file adds the parameterized composition blocks (ResnetBlock3D, +// UNetMidBlock3D, SpaceToDepthDownsample, DepthToSpaceUpsample) and the +// VideoEncoder / VideoDecoder top-levels. +// +// Tensor convention throughout: B=1 collapsed; ggml ne=[W, H, T, C]. +// Weight naming mirrors the Python reference verbatim — see +// `/tmp/vae_ref/tensor_names.txt` for the canonical prefix layout. + +namespace LTXVAE { + + // ---------- TimestepEmbedder ---------- + // + // PixArtAlphaCombinedTimestepSizeEmbeddings with size_emb_dim=0. Python + // structure: `.timestep_embedder.linear_{1,2}`. Sinusoidal projection into + // TIME_PROJ_DIM (256) fed to a two-Linear MLP with SiLU between. + + struct TimestepEmbedder : public GGMLBlock { + protected: + int embedding_dim = 0; + static constexpr int TIME_PROJ_DIM = 256; + + public: + TimestepEmbedder() = default; + TimestepEmbedder(int embedding_dim) : embedding_dim(embedding_dim) { + blocks["timestep_embedder.linear_1"] = std::make_shared(TIME_PROJ_DIM, embedding_dim, true); + blocks["timestep_embedder.linear_2"] = std::make_shared(embedding_dim, embedding_dim, true); + } + + // timestep: ne=[B]. Returns ne=[embedding_dim, B]. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { + auto l1 = std::dynamic_pointer_cast(blocks["timestep_embedder.linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["timestep_embedder.linear_2"]); + auto proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, TIME_PROJ_DIM, 10000, 1.0f); + auto h = l1->forward(ctx, proj); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + return l2->forward(ctx, h); + } + }; + + // ---------- ResnetBlock3D ---------- + // + // Python forward (when timestep_conditioning=True): + // h = norm1(x) [PixelNorm] + // ada = scale_shift_table + time_embed.reshape(B, 4, in_channels, 1,1,1) + // shift1, scale1, shift2, scale2 = ada.unbind(dim=1) + // h = h * (1 + scale1) + shift1 + // h = silu(h); h = conv1(h) + // h = norm2(h); h = h * (1 + scale2) + shift2 + // h = silu(h); h = conv2(h) + // return input + h + // + // When in_channels != out_channels, the skip path goes through + // norm3 = GroupNorm(num_groups=1, ...) + conv_shortcut (1×1×1 Conv3d). + // Our parity config keeps in==out, so we hard-disable that path until + // we land a use case that needs it. + // + // inject_noise is not yet supported (would require a seeded randn in ggml). + + struct ResnetBlock3D : public GGMLBlock { + protected: + int in_channels = 0; + int out_channels = 0; + bool timestep_conditioning = false; + bool has_shortcut = false; + float eps = 1e-6f; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + if (timestep_conditioning) { + // Python ne: [4, in_channels] → GGML ne [in_channels, 4]. + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, in_channels, 4); + } + } + + public: + ResnetBlock3D() = default; + ResnetBlock3D(int in_ch, int out_ch, bool timestep_cond, float eps_ = 1e-8f, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), + out_channels(out_ch), + timestep_conditioning(timestep_cond), + has_shortcut(in_ch != out_ch), + eps(eps_) { + blocks["conv1"] = std::make_shared( + in_ch, out_ch, 3, std::tuple{1,1,1}, 1, true, pad); + blocks["conv2"] = std::make_shared( + out_ch, out_ch, 3, std::tuple{1,1,1}, 1, true, pad); + if (has_shortcut) { + GGML_ABORT("ResnetBlock3D with in != out not yet implemented (norm3 + conv_shortcut)"); + } + } + + // x: ne=[W, H, T, C_in]. time_embed (optional): ne=[4*in_channels, B=1]. + // `causal` propagates to the inner CausalConv3d.forward calls. + // If traces is non-null, pushes intermediates in order: + // 0 post_norm1, 1 shift1, 2 scale1, 3 post_adaln1, 4 post_conv1, + // 5 post_norm2, 6 shift2, 7 scale2, 8 post_adaln2, 9 post_conv2, 10 final. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* time_embed = nullptr, + std::vector* traces = nullptr, bool causal = true) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + auto input = x; + auto h = pixel_norm(ctx->ggml_ctx, x, eps); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + ggml_tensor *shift1 = nullptr, *scale1 = nullptr; + ggml_tensor *shift2 = nullptr, *scale2 = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(time_embed != nullptr); + auto sst = params["scale_shift_table"]; // ne [in_channels, 4] + // time_embed has ne [4*in_channels, B=1]. Reshape to ne [in_channels, 4] (implicit B=1). + auto te = ggml_reshape_2d(ctx->ggml_ctx, time_embed, in_channels, 4); + auto ada = ggml_add(ctx->ggml_ctx, te, sst); // [in_channels, 4] + + shift1 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 0, 1); + scale1 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 1, 2); + shift2 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 2, 3); + scale2 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 3, 4); + if (traces) { + traces->push_back(ggml_cont(ctx->ggml_ctx, shift1)); + traces->push_back(ggml_cont(ctx->ggml_ctx, scale1)); + } + // Reshape happens below; the apply also happens below. + // Reshape each [in_channels, 1] → [1, 1, 1, in_channels] so they broadcast + // over (W, H, T) when added/multiplied with h [W, H, T, in_channels]. + shift1 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift1), 1, 1, 1, in_channels); + scale1 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale1), 1, 1, 1, in_channels); + shift2 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift2), 1, 1, 1, in_channels); + scale2 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale2), 1, 1, 1, in_channels); + + auto h_scaled = ggml_mul(ctx->ggml_ctx, h, scale1); + h = ggml_add(ctx->ggml_ctx, h, h_scaled); + h = ggml_add(ctx->ggml_ctx, h, shift1); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + } + + h = ggml_silu(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h, causal); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + h = pixel_norm(ctx->ggml_ctx, h, eps); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + if (timestep_conditioning) { + auto h_scaled = ggml_mul(ctx->ggml_ctx, h, scale2); + h = ggml_add(ctx->ggml_ctx, h, h_scaled); + h = ggml_add(ctx->ggml_ctx, h, shift2); + } + + h = ggml_silu(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h, causal); + + // in_channels == out_channels so skip is Identity (the `has_shortcut` path aborts above). + return ggml_add(ctx->ggml_ctx, h, input); + } + }; + + // ---------- UNetMidBlock3D ---------- + + struct UNetMidBlock3D : public GGMLBlock { + protected: + int in_channels = 0; + int num_layers = 0; + bool timestep_conditioning = false; + + public: + UNetMidBlock3D() = default; + UNetMidBlock3D(int in_ch, int num_layers, bool timestep_cond, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), num_layers(num_layers), timestep_conditioning(timestep_cond) { + for (int i = 0; i < num_layers; i++) { + blocks["res_blocks." + std::to_string(i)] = std::make_shared(in_ch, in_ch, timestep_cond, 1e-8f, pad); + } + if (timestep_cond) { + blocks["time_embedder"] = std::make_shared(in_ch * 4); + } + } + + // timestep: ne=[B=1] if conditioning enabled, else null. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* timestep = nullptr, + std::vector* traces = nullptr, bool causal = true) { + ggml_tensor* time_embed = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto te = std::dynamic_pointer_cast(blocks["time_embedder"]); + time_embed = te->forward(ctx, timestep); // ne=[4*in_channels, 1] + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, time_embed)); + } + for (int i = 0; i < num_layers; i++) { + auto res = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); + x = res->forward(ctx, x, time_embed, traces, causal); + } + return x; + } + }; + + // ---------- SpaceToDepthDownsample (encoder) ---------- + // + // Python forward: + // if stride[0]==2: x = cat([x[:,:,:1], x], dim=2) # duplicate first frame + // x_in = rearrange(x, "b c (d p1)(h p2)(w p3) -> b (c p1 p2 p3) d h w", ...) + // x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=group_size).mean(dim=2) + // x = self.conv(x, causal); x = rearrange(x, ...s2d...); return x + x_in + + struct SpaceToDepthDownsample : public GGMLBlock { + protected: + int in_channels = 0; + int out_channels = 0; + int p1 = 1, p2 = 1, p3 = 1; + int group_size = 1; + + // Helper: collapse group-size consecutive channels via mean along axis 2 + // (after reshaping [W,H,T,C_exp] → [W*H, T, g, C_new]). + ggml_tensor* group_mean_channel(ggml_context* ctx, ggml_tensor* x, int g) const { + if (g == 1) return x; + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], Cexp = x->ne[3]; + GGML_ASSERT(Cexp % g == 0); + int64_t C_new = Cexp / g; + // Reshape: merge W,H; split Cexp into [g, C_new] (g innermost-of-that-group, + // matching einops "(c g)" with g innermost). + auto y = ggml_reshape_4d(ctx, x, W * H, T, g, C_new); + // Move g to innermost (axis 0) for ggml_mean. + y = ggml_cont(ctx, ggml_permute(ctx, y, 1, 2, 0, 3)); // ne=[g, W*H, T, C_new] + y = ggml_mean(ctx, y); // ne=[1, W*H, T, C_new] + // Permute back & reshape to [W, H, T, C_new]. + y = ggml_cont(ctx, ggml_permute(ctx, y, 3, 0, 1, 2)); // ne=[W*H, T, C_new, 1] + y = ggml_reshape_4d(ctx, y, W, H, T, C_new); + return y; + } + + public: + SpaceToDepthDownsample() = default; + SpaceToDepthDownsample(int in_ch, int out_ch, std::tuple stride) + : in_channels(in_ch), out_channels(out_ch), + p1(std::get<0>(stride)), p2(std::get<1>(stride)), p3(std::get<2>(stride)) { + int prod = p1 * p2 * p3; + GGML_ASSERT((in_ch * prod) % out_ch == 0); + group_size = in_ch * prod / out_ch; + GGML_ASSERT(out_ch % prod == 0); + blocks["conv"] = std::make_shared(in_ch, out_ch / prod, 3); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + + // Duplicate first frame if temporal stride is 2. + if (p1 == 2) { + auto first = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + first = ggml_cont(ctx->ggml_ctx, first); + x = ggml_concat(ctx->ggml_ctx, first, x, 2); // prepend along T + } + + // Skip: s2d → group-mean. + auto x_in = space_to_depth(ctx->ggml_ctx, x, p1, p2, p3); + x_in = group_mean_channel(ctx->ggml_ctx, x_in, group_size); + + // Main: conv (preserves T because of causal padding, stride=1), then s2d. + auto y = conv->forward(ctx, x, causal); + y = space_to_depth(ctx->ggml_ctx, y, p1, p2, p3); + + return ggml_add(ctx->ggml_ctx, y, x_in); + } + }; + + // ---------- DepthToSpaceUpsample (decoder) ---------- + // + // For the parity test we only need residual=False (compress_time, compress_space). + // `compress_all` blocks with residual=True have a repeat-based skip path; we'll + // add that when a decoder config needs it. + + struct DepthToSpaceUpsample : public GGMLBlock { + protected: + int in_channels = 0; + int p1 = 1, p2 = 1, p3 = 1; + int reduction_factor = 1; + + public: + DepthToSpaceUpsample() = default; + DepthToSpaceUpsample(int in_ch, std::tuple stride, int reduction_factor = 1, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), + p1(std::get<0>(stride)), p2(std::get<1>(stride)), p3(std::get<2>(stride)), + reduction_factor(reduction_factor) { + int prod = p1 * p2 * p3; + int conv_out = prod * in_ch / reduction_factor; + blocks["conv"] = std::make_shared( + in_ch, conv_out, 3, std::tuple{1,1,1}, 1, true, pad); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + x = conv->forward(ctx, x, causal); + x = depth_to_space(ctx->ggml_ctx, x, p1, p2, p3); + if (p1 == 2) { + // Drop first frame along T to match Python x[:, :, 1:, ...]. + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + auto sliced = ggml_view_4d(ctx->ggml_ctx, x, + W, H, T - 1, C, + x->nb[1], x->nb[2], x->nb[3], x->nb[2]); // skip frame 0 + x = ggml_cont(ctx->ggml_ctx, sliced); + } + return x; + } + }; + + // ---------- PerChannelStatistics wrapper ---------- + // + // Python uses register_buffer("std-of-means", ...) and ("mean-of-means", ...) — + // dashed names which don't appear elsewhere in this codebase. We register them + // as tensors via init_params and carry the dashed names verbatim so loader + // name matching finds them. + + struct PerChannelStatisticsBlock : public GGMLBlock { + protected: + int latent_channels = 0; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + params["std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, latent_channels); + params["mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, latent_channels); + } + + public: + PerChannelStatisticsBlock() = default; + explicit PerChannelStatisticsBlock(int latent_channels) : latent_channels(latent_channels) {} + + ggml_tensor* normalize(GGMLRunnerContext* ctx, ggml_tensor* x) { + return pcs_normalize(ctx->ggml_ctx, x, params["mean-of-means"], params["std-of-means"]); + } + ggml_tensor* un_normalize(GGMLRunnerContext* ctx, ggml_tensor* x) { + return pcs_unnormalize(ctx->ggml_ctx, x, params["mean-of-means"], params["std-of-means"]); + } + }; + + // ---------- VideoEncoder ---------- + // + // The encoder config is a list of (block_name, block_config) tuples — we keep + // that shape in C++ via an EncoderBlockSpec. Only `res_x`, `compress_space_res`, + // `compress_time_res`, `compress_all_res` are handled here; more variants can + // be added as their use-cases land. `norm_layer` is pixel_norm only (group_norm + // would require new primitives). `latent_log_var` is UNIFORM only. + + enum class EncoderBlockKind { + RES_X, + COMPRESS_SPACE_RES, // stride=(1,2,2) + COMPRESS_TIME_RES, // stride=(2,1,1) + COMPRESS_ALL_RES, // stride=(2,2,2) + }; + + struct EncoderBlockSpec { + EncoderBlockKind kind; + int num_layers = 1; // used for RES_X + int multiplier = 2; // used for compress_*_res + }; + + struct VideoEncoder : public GGMLBlock { + protected: + int in_channels = 3; + int latent_channels = 128; + int patch_size = 4; + std::vector encoder_blocks; + float eps = 1e-6f; + + public: + VideoEncoder() = default; + VideoEncoder(int in_ch, int latent_ch, int patch, + const std::vector& enc_blocks) + : in_channels(in_ch), latent_channels(latent_ch), patch_size(patch), + encoder_blocks(enc_blocks) { + int feature_ch = latent_ch; + int cur_in = in_ch * patch * patch; // after patchify + + blocks["conv_in"] = std::make_shared(cur_in, feature_ch, 3); + + int cur_c = feature_ch; + for (size_t i = 0; i < encoder_blocks.size(); ++i) { + const auto& b = encoder_blocks[i]; + std::string key = "down_blocks." + std::to_string(i); + switch (b.kind) { + case EncoderBlockKind::RES_X: + blocks[key] = std::make_shared(cur_c, b.num_layers, /*timestep_cond=*/false); + break; + case EncoderBlockKind::COMPRESS_SPACE_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{1,2,2}); + cur_c *= b.multiplier; + break; + case EncoderBlockKind::COMPRESS_TIME_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{2,1,1}); + cur_c *= b.multiplier; + break; + case EncoderBlockKind::COMPRESS_ALL_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{2,2,2}); + cur_c *= b.multiplier; + break; + } + } + + // UNIFORM log-var: conv_out gets one extra channel for the shared logvar. + int conv_out_ch = latent_ch + 1; + blocks["conv_out"] = std::make_shared(cur_c, conv_out_ch, 3); + blocks["per_channel_statistics"] = std::make_shared(latent_ch); + } + + // sample: ne=[W, H, T, C=3] (B=1). Returns normalized latent ne=[W', H', T', latent_ch]. + // If trace_outputs is non-null, intermediates are pushed in this order: + // 0: post_patchify, 1: post_conv_in, 2..K-1: per down_block output, + // K: post_norm, K+1: post_conv_out, K+2: means_preNorm, K+3: latent. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* sample, + std::vector* trace_outputs = nullptr) { + // patchify (distinct channel ordering from the SpaceToDepthDownsample blocks; + // see `patchify` comment in ltxvae_primitives.hpp). + auto x = patchify(ctx->ggml_ctx, sample, 1, patch_size, patch_size); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + x = conv_in->forward(ctx, x); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + for (size_t i = 0; i < encoder_blocks.size(); ++i) { + std::string key = "down_blocks." + std::to_string(i); + switch (encoder_blocks[i].kind) { + case EncoderBlockKind::RES_X: { + auto b = std::dynamic_pointer_cast(blocks[key]); + x = b->forward(ctx, x, nullptr); + break; + } + default: { + auto b = std::dynamic_pointer_cast(blocks[key]); + x = b->forward(ctx, x); + break; + } + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + } + + x = pixel_norm(ctx->ggml_ctx, x, eps); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + x = ggml_silu(ctx->ggml_ctx, x); + + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + x = conv_out->forward(ctx, x); // ne=[W', H', T', latent_ch+1] + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + // UNIFORM log_var handling: means = x[:, :-1], we skip logvar entirely (it would be + // expanded then discarded after the chunk(2) split). Take the first latent_ch channels. + auto means = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], x->ne[2], latent_channels, + x->nb[1], x->nb[2], x->nb[3], 0); + means = ggml_cont(ctx->ggml_ctx, means); + if (trace_outputs) trace_outputs->push_back(means); + + auto pcs = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto latent = pcs->normalize(ctx, means); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, latent)); + return latent; + } + }; + + // ---------- VideoDecoder ---------- + + enum class DecoderBlockKind { + RES_X, + COMPRESS_SPACE, // stride=(1,2,2), residual=False + COMPRESS_TIME, // stride=(2,1,1), residual=False + COMPRESS_ALL, // stride=(2,2,2), residual configurable (default False here) + }; + + struct DecoderBlockSpec { + DecoderBlockKind kind; + int num_layers = 1; // RES_X + int multiplier = 1; // channel reduction factor for compress_* + }; + + struct VideoDecoder : public GGMLBlock { + protected: + int latent_channels = 128; + int out_channels = 3; + int patch_size = 4; + int base_channels = 128; + bool timestep_conditioning = true; + std::vector decoder_blocks; // stored in ENCODER-side order; forward reverses + float eps = 1e-6f; + int feature_channels = 0; + // Decoder uses `reflect` spatial padding by default per the Python reference + // (VideoDecoderConfigurator.from_config default). All CausalConv3d instances we + // construct below are handed this padding mode. + static constexpr LTXV::SpatialPadding PAD = LTXV::SpatialPadding::REFLECT; + // Python configurator defaults: `causal_decoder=False`. All our CausalConv3d.forward + // calls within the decoder should therefore use causal=False. (Encoder uses True.) + static constexpr bool DECODER_CAUSAL = false; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + if (timestep_conditioning) { + // Python: last_scale_shift_table = Parameter(torch.empty(2, feature_channels)). + params["last_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, feature_channels, 2); + // timestep_scale_multiplier: scalar. + params["timestep_scale_multiplier"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + } + } + + public: + VideoDecoder() = default; + VideoDecoder(int latent_ch, int out_ch, int patch, int base_ch, bool timestep_cond, + const std::vector& dec_blocks) + : latent_channels(latent_ch), out_channels(out_ch), patch_size(patch), + base_channels(base_ch), timestep_conditioning(timestep_cond), + decoder_blocks(dec_blocks) { + // Decoder's feature_channels = base_channels * 8 per LTX-2 default (3 upsample blocks × 2). + feature_channels = base_ch * 8; + + blocks["conv_in"] = std::make_shared( + latent_ch, feature_channels, 3, std::tuple{1,1,1}, 1, true, PAD); + + // Decoder config is stored in encoder-side order; construct up_blocks in REVERSED order + // (matching the Python `list(reversed(decoder_blocks))`). + int cur_c = feature_channels; + for (size_t i = 0; i < decoder_blocks.size(); ++i) { + const auto& b = decoder_blocks[decoder_blocks.size() - 1 - i]; + std::string key = "up_blocks." + std::to_string(i); + switch (b.kind) { + case DecoderBlockKind::RES_X: + blocks[key] = std::make_shared(cur_c, b.num_layers, timestep_conditioning, PAD); + break; + case DecoderBlockKind::COMPRESS_SPACE: + blocks[key] = std::make_shared(cur_c, std::tuple{1,2,2}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + case DecoderBlockKind::COMPRESS_TIME: + blocks[key] = std::make_shared(cur_c, std::tuple{2,1,1}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + case DecoderBlockKind::COMPRESS_ALL: + blocks[key] = std::make_shared(cur_c, std::tuple{2,2,2}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + } + } + + int final_out_ch = out_ch * patch * patch; + blocks["conv_out"] = std::make_shared( + cur_c, final_out_ch, 3, std::tuple{1,1,1}, 1, true, PAD); + + if (timestep_conditioning) { + blocks["last_time_embedder"] = std::make_shared(feature_channels * 2); + } + blocks["per_channel_statistics"] = std::make_shared(latent_ch); + } + + // Trace stage order (for parity debugging): + // 0 post_unnorm, 1 post_conv_in, 2..K-1 per up_block output, + // K post_pixel_norm (pre-ada), K+1 post_ada, K+2 post_conv_out, K+3 video_out. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* latent, ggml_tensor* timestep = nullptr, + std::vector* trace_outputs = nullptr) { + auto pcs = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto x = pcs->un_normalize(ctx, latent); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + // Earlier dump_vae.py used default causal=True for conv_in, but the real Python + // decoder.forward uses self.causal which is False — the dumper is now aligned. + x = conv_in->forward(ctx, x, DECODER_CAUSAL); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + for (size_t i = 0; i < decoder_blocks.size(); ++i) { + const auto& b = decoder_blocks[decoder_blocks.size() - 1 - i]; + std::string key = "up_blocks." + std::to_string(i); + if (b.kind == DecoderBlockKind::RES_X) { + auto blk = std::dynamic_pointer_cast(blocks[key]); + x = blk->forward(ctx, x, timestep_conditioning ? timestep : nullptr, trace_outputs, DECODER_CAUSAL); + } else { + auto blk = std::dynamic_pointer_cast(blocks[key]); + x = blk->forward(ctx, x, DECODER_CAUSAL); + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + } + + // Final norm + AdaLN + SiLU + conv_out. + x = pixel_norm(ctx->ggml_ctx, x, eps); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto te = std::dynamic_pointer_cast(blocks["last_time_embedder"]); + // Python multiplies the timestep by timestep_scale_multiplier BEFORE the embed. + auto tsm = params["timestep_scale_multiplier"]; // scalar [1] + auto t_scaled = ggml_mul(ctx->ggml_ctx, timestep, tsm); + auto time_embed = te->forward(ctx, t_scaled); // ne=[2*feature_channels, 1] + + auto sst = params["last_scale_shift_table"]; // ne=[feature_channels, 2] + auto te2 = ggml_reshape_2d(ctx->ggml_ctx, time_embed, feature_channels, 2); + auto ada = ggml_add(ctx->ggml_ctx, te2, sst); // ne=[feature_channels, 2] + + auto shift = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 0, 1); + auto scale = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 1, 2); + shift = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift), 1, 1, 1, feature_channels); + scale = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale), 1, 1, 1, feature_channels); + + auto x_scaled = ggml_mul(ctx->ggml_ctx, x, scale); + x = ggml_add(ctx->ggml_ctx, x, x_scaled); + x = ggml_add(ctx->ggml_ctx, x, shift); + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + x = ggml_silu(ctx->ggml_ctx, x); + + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + x = conv_out->forward(ctx, x, DECODER_CAUSAL); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + x = unpatchify(ctx->ggml_ctx, x, 1, patch_size, patch_size); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + return x; + } + }; + + // ---------- GGMLRunner wrappers ---------- + + struct VAEEncoderRunner : public GGMLRunner { + VideoEncoder encoder; + int in_channels; + int latent_channels; + int patch_size; + + VAEEncoderRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int in_ch, + int latent_ch, + int patch, + const std::vector& specs) + : GGMLRunner(backend, offload_params_to_cpu), + encoder(in_ch, latent_ch, patch, specs), + in_channels(in_ch), latent_channels(latent_ch), patch_size(patch) { + encoder.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltx2_vae_encoder"; } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + encoder.get_param_tensors(tensors, prefix); + } + + // stage_index==-1 returns the final latent; >=0 returns the matching trace. + // Full forward is always built so buffer allocation covers every declared input. + sd::Tensor compute(int n_threads, const sd::Tensor& video_tensor, + int stage_index = -1) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* x = make_input(video_tensor); + auto runner_ctx = get_context(); + std::vector traces; + ggml_tensor* final_out = encoder.forward(&runner_ctx, x, &traces); + ggml_build_forward_expand(gf, final_out); + if (stage_index >= 0 && stage_index < (int)traces.size()) { + ggml_build_forward_expand(gf, traces[stage_index]); + } + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } + }; + + struct VAEDecoderRunner : public GGMLRunner { + VideoDecoder decoder; + + VAEDecoderRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int latent_ch, + int out_ch, + int patch, + int base_ch, + bool timestep_cond, + const std::vector& specs) + : GGMLRunner(backend, offload_params_to_cpu), + decoder(latent_ch, out_ch, patch, base_ch, timestep_cond, specs) { + decoder.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltx2_vae_decoder"; } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + decoder.get_param_tensors(tensors, prefix); + } + + // stage_index==-1 returns the final video output; >=0 returns the matching trace. + // We always build the FULL forward graph so every declared input has a backend + // buffer; when stage_index is set we just re-expand that trace last so it becomes + // the final-result tensor that GGMLRunner::compute extracts. + sd::Tensor compute(int n_threads, + const sd::Tensor& latent_tensor, + const sd::Tensor& timestep_tensor, + int stage_index = -1) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* z = make_input(latent_tensor); + ggml_tensor* t = timestep_tensor.empty() ? nullptr : make_input(timestep_tensor); + auto runner_ctx = get_context(); + std::vector traces; + ggml_tensor* final_out = decoder.forward(&runner_ctx, z, t, &traces); + ggml_build_forward_expand(gf, final_out); + if (stage_index >= 0 && stage_index < (int)traces.size()) { + ggml_build_forward_expand(gf, traces[stage_index]); + } + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } + }; + + // ---------- Combined VAE runner ---------- + // + // Plumbs both VideoEncoder and VideoDecoder into the shared VAE interface so + // create_vae() in stable-diffusion.cpp can treat LTX-2 like any other VAE. + // + // Prefix convention matches the real LTX-2 checkpoint: `vae.encoder.*`, + // `vae.decoder.*`, `vae.per_channel_statistics.*`. Since our VideoEncoder and + // VideoDecoder each register a PerChannelStatisticsBlock under their own + // sub-prefix, we need the state dict to have nested PCS copies (which our + // parity dumper provides). Real LTX-2 checkpoints only ship the top-level + // `vae.per_channel_statistics.*` — see FUTURE note below. + + struct LTX2VAERunner : public VAE { + VideoEncoder encoder; + VideoDecoder decoder; + float decode_timestep = 0.05f; // Python default. + bool uses_timestep_conditioning = true; + + LTX2VAERunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + SDVersion version_, + int in_ch = 3, + int latent_ch = 128, + int patch = 4, + int decoder_base_ch = 128, + bool timestep_cond = true, + std::vector enc_specs = {}, + std::vector dec_specs = {}) + : VAE(version_, backend, offload_params_to_cpu), + encoder(in_ch, latent_ch, patch, enc_specs.empty() ? default_enc_specs() : enc_specs), + decoder(latent_ch, in_ch, patch, decoder_base_ch, timestep_cond, + dec_specs.empty() ? default_dec_specs() : dec_specs), + uses_timestep_conditioning(timestep_cond) { + // The VAE base class' scale_input=true would re-scale [0,1]→[-1,1] and back; + // LTX-2 expects inputs already in [-1, 1], so disable. + scale_input = false; + encoder.init(params_ctx, tensor_storage_map, prefix + ".encoder"); + decoder.init(params_ctx, tensor_storage_map, prefix + ".decoder"); + } + + // Production default: 1× compress_space_res, 1× compress_time_res, 2× compress_all_res, + // per the LTXV paper "Standard LTX Video configuration" docstring. + static std::vector default_enc_specs() { + return { + {EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + }; + } + static std::vector default_dec_specs() { + // Stored in encoder-side order; VideoDecoder reverses. + return { + {DecoderBlockKind::COMPRESS_SPACE, 1, 1}, + {DecoderBlockKind::COMPRESS_TIME, 1, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + }; + } + + // Real 22B LTX-2 video VAE spec, reverse-engineered from the checkpoint's + // weight shapes (encoder ch progression: 128 → 256 → 512 → 1024 → 1024): + // idx kind cur_c after + // 0 RES_X(4 layers) 128 + // 1 COMPRESS_SPACE_RES(m=2) 128 → 256 + // 2 RES_X(6 layers) 256 + // 3 COMPRESS_TIME_RES(m=2) 256 → 512 + // 4 RES_X(4 layers) 512 + // 5 COMPRESS_ALL_RES(m=2) 512 → 1024 + // 6 RES_X(2 layers) 1024 + // 7 COMPRESS_ALL_RES(m=1) 1024 → 1024 (spatial/temporal compress only) + // 8 RES_X(2 layers) 1024 + // Final conv_out: 1024 → 129 (128 latent + 1 logvar). + // Decoder mirrors in encoder-side order; VideoDecoder reverses at construct. + static std::vector ltx2_22b_enc_specs() { + return { + {EncoderBlockKind::RES_X, 4, 1}, + {EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {EncoderBlockKind::RES_X, 6, 1}, + {EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {EncoderBlockKind::RES_X, 4, 1}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + {EncoderBlockKind::RES_X, 2, 1}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 1}, + {EncoderBlockKind::RES_X, 2, 1}, + }; + } + static std::vector ltx2_22b_dec_specs() { + // Encoder-side order; VideoDecoder iterates in reverse at construct. + // Reverse iteration maps decoder_blocks[i] → up_blocks.[N-1-i], so the + // last entry here becomes up_blocks.0 (innermost, 1024 channels). + // + // Decoder channel progression (verified against real weight shapes): + // up_blocks.0 RES_X(2) @ 1024 + // up_blocks.1 COMPRESS_ALL(m=2) 1024 → 512 (conv:4096, d2s/8) + // up_blocks.2 RES_X(2) @ 512 + // up_blocks.3 COMPRESS_ALL(m=1) 512 → 512 (conv:4096, d2s/8) + // up_blocks.4 RES_X(4) @ 512 + // up_blocks.5 COMPRESS_TIME(m=2) 512 → 256 (conv:512, d2s/2) + // up_blocks.6 RES_X(6) @ 256 + // up_blocks.7 COMPRESS_SPACE(m=2) 256 → 128 (conv:512, d2s/4) + // up_blocks.8 RES_X(4) @ 128 + // Decoder's compress multipliers are NOT a mirror of the encoder's + // — the model is architecturally asymmetric (different res counts, different + // compress kinds at each level). Enc vs dec must each be traced separately. + return { + {DecoderBlockKind::RES_X, 4, 1}, + {DecoderBlockKind::COMPRESS_SPACE, 1, 2}, + {DecoderBlockKind::RES_X, 6, 1}, + {DecoderBlockKind::COMPRESS_TIME, 1, 2}, + {DecoderBlockKind::RES_X, 4, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + {DecoderBlockKind::RES_X, 2, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 2}, + {DecoderBlockKind::RES_X, 2, 1}, + }; + } + + std::string get_desc() override { return "ltx2_vae"; } + + void get_param_tensors(std::map& tensors, const std::string prefix) override { + encoder.get_param_tensors(tensors, prefix + ".encoder"); + decoder.get_param_tensors(tensors, prefix + ".decoder"); + } + + int get_encoder_output_channels(int /*input_channels*/) override { + return 128; // latent_channels + } + + sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, + std::shared_ptr /*rng*/) override { + return vae_output; + } + sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { + return latents; + } + sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { + return latents; + } + + ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { + // 10240 fit the 4-block parity test. The 22B VAE has 9 encoder + 9 + // decoder blocks with up to 6 res_blocks each, plus per-channel stats + // and conv_in/out. Bumped for safety. + ggml_cgraph* gf = new_graph_custom(65536); + ggml_tensor* z = make_input(z_tensor); + auto runner_ctx = get_context(); + ggml_tensor* out; + if (decode_graph) { + ggml_tensor* t = nullptr; + if (uses_timestep_conditioning) { + // Build a scalar timestep tensor inline (no external input needed). + t = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1); + ggml_set_name(t, "ltx2_vae_decode_timestep"); + decode_timestep_backing.resize(1); + decode_timestep_backing[0] = decode_timestep; + set_backend_tensor_data(t, decode_timestep_backing.data()); + } + out = decoder.forward(&runner_ctx, z, t); + } else { + out = encoder.forward(&runner_ctx, z); + } + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor _compute(const int n_threads, + const sd::Tensor& z, + bool decode_graph) override { + auto get_g = [&]() -> ggml_cgraph* { return build_graph(z, decode_graph); }; + auto out = take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + // Decoder output is [W, H, T, C]; decode_video_outputs + tensor_to_sd_image + // expect 5D [W, H, T, C, B] to pick the video-shaped index path. Add the + // trailing batch axis so the conversion uses the (iw, ih, frame, ic, 0) + // accessor (the 4D path assumes [W, H, C, F] which is the wrong layout). + if (decode_graph && !out.empty() && out.shape().size() == 4) { + auto s = out.shape(); + out.reshape_({s[0], s[1], s[2], s[3], 1}); + } + return out; + } + + private: + std::vector decode_timestep_backing; + }; + +} // namespace LTXVAE + +#endif // __LTXVAE_HPP__ diff --git a/src/ltxvae_primitives.hpp b/src/ltxvae_primitives.hpp new file mode 100644 index 000000000..05c301ecf --- /dev/null +++ b/src/ltxvae_primitives.hpp @@ -0,0 +1,212 @@ +#ifndef __LTXVAE_PRIMITIVES_HPP__ +#define __LTXVAE_PRIMITIVES_HPP__ + +#include "ggml.h" + +// Space-to-depth / depth-to-space helpers for the LTX-2 VAE. +// +// The VAE's `SpaceToDepthDownsample` and `DepthToSpaceUpsample` blocks compress +// or expand one or more of the (T, H, W) axes into/out of the channel axis. In +// einops notation (with B=1 elided): +// +// rearrange(x, "c (t p1) (h p2) (w p3) -> (c p1 p2 p3) t h w", ...) # space-to-depth +// rearrange(x, "(c p1 p2 p3) t h w -> c (t p1) (h p2) (w p3)", ...) # depth-to-space +// +// The einops grouping "(c p1 p2 p3)" puts p3 innermost (fastest-varying) within +// the merged channel axis, so c_new = c*p1*p2*p3 + i1*p2*p3 + i2*p3 + i3. +// +// GGML caps tensors at 4-D, which prevents a single reshape from representing the +// natural 5-D/6-D intermediate. We achieve the same result by folding the three +// strided axes ONE AT A TIME, composing three 4-D rearranges. The fold order +// matters: because each single-axis fold puts the just-folded factor innermost +// within the merged channel axis, folding in the order T→H→W produces p3 as the +// innermost factor in the final output — matching einops "(c p1 p2 p3)". +// +// Convention: all tensors use GGML ne=[W, H, T, C] (B=1 collapsed). A "factor" +// of 1 is a no-op; single-axis folds require the target axis to be divisible +// by factor. +// +// The primitives are verified byte-exact against PyTorch einops in the +// standalone test sd-s2d-primitives-test. + +namespace LTXVAE { + +// ---------- SpaceToDepth ---------- + +inline ggml_tensor* space_to_depth_axisW(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(W % factor == 0); + int64_t W_out = W / factor; + // Split innermost axis into [factor (innermost), W_out]. Merge H,T to stay 4D. + auto y = ggml_reshape_4d(ctx, x, factor, W_out, H * T, C); + // Move "factor" from axis 0 to axis 2 (adjacent to C). + // ggml_permute(a, p0, p1, p2, p3) says "old axis i goes to new position p_i". + // Here old→new: 0→2, 1→0, 2→1, 3→3. + y = ggml_cont(ctx, ggml_permute(ctx, y, 2, 0, 1, 3)); // ne=[W_out, H*T, factor, C] + // Merge (factor, C) with factor innermost of the new channel axis, matching einops (c p3). + y = ggml_reshape_4d(ctx, y, W_out, H, T, C * factor); + return y; +} + +inline ggml_tensor* space_to_depth_axisH(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(H % factor == 0); + int64_t H_out = H / factor; + auto y = ggml_reshape_4d(ctx, x, W, factor, H_out * T, C); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W, H*T, factor, C] + y = ggml_reshape_4d(ctx, y, W, H_out, T, C * factor); + return y; +} + +inline ggml_tensor* space_to_depth_axisT(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(T % factor == 0); + int64_t T_out = T / factor; + auto y = ggml_reshape_4d(ctx, x, W * H, factor, T_out, C); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W*H, T, factor, C] + y = ggml_reshape_4d(ctx, y, W, H, T_out, C * factor); + return y; +} + +// Compose: fold T first (so p1 ends up outer), then H (p2), then W (p3 innermost) +// — matching einops "(c p1 p2 p3)" channel ordering. +inline ggml_tensor* space_to_depth(ggml_context* ctx, ggml_tensor* x, + int p1, int p2, int p3) { + if (p1 > 1) x = space_to_depth_axisT(ctx, x, p1); + if (p2 > 1) x = space_to_depth_axisH(ctx, x, p2); + if (p3 > 1) x = space_to_depth_axisW(ctx, x, p3); + return x; +} + +// ---------- DepthToSpace (inverse) ---------- +// +// Each single-axis depth-to-space splits the last axis (C_in = C_out * factor) +// with factor innermost, moves factor to the strided spatial axis, then merges. +// To invert space_to_depth's T→H→W fold order, we unfold in reverse: W→H→T. + +inline ggml_tensor* depth_to_space_axisW(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + // Split last axis into [factor (innermost), C_out]. + auto y = ggml_reshape_4d(ctx, x, W, H * T, factor, C_out); + // Inverse of the S2D-axisW permute (2,0,1,3). Inverse of that map is (1,2,0,3): + // old 0→new 1, old 1→new 2, old 2→new 0, old 3→new 3. + y = ggml_cont(ctx, ggml_permute(ctx, y, 1, 2, 0, 3)); // ne=[factor, W, H*T, C_out] + y = ggml_reshape_4d(ctx, y, W * factor, H, T, C_out); + return y; +} + +inline ggml_tensor* depth_to_space_axisH(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + auto y = ggml_reshape_4d(ctx, x, W, H * T, factor, C_out); + // Inverse of S2D-axisH's (0,2,1,3) is itself (0,2,1,3). + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W, factor, H*T, C_out] + y = ggml_reshape_4d(ctx, y, W, H * factor, T, C_out); + return y; +} + +inline ggml_tensor* depth_to_space_axisT(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + auto y = ggml_reshape_4d(ctx, x, W * H, T, factor, C_out); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W*H, factor, T, C_out] + y = ggml_reshape_4d(ctx, y, W, H, T * factor, C_out); + return y; +} + +// Inverse of space_to_depth: unfold in reverse order (W first, then H, then T) +// because S2D folded T first, H, W. +inline ggml_tensor* depth_to_space(ggml_context* ctx, ggml_tensor* x, + int p1, int p2, int p3) { + if (p3 > 1) x = depth_to_space_axisW(ctx, x, p3); + if (p2 > 1) x = depth_to_space_axisH(ctx, x, p2); + if (p1 > 1) x = depth_to_space_axisT(ctx, x, p1); + return x; +} + +// ---------- patchify / unpatchify ---------- +// +// The VAE's patchify op uses a DIFFERENT channel ordering from the Downsample/Upsample +// blocks: einops `"b c (f p) (h q) (w r) -> b (c p r q) f h w"` — innermost within the +// merged channel axis is q (H-patch), NOT p3/W as elsewhere. To match, we fold in the +// order T (p), W (r), H (q) — last fold ends up innermost. + +inline ggml_tensor* patchify(ggml_context* ctx, ggml_tensor* x, int pt, int ph, int pw) { + if (pt > 1) x = space_to_depth_axisT(ctx, x, pt); + if (pw > 1) x = space_to_depth_axisW(ctx, x, pw); + if (ph > 1) x = space_to_depth_axisH(ctx, x, ph); + return x; +} + +inline ggml_tensor* unpatchify(ggml_context* ctx, ggml_tensor* x, int pt, int ph, int pw) { + if (ph > 1) x = depth_to_space_axisH(ctx, x, ph); + if (pw > 1) x = depth_to_space_axisW(ctx, x, pw); + if (pt > 1) x = depth_to_space_axisT(ctx, x, pt); + return x; +} + +// ---------- PixelNorm ---------- +// +// Python (ltx_core.model.common.normalization.PixelNorm, dim=1): +// y = x / sqrt(mean(x^2, dim=1, keepdim=True) + eps) +// PyTorch dim=1 is the channel axis. In our GGML layout ne=[W, H, T, C] that's +// ne[3] (outermost). ggml_rms_norm normalizes along ne[0] (innermost), so we +// permute C to innermost, rms-normalize, then permute back. +// +// This has NO learnable parameters — the Python PixelNorm is parameter-free. + +inline ggml_tensor* pixel_norm(ggml_context* ctx, ggml_tensor* x, float eps) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + // Move C to innermost. old→new: 0→1 (W to pos 1), 1→2 (H to 2), 2→3 (T to 3), 3→0 (C to 0). + auto y = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 3, 0)); // ne=[C, W, H, T] + y = ggml_rms_norm(ctx, y, eps); // normalize along ne[0]=C + // Permute back: C to outermost. old→new: 0→3, 1→0, 2→1, 3→2. + y = ggml_cont(ctx, ggml_permute(ctx, y, 3, 0, 1, 2)); // ne=[W, H, T, C] + (void)W; (void)H; (void)T; (void)C; + return y; +} + +// ---------- PerChannelStatistics ---------- +// +// Python: buffers `mean-of-means` [C] and `std-of-means` [C]. +// normalize(x) = (x - mean) / std +// un_normalize(x) = x * std + mean +// In GGML with ne=[W, H, T, C] and a 1D buffer of shape [C] (ne=[C, 1, 1, 1]), +// we broadcast over W*H*T by using the asymmetric-broadcast ggml_add/mul: +// ggml_mul(a, b) requires a->ne[i] % b->ne[i] == 0, so we pass x as `a` and the +// [C] buffer reshaped to ne=[1, 1, 1, C] as `b` — same outermost-axis shape. + +inline ggml_tensor* pcs_normalize(ggml_context* ctx, ggml_tensor* x, + ggml_tensor* mean_of_means, + ggml_tensor* std_of_means) { + int64_t C = x->ne[3]; + // Reshape both buffers to ne=[1, 1, 1, C] so they broadcast along W/H/T. + auto mu = ggml_reshape_4d(ctx, mean_of_means, 1, 1, 1, C); + auto sigma = ggml_reshape_4d(ctx, std_of_means, 1, 1, 1, C); + // (x - mu) / sigma = (x - mu) * (1/sigma). Compute the reciprocal by dividing. + // ggml doesn't have a direct div by tensor; emulate with ggml_div if available, + // else compute inv_sigma on the host. Since sigma is a loaded buffer (constant + // at inference), the cheapest is to do: x_shifted = x - mu; x_norm = x_shifted / sigma. + auto x_shifted = ggml_sub(ctx, x, mu); + auto x_norm = ggml_div(ctx, x_shifted, sigma); + return x_norm; +} + +inline ggml_tensor* pcs_unnormalize(ggml_context* ctx, ggml_tensor* x, + ggml_tensor* mean_of_means, + ggml_tensor* std_of_means) { + int64_t C = x->ne[3]; + auto mu = ggml_reshape_4d(ctx, mean_of_means, 1, 1, 1, C); + auto sigma = ggml_reshape_4d(ctx, std_of_means, 1, 1, 1, C); + auto y = ggml_mul(ctx, x, sigma); + y = ggml_add(ctx, y, mu); + return y; +} + +} // namespace LTXVAE + +#endif diff --git a/src/model.cpp b/src/model.cpp index 3479a0bea..8a5fe3617 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -471,6 +471,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) { return VERSION_ERNIE_IMAGE; } + if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.scale_shift_table") != std::string::npos) { + return VERSION_LTX2; + } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; } diff --git a/src/model.h b/src/model.h index 65bc6c367..a6049c8a8 100644 --- a/src/model.h +++ b/src/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_ERNIE_IMAGE, + VERSION_LTX2, VERSION_COUNT, }; @@ -139,6 +140,13 @@ static inline bool sd_version_is_ernie_image(SDVersion version) { return false; } +static inline bool sd_version_is_ltx2(SDVersion version) { + if (version == VERSION_LTX2) { + return true; + } + return false; +} + static inline bool sd_version_uses_flux2_vae(SDVersion version) { if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) { return true; @@ -165,7 +173,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_z_image(version) || - sd_version_is_ernie_image(version)) { + sd_version_is_ernie_image(version) || + sd_version_is_ltx2(version)) { return true; } return false; diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index 618c7f6e9..fb91ab346 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -653,6 +653,39 @@ std::string convert_diffusers_dit_to_original_lumina2(std::string name) { return name; } +std::string convert_diffusers_dit_to_original_ltx2(std::string name) { + // Maps diffusers' LTX Video Transformer naming → original LTX-2 naming. + // The GGML block tree mirrors the original Python class attribute names, so anything matching the + // original naming passes through. Only the few diffusers-specific renames are listed here. + static std::unordered_map ltx2_name_map; + + if (ltx2_name_map.empty()) { + // Input projection: diffusers names it x_embedder; original uses patchify_proj. + ltx2_name_map["x_embedder.weight"] = "patchify_proj.weight"; + ltx2_name_map["x_embedder.bias"] = "patchify_proj.bias"; + + // Timestep head: diffusers sometimes puts these under time_embed / time_text_embed while the + // original LTX-2 uses adaln_single.emb.timestep_embedder.linear_{1,2} and adaln_single.linear. + ltx2_name_map["time_embed.timestep_embedder.linear_1.weight"] = "adaln_single.emb.timestep_embedder.linear_1.weight"; + ltx2_name_map["time_embed.timestep_embedder.linear_1.bias"] = "adaln_single.emb.timestep_embedder.linear_1.bias"; + ltx2_name_map["time_embed.timestep_embedder.linear_2.weight"] = "adaln_single.emb.timestep_embedder.linear_2.weight"; + ltx2_name_map["time_embed.timestep_embedder.linear_2.bias"] = "adaln_single.emb.timestep_embedder.linear_2.bias"; + ltx2_name_map["time_embed.linear.weight"] = "adaln_single.linear.weight"; + ltx2_name_map["time_embed.linear.bias"] = "adaln_single.linear.bias"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "adaln_single.emb.timestep_embedder.linear_1.weight"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "adaln_single.emb.timestep_embedder.linear_1.bias"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_2.weight"] = "adaln_single.emb.timestep_embedder.linear_2.weight"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_2.bias"] = "adaln_single.emb.timestep_embedder.linear_2.bias"; + + // Transformer block names typically match (attn1/attn2/ff/scale_shift_table), so nothing to rewrite. + // Output projection & scale_shift_table pass through. + } + + replace_with_prefix_map(name, ltx2_name_map); + + return name; +} + std::string convert_other_dit_to_original_anima(std::string name) { static const std::string anima_net_prefix = "net."; if (!starts_with(name, anima_net_prefix)) { @@ -672,6 +705,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); + } else if (sd_version_is_ltx2(version)) { + name = convert_diffusers_dit_to_original_ltx2(name); } else if (sd_version_is_anima(version)) { name = convert_other_dit_to_original_anima(name); } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index b9d3e9af1..149f07b72 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -16,6 +16,7 @@ #include "lora.hpp" #include "pmid.hpp" #include "sample-cache.h" +#include "ltxvae.hpp" #include "tae.hpp" #include "vae.hpp" @@ -54,6 +55,7 @@ const char* model_version_to_str[] = { "Z-Image", "Ovis Image", "Ernie Image", + "LTX-2", }; const char* sampling_methods_str[] = { @@ -170,10 +172,93 @@ class StableDiffusionGGML { ggml_backend_free(backend); } + // Read an integer environment variable, returning `def` if unset or malformed. + static int get_env_int(const char* name, int def) { + const char* v = getenv(name); + if (v == nullptr || *v == '\0') return def; + try { + return std::stoi(v); + } catch (...) { + LOG_WARN("env %s: '%s' is not a valid integer, using default %d", name, v, def); + return def; + } + } + + // Initialize a GPU backend for the given device id, or fall back to CPU. + // For CUDA, device_id < 0 means "CPU only"; otherwise clamp to available count. + // `component_name` is used for log messages (e.g. "DiT", "Gemma", "VAE"). + static ggml_backend_t init_device_backend(int device_id, const char* component_name) { + if (device_id < 0) { + LOG_INFO("%s: using CPU backend (device=-1)", component_name); + return ggml_backend_cpu_init(); + } +#ifdef SD_USE_CUDA + int count = ggml_backend_cuda_get_device_count(); + if (count <= 0) { + LOG_WARN("%s: no CUDA devices available, falling back to CPU", component_name); + return ggml_backend_cpu_init(); + } + if (device_id >= count) { + LOG_WARN("%s: CUDA device %d requested but only %d available, falling back to device 0", + component_name, device_id, count); + device_id = 0; + } + auto b = ggml_backend_cuda_init(device_id); + if (b != nullptr) { + LOG_INFO("%s: using CUDA device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: CUDA device %d init failed, falling back to CPU", component_name, device_id); + return ggml_backend_cpu_init(); +#elif defined(SD_USE_VULKAN) + int count = ggml_backend_vk_get_device_count(); + if (count <= 0) { + LOG_WARN("%s: no Vulkan devices available, falling back to CPU", component_name); + return ggml_backend_cpu_init(); + } + if (device_id >= count) { + LOG_WARN("%s: Vulkan device %d requested but only %d available, falling back to device 0", + component_name, device_id, count); + device_id = 0; + } + auto b = ggml_backend_vk_init((size_t)device_id); + if (b != nullptr) { + LOG_INFO("%s: using Vulkan device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: Vulkan device %d init failed, falling back to CPU", component_name, device_id); + return ggml_backend_cpu_init(); +#elif defined(SD_USE_SYCL) + auto b = ggml_backend_sycl_init(device_id); + if (b != nullptr) { + LOG_INFO("%s: using SYCL device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: SYCL init failed, falling back to CPU", component_name); + return ggml_backend_cpu_init(); +#else + (void)device_id; + LOG_INFO("%s: using CPU backend", component_name); + return ggml_backend_cpu_init(); +#endif + } + + // Main backend init. Honours these env vars for per-component device placement + // (used by the init path below): + // SD_CUDA_DEVICE default CUDA device id (default 0) — also used for DiT + // SD_CUDA_DEVICE_CLIP text encoder / conditioner (falls back to SD_CUDA_DEVICE) + // SD_CUDA_DEVICE_VAE VAE (falls back to SD_CUDA_DEVICE) + // SD_CUDA_DEVICE_CONTROL ControlNet (falls back to SD_CUDA_DEVICE) + // SD_VK_DEVICE same pattern for the Vulkan build + // Setting any of these to -1 forces CPU for that component. + // + // `keep_clip_on_cpu` / `keep_vae_on_cpu` still take precedence and force CPU. + // For weights too big even for a dedicated device, use offload_params_to_cpu + // (keeps weights on CPU and streams per-step to GPU). void init_backend() { #ifdef SD_USE_CUDA - LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(0); + int main_dev = get_env_int("SD_CUDA_DEVICE", 0); + backend = init_device_backend(main_dev, "main"); #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); @@ -227,6 +312,36 @@ class StableDiffusionGGML { } } + // Resolve the backend for a sub-component by reading its env override (if set), + // otherwise reusing the main backend. Returns the main `backend` unchanged if + // the override matches the main device; otherwise creates a new backend (which + // the caller is responsible for freeing via the existing `!= backend` dtor check). + // `force_cpu` short-circuits to CPU regardless of the env var. + ggml_backend_t resolve_component_backend(const char* env_name, + const char* component_name, + bool force_cpu) { + if (force_cpu) { + if (ggml_backend_is_cpu(backend)) { + return backend; + } + LOG_INFO("%s: forced CPU backend", component_name); + return ggml_backend_cpu_init(); + } +#if defined(SD_USE_CUDA) || defined(SD_USE_VULKAN) || defined(SD_USE_SYCL) + int main_dev = get_env_int("SD_CUDA_DEVICE", 0); + int override_dev = get_env_int(env_name, main_dev); + if (override_dev == main_dev && !ggml_backend_is_cpu(backend)) { + // Same device as main — reuse the main backend to save GPU memory/context. + return backend; + } + return init_device_backend(override_dev, component_name); +#else + (void)env_name; + (void)component_name; + return backend; +#endif + } + std::shared_ptr get_rng(rng_type_t rng_type) { if (rng_type == STD_DEFAULT_RNG) { return std::make_shared(); @@ -352,6 +467,115 @@ class StableDiffusionGGML { auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + // LTX-2 prefix + Gemma sandwich-norm fixup: the conditioner expects Gemma at + // `text_encoder.model.*`, but `--llm-path` prepends `text_encoders.llm.*` + // (convert_tensors_name then maps gguf llama names to HF names, yielding + // `text_encoders.llm.model.*`). + // + // Additionally, Gemma 3 has 4 layernorms per block (sandwich norms) that the + // shared llm_name_map only partly translates. The raw GGUF names blk.N.{attn_norm, + // post_attention_norm, ffn_norm, post_ffw_norm} end up as HF-style + // input_layernorm + post_attention_norm + post_attention_layernorm + post_ffw_norm + // after the generic map (where ffn_norm→post_attention_layernorm is Qwen-correct + // but wrong for Gemma). We rename here once version is LTX-2: + // post_attention_layernorm → pre_feedforward_layernorm (was actually ffn_norm) + // post_attention_norm → post_attention_layernorm (append _layernorm) + // post_ffw_norm → post_feedforward_layernorm + // Order matters: do the first rename first so the second can safely write to + // the now-vacated post_attention_layernorm slot. + if (sd_version_is_ltx2(version)) { + // Step 1: prefix rewrite text_encoders.llm. → text_encoder. + const std::string from = "text_encoders.llm."; + const std::string to = "text_encoder."; + { + String2TensorStorage renamed; + size_t renames = 0; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + std::string new_k = k; + if (k.rfind(from, 0) == 0) { + new_k = to + k.substr(from.size()); + kv.second.name = new_k; + renames++; + } + renamed[new_k] = std::move(kv.second); + } + if (renames > 0) { + tensor_storage_map.swap(renamed); + LOG_INFO("LTX-2: renamed %zu '%s*' tensors → '%s*' (Gemma text encoder path)", + renames, from.c_str(), to.c_str()); + } + } + + // Step 2: Gemma 3 sandwich-norm renames, applied in the order documented + // above. Each pass rebuilds the storage map because std::map keys are const. + auto rename_suffix = [&](const std::string& old_suffix, const std::string& new_suffix) -> size_t { + String2TensorStorage renamed; + size_t renames = 0; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + std::string new_k = k; + size_t p = k.rfind(old_suffix); + if (p != std::string::npos && p + old_suffix.size() == k.size()) { + // Only rename if prefix looks like a Gemma layer key. + if (k.find("text_encoder.model.layers.") != std::string::npos) { + new_k = k.substr(0, p) + new_suffix; + kv.second.name = new_k; + renames++; + } + } + renamed[new_k] = std::move(kv.second); + } + tensor_storage_map.swap(renamed); + return renames; + }; + size_t r1 = rename_suffix(".post_attention_layernorm.weight", ".pre_feedforward_layernorm.weight"); + size_t r2 = rename_suffix(".post_attention_norm.weight", ".post_attention_layernorm.weight"); + size_t r3 = rename_suffix(".post_ffw_norm.weight", ".post_feedforward_layernorm.weight"); + if (r1 + r2 + r3 > 0) { + LOG_INFO("LTX-2: Gemma sandwich-norm rename: %zu pre_ff, %zu post_attn, %zu post_ff", + r1, r2, r3); + } + + // Step 3: Duplicate `first_stage_model.per_channel_statistics.*` into the + // `first_stage_model.encoder.per_channel_statistics.*` path expected by + // VideoEncoder's child block tree. VideoDecoder also expects these under + // its `decoder.per_channel_statistics` subprefix. Real LTX-2 checkpoints + // only ship the top-level buffer (mean-of-means, std-of-means). + { + const std::string top_pre = "first_stage_model.per_channel_statistics."; + size_t copied = 0; + // Snapshot keys with top_pre first (iteration + insertion is unsafe). + std::vector> to_copy; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + if (k.rfind(top_pre, 0) == 0) { + std::string suffix = k.substr(top_pre.size()); + to_copy.push_back({k, suffix}); + } + } + for (auto& pair : to_copy) { + const std::string& src_key = pair.first; + const std::string& suffix = pair.second; + auto src_it = tensor_storage_map.find(src_key); + if (src_it == tensor_storage_map.end()) continue; + for (const char* sub : {"encoder", "decoder"}) { + std::string dst_key = "first_stage_model." + std::string(sub) + + ".per_channel_statistics." + suffix; + if (tensor_storage_map.find(dst_key) != tensor_storage_map.end()) continue; + TensorStorage dup = src_it->second; + dup.name = dst_key; + tensor_storage_map[dst_key] = dup; + copied++; + } + } + if (copied > 0) { + LOG_INFO("LTX-2: duplicated %zu PerChannelStatistics entries to encoder/decoder subprefixes", + copied); + } + } + } + LOG_INFO("Version: %s ", model_version_to_str[version]); ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype @@ -428,11 +652,9 @@ class StableDiffusionGGML { bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; { - clip_backend = backend; - if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { - LOG_INFO("CLIP: Using CPU backend"); - clip_backend = ggml_backend_cpu_init(); - } + // Pick a device for the text-encoder stack. SD_CUDA_DEVICE_CLIP overrides + // (set to -1 for CPU); `keep_clip_on_cpu` still forces CPU regardless. + clip_backend = resolve_component_backend("SD_CUDA_DEVICE_CLIP", "CLIP/TextEncoder", clip_on_cpu); if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -563,6 +785,23 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map, "model.diffusion_model"); + } else if (sd_version_is_ltx2(version)) { + // LTX-2: Gemma 3 text encoder (Phase 8), 1D embeddings connector + DiT + // caption_projection (Phase 9), and LTX-2 causal 3D VAE (Phase 11) are all + // landed. LTX2GemmaConditioner auto-detects connector presence from the + // tensor map; if absent it falls back to Gemma's last_hidden_state. + // The tokenizer.json path is required — prompts can't be encoded without + // it. Any HuggingFace-format `tokenizer.json` for Gemma 3 works. + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + "text_encoder", + SAFE_STR(sd_ctx_params->gemma_tokenizer_path)); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -607,12 +846,10 @@ class StableDiffusionGGML { high_noise_diffusion_model->get_param_tensors(tensors); } - if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { - LOG_INFO("VAE Autoencoder: Using CPU backend"); - vae_backend = ggml_backend_cpu_init(); - } else { - vae_backend = backend; - } + // Pick a device for the VAE. SD_CUDA_DEVICE_VAE overrides (set to -1 for CPU); + // `keep_vae_on_cpu` still forces CPU regardless. + vae_backend = resolve_component_backend("SD_CUDA_DEVICE_VAE", "VAE", + sd_ctx_params->keep_vae_on_cpu); auto create_tae = [&]() -> std::shared_ptr { if (sd_version_is_wan(version) || @@ -646,6 +883,30 @@ class StableDiffusionGGML { "first_stage_model", vae_decode_only, version); + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: in the real checkpoint after convert_tensors_name, + // the `vae.` → `first_stage_model.` rename from name_conversion.cpp + // puts weights under the standard `first_stage_model.` prefix. The + // sd-vae-parity test uses a pre-named `vae.` state dict directly so + // it can run on the parity dumper's output without going through the + // conversion pass. + // + // The 22B checkpoint (see `ltx2_22b_{enc,dec}_specs`) has a 9-block + // encoder/decoder with mixed RES_X and COMPRESS_* blocks — much deeper + // than the 4-block tiny-test default. We hardcode the 22B spec here for + // the smoke test; a proper auto-detect from tensor shapes is a follow-up. + return std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + version, + /*in_ch=*/3, + /*latent_ch=*/128, + /*patch=*/4, + /*decoder_base_ch=*/128, + /*timestep_cond=*/false, + LTXVAE::LTX2VAERunner::ltx2_22b_enc_specs(), + LTXVAE::LTX2VAERunner::ltx2_22b_dec_specs()); } else { auto model = std::make_shared(vae_backend, offload_params_to_cpu, @@ -960,6 +1221,8 @@ class StableDiffusionGGML { } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; + } else if (sd_version_is_ltx2(version)) { + pred_type = LTX2_FLOW_PRED; } else { pred_type = EPS_PRED; } @@ -992,6 +1255,11 @@ class StableDiffusionGGML { denoiser = std::make_shared(); break; } + case LTX2_FLOW_PRED: { + LOG_INFO("running in LTX-2 FLOW mode"); + denoiser = std::make_shared(); + break; + } default: { LOG_ERROR("Unknown predition type %i", pred_type); ggml_free(ctx); @@ -1865,6 +2133,9 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_uses_flux2_vae(version)) { latent_channel = 128; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE latent dim (matches DiT patchify_proj in_channels). + latent_channel = 128; } else { latent_channel = 16; } @@ -1872,9 +2143,24 @@ class StableDiffusionGGML { return latent_channel; } - int get_image_seq_len(int h, int w) { + int get_image_seq_len(int h, int w, int frames = 1) { int vae_scale_factor = get_vae_scale_factor(); - return (h / vae_scale_factor) * (w / vae_scale_factor); + int spatial_tokens = (h / vae_scale_factor) * (w / vae_scale_factor); + // For video flow-match schedulers (LTX-2, Wan), `tokens` in the shift + // formula is math.prod(latent.shape[2:]) = T_latent * H_latent * W_latent. + // Earlier we only passed the spatial count (H*W), which under-shifted + // the LTX-2 schedule because the 22B run has 25-frame inputs → + // T_latent = 4, so the real token count is 4× the spatial count. + // Python reference: ltx_core/components/schedulers.py::LTX2Scheduler.execute. + if (frames > 1 && sd_version_is_ltx2(version)) { + int T_latent = ((frames - 1) / 8) + 1; // LTX-2 VAE: 8× temporal compression. + return spatial_tokens * T_latent; + } + if (frames > 1 && sd_version_is_wan(version)) { + int T_latent = ((frames - 1) / 4) + 1; // Wan VAE: 4× temporal compression. + return spatial_tokens * T_latent; + } + return spatial_tokens; } sd::Tensor generate_init_latent(int width, @@ -1887,6 +2173,9 @@ class StableDiffusionGGML { int T = frames; if (sd_version_is_wan(version)) { T = ((T - 1) / 4) + 1; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: 8× temporal compression. + T = ((T - 1) / 8) + 1; } int C = get_latent_channel(); if (video) { @@ -2050,6 +2339,7 @@ const char* prediction_to_str[] = { "sd3_flow", "flux_flow", "flux2_flow", + "ltx2_flow", }; const char* sd_prediction_name(enum prediction_t prediction) { @@ -2178,6 +2468,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "t5xxl_path: %s\n" "llm_path: %s\n" "llm_vision_path: %s\n" + "gemma_tokenizer_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" "vae_path: %s\n" @@ -2210,6 +2501,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->llm_path), SAFE_STR(sd_ctx_params->llm_vision_path), + SAFE_STR(sd_ctx_params->gemma_tokenizer_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), @@ -2382,6 +2674,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; + sd_vid_gen_params->fps = 24.f; sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_cache_params_init(&sd_vid_gen_params->cache); } @@ -2535,6 +2828,7 @@ struct GenerationRequest { sd_guidance_params_t high_noise_guidance = {}; sd_pm_params_t pm_params = {}; int frames = -1; + float fps = 0.f; // 0 = keep diffusion model's default float vace_strength = 1.f; GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { @@ -2565,6 +2859,7 @@ struct GenerationRequest { width = sd_vid_gen_params->width; height = sd_vid_gen_params->height; frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1; + fps = sd_vid_gen_params->fps; clip_skip = sd_vid_gen_params->clip_skip; vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); @@ -2716,7 +3011,7 @@ struct SamplePlan { sample_params->scheduler, sample_method); sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, - sd_ctx->sd->get_image_seq_len(request->height, request->width), + sd_ctx->sd->get_image_seq_len(request->height, request->width, request->frames), scheduler, sd_ctx->sd->version); } @@ -3527,6 +3822,15 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift); sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count); + // Propagate output fps to diffusion models that need it for temporal RoPE + // (LTX-2 divides time positions by fps; see LTXRope::gen_video_positions). + if (request.fps > 0.f && sd_ctx->sd->diffusion_model) { + sd_ctx->sd->diffusion_model->set_fps(request.fps); + } + if (request.fps > 0.f && sd_ctx->sd->high_noise_diffusion_model) { + sd_ctx->sd->high_noise_diffusion_model->set_fps(request.fps); + } + SamplePlan plan(sd_ctx, sd_vid_gen_params, request); auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request); if (!latent_inputs_opt.has_value()) { diff --git a/src/tokenizers/gemma_tokenizer.cpp b/src/tokenizers/gemma_tokenizer.cpp new file mode 100644 index 000000000..43c772abe --- /dev/null +++ b/src/tokenizers/gemma_tokenizer.cpp @@ -0,0 +1,254 @@ +#include "gemma_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "util.h" + +namespace { + +// Parse "<0xAB>" -> byte value 0xAB. Returns -1 if not a byte token. +int parse_byte_token(const std::string& piece) { + if (piece.size() != 6) return -1; + if (piece[0] != '<' || piece[1] != '0' || piece[2] != 'x' || piece[5] != '>') return -1; + auto hex = [](char c) -> int { + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'A' && c <= 'F') return 10 + c - 'A'; + if (c >= 'a' && c <= 'f') return 10 + c - 'a'; + return -1; + }; + int hi = hex(piece[3]), lo = hex(piece[4]); + if (hi < 0 || lo < 0) return -1; + return (hi << 4) | lo; +} + +} // namespace + +GemmaTokenizer::GemmaTokenizer() { + byte_fallback_ids_.fill(-1); + BOS_TOKEN = ""; + EOS_TOKEN = ""; + PAD_TOKEN = ""; + UNK_TOKEN = ""; + BOS_TOKEN_ID = 2; + EOS_TOKEN_ID = 1; + PAD_TOKEN_ID = 0; + UNK_TOKEN_ID = 3; + add_bos_token = true; // Gemma post-processor prepends . + add_eos_token = false; + pad_left = true; // Gemma uses left padding. +} + +std::string GemmaTokenizer::decode_token(int token_id) const { + if (token_id >= 0 && token_id < (int)id_to_piece_.size()) { + return id_to_piece_[token_id]; + } + return ""; +} + +// HF normalizer: Replace " " -> "▁" (U+2581). All other chars untouched. +std::string GemmaTokenizer::normalize(const std::string& text) const { + static const std::string metaspace = "\xe2\x96\x81"; // UTF-8 for U+2581 + std::string out; + out.reserve(text.size() + text.size() / 8); + for (char c : text) { + if (c == ' ') { + out.append(metaspace); + } else { + out.push_back(c); + } + } + return out; +} + +std::vector GemmaTokenizer::split_utf8_chars(const std::string& s) { + std::vector out; + size_t i = 0; + while (i < s.size()) { + unsigned char b = (unsigned char)s[i]; + size_t len; + if (b < 0x80) len = 1; + else if (b < 0xC0) len = 1; // malformed continuation; treat as 1-byte + else if (b < 0xE0) len = 2; + else if (b < 0xF0) len = 3; + else len = 4; + if (i + len > s.size()) len = s.size() - i; + out.emplace_back(s.substr(i, len)); + i += len; + } + return out; +} + +void GemmaTokenizer::byte_fallback(const std::string& ch, std::vector& out) const { + for (unsigned char b : ch) { + int id = byte_fallback_ids_[b]; + if (id >= 0) { + out.push_back(id_to_piece_[id]); // "<0xNN>" + } else { + out.push_back(UNK_TOKEN); + } + } +} + +std::vector GemmaTokenizer::bpe(std::vector pieces) const { + // Greedy BPE: at each step find the adjacent pair with lowest merge rank and apply it. + // O(N^2 * merges_lookup) per encode. N here is chars in a single chunk — a few hundred + // at most for our use. Good enough. + while (pieces.size() > 1) { + int best_rank = INT_MAX; + int best_i = -1; + for (size_t i = 0; i + 1 < pieces.size(); ++i) { + std::string key = pieces[i]; + key.push_back('\t'); + key.append(pieces[i + 1]); + auto it = merge_ranks_.find(key); + if (it != merge_ranks_.end() && it->second < best_rank) { + best_rank = it->second; + best_i = (int)i; + } + } + if (best_i < 0) break; + pieces[best_i] = pieces[best_i] + pieces[best_i + 1]; + pieces.erase(pieces.begin() + best_i + 1); + } + return pieces; +} + +std::vector GemmaTokenizer::encode(const std::string& text, on_new_token_cb_t /*cb*/) { + if (!loaded_) { + LOG_ERROR("GemmaTokenizer::encode called before load_from_file()"); + return {}; + } + + std::string normalized = normalize(text); + + // ignore_merges=true: if the entire (post-normalization) chunk is directly in vocab, + // emit it as a single token without running BPE. + if (ignore_merges_) { + auto it = vocab_.find(normalized); + if (it != vocab_.end()) { + return {it->second}; + } + } + + std::vector pieces; + pieces.reserve(normalized.size()); + for (const auto& ch : split_utf8_chars(normalized)) { + auto it = vocab_.find(ch); + if (it != vocab_.end()) { + pieces.push_back(ch); + } else { + byte_fallback(ch, pieces); + } + } + + pieces = bpe(std::move(pieces)); + + std::vector ids; + ids.reserve(pieces.size()); + for (const auto& p : pieces) { + auto it = vocab_.find(p); + if (it != vocab_.end()) { + ids.push_back(it->second); + } else { + ids.push_back(UNK_TOKEN_ID); + } + } + return ids; +} + +bool GemmaTokenizer::load_from_file(const std::string& path) { + std::ifstream f(path); + if (!f.is_open()) { + LOG_ERROR("GemmaTokenizer: cannot open %s", path.c_str()); + return false; + } + nlohmann::json j; + try { + f >> j; + } catch (const nlohmann::json::parse_error& e) { + LOG_ERROR("GemmaTokenizer: JSON parse error in %s: %s", path.c_str(), e.what()); + return false; + } + + if (!j.contains("model") || !j["model"].contains("vocab") || !j["model"].contains("merges")) { + LOG_ERROR("GemmaTokenizer: JSON missing model.vocab or model.merges"); + return false; + } + const auto& model = j["model"]; + + // Vocab: HF tokenizer.json for BPE stores vocab as an object {piece: id}. + const auto& vocab = model["vocab"]; + id_to_piece_.clear(); + id_to_piece_.resize(vocab.size()); + vocab_.reserve(vocab.size() * 2); + for (auto it = vocab.begin(); it != vocab.end(); ++it) { + const std::string piece = it.key(); + int id = it.value().get(); + vocab_.emplace(piece, id); + if (id >= 0 && id < (int)id_to_piece_.size()) { + id_to_piece_[id] = piece; + } + } + + // Merges: ordered list; earlier entries have higher priority (lower rank). + const auto& merges = model["merges"]; + merge_ranks_.reserve(merges.size() * 2); + int rank = 0; + for (const auto& m : merges) { + // tokenizers >=0.20 stores each merge as a [left, right] array; older versions used a + // single space-separated string. Accept both for robustness. + std::string left, right; + if (m.is_array() && m.size() == 2) { + left = m[0].get(); + right = m[1].get(); + } else if (m.is_string()) { + const std::string s = m.get(); + auto pos = s.find(' '); + if (pos == std::string::npos) continue; + left = s.substr(0, pos); + right = s.substr(pos + 1); + } else { + continue; + } + std::string key = left; + key.push_back('\t'); + key.append(right); + merge_ranks_.emplace(std::move(key), rank++); + } + + // Locate byte-fallback IDs. Every byte value 0..255 should have a "<0xNN>" entry. + for (int id = 0; id < (int)id_to_piece_.size(); ++id) { + int b = parse_byte_token(id_to_piece_[id]); + if (b >= 0) { + byte_fallback_ids_[b] = id; + } + } + + // Special token IDs: honor what's actually in the JSON if model is unusual; otherwise + // keep the Gemma-3 defaults from the ctor. + if (model.contains("unk_token") && model["unk_token"].is_string()) { + auto it = vocab_.find(model["unk_token"].get()); + if (it != vocab_.end()) UNK_TOKEN_ID = it->second; + } + if (j.contains("added_tokens")) { + for (const auto& at : j["added_tokens"]) { + if (!at.contains("content") || !at.contains("id")) continue; + const std::string c = at["content"].get(); + int id = at["id"].get(); + if (c == "") BOS_TOKEN_ID = id; + else if (c == "") EOS_TOKEN_ID = id; + else if (c == "") PAD_TOKEN_ID = id; + else if (c == "") UNK_TOKEN_ID = id; + } + } + + ignore_merges_ = model.value("ignore_merges", true); + loaded_ = true; + LOG_DEBUG("GemmaTokenizer loaded: vocab=%zu merges=%zu", vocab_.size(), merge_ranks_.size()); + return true; +} diff --git a/src/tokenizers/gemma_tokenizer.h b/src/tokenizers/gemma_tokenizer.h new file mode 100644 index 000000000..8753cbd9f --- /dev/null +++ b/src/tokenizers/gemma_tokenizer.h @@ -0,0 +1,50 @@ +#ifndef __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ +#define __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +#include "tokenizer.h" + +// Gemma 3 tokenizer. BPE with byte-fallback + Metaspace-style normalization +// (space → U+2581 "▁"). Loads a HuggingFace tokenizer.json produced by +// `AutoTokenizer.from_pretrained("google/gemma-3-12b-it").backend_tokenizer.save()`. +// +// Not embeddable as a header like the other tokenizers — the raw JSON is ~33 MB +// and the vocab alone is 262144 pieces plus 514906 merges. Expected workflow: +// ship the tokenizer.json file alongside the weights, pass its path at runtime. +class GemmaTokenizer : public Tokenizer { +protected: + std::unordered_map vocab_; // piece -> id + std::vector id_to_piece_; // id -> piece + std::unordered_map merge_ranks_; // "left\tright" -> rank (lower = higher priority) + std::array byte_fallback_ids_{}; // byte value -> piece id for <0xXX> + bool loaded_ = false; + bool ignore_merges_ = true; + + std::string decode_token(int token_id) const override; + std::string normalize(const std::string& text) const override; + + // Split a UTF-8 string into its individual code-point-sized chunks. + static std::vector split_utf8_chars(const std::string& s); + + // Byte-fallback a character that isn't in vocab: produce UTF-8 byte tokens. + void byte_fallback(const std::string& ch, std::vector& out) const; + + // Run BPE merging until no more merges apply. + std::vector bpe(std::vector pieces) const; + +public: + GemmaTokenizer(); + + bool load_from_file(const std::string& path); + bool is_loaded() const { return loaded_; } + int vocab_size() const { return (int)id_to_piece_.size(); } + + std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ diff --git a/src/vae.hpp b/src/vae.hpp index dc69535e8..416bc9b81 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -73,6 +73,9 @@ struct VAE : public GGMLRunner { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: 32× spatial compression (256×256 → 8×8 latent). + scale_factor = 32; } return scale_factor; } diff --git a/tests/ltx_parity/CMakeLists.txt b/tests/ltx_parity/CMakeLists.txt new file mode 100644 index 000000000..edd7771c4 --- /dev/null +++ b/tests/ltx_parity/CMakeLists.txt @@ -0,0 +1,62 @@ +set(TARGET sd-ltx-parity) + +add_executable(${TARGET} + test_ltx_parity.cpp +) + +target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_TARGET sd-gemma-parity) + +add_executable(${GEMMA_TARGET} + test_gemma_parity.cpp +) + +target_link_libraries(${GEMMA_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_TOK_TARGET sd-gemma-tokenizer-test) + +add_executable(${GEMMA_TOK_TARGET} + test_gemma_tokenizer.cpp +) + +target_link_libraries(${GEMMA_TOK_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_TOK_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(S2D_TARGET sd-s2d-primitives-test) + +add_executable(${S2D_TARGET} + test_s2d_primitives.cpp +) + +target_link_libraries(${S2D_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${S2D_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(VAE_TARGET sd-vae-parity) + +add_executable(${VAE_TARGET} + test_vae_parity.cpp +) + +target_link_libraries(${VAE_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${VAE_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(CONN_TARGET sd-connector-parity) + +add_executable(${CONN_TARGET} + test_connector_parity.cpp +) + +target_link_libraries(${CONN_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${CONN_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(VAE_RT_TARGET sd-ltx2-vae-roundtrip) + +add_executable(${VAE_RT_TARGET} + test_ltx2_vae_roundtrip.cpp +) + +target_link_libraries(${VAE_RT_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${VAE_RT_TARGET} PUBLIC c_std_11 cxx_std_17) diff --git a/tests/ltx_parity/README.md b/tests/ltx_parity/README.md new file mode 100644 index 000000000..8f5f090fa --- /dev/null +++ b/tests/ltx_parity/README.md @@ -0,0 +1,36 @@ +# LTX-2 parity tests + +Block-by-block numerical parity between the C++/GGML LTX-2 port and the reference +PyTorch implementation in `/devel/tools/diffusion/LTX-2/packages/ltx-core/`. + +## How it works + +`dump_reference.py` instantiates a **tiny, deterministic** LTX-2 transformer with fixed +random weights (seed=0) and runs a forward pass on a fixed input. It writes: + +- `/tmp/ltx_ref/manifest.json` — catalogue of every dumped tensor (name, shape, dtype, offset) +- `/tmp/ltx_ref/state_dict.safetensors` — all model weights in a standard format +- `/tmp/ltx_ref/tensors/*.bin` — each intermediate tensor as raw float32 bytes + +The "tiny" model is small enough (2 layers, inner_dim=128) to run in milliseconds on CPU +and make it easy to dump every intermediate without filling the disk. That scope is +deliberate: parity at tiny dims transfers to full-size models because every block is +tested exhaustively. + +A matching C++ test (to be written) loads `state_dict.safetensors`, replays the same +input, and diffs every intermediate tensor against the reference. Tolerances: +- F32: 1e-5 absolute, 1e-4 relative +- BF16/FP16 C++ path: 1e-2 absolute, 5e-3 relative + +## Run + +```bash +/home/ilintar/venv/bin/python dump_reference.py +``` + +## What's NOT covered (yet) + +- **Gemma 3 text encoder** — needs a Gemma 3 checkpoint. Deferred; we dump a synthetic + `context` tensor (random but fixed) as a placeholder. +- **VAE** — separate dumper planned once the C++ VAE is building. +- **Sampler loop** — a separate script, not this one. This one tests a single forward call. diff --git a/tests/ltx_parity/dump_connector.py b/tests/ltx_parity/dump_connector.py new file mode 100644 index 000000000..76aa2f695 --- /dev/null +++ b/tests/ltx_parity/dump_connector.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +"""Dump tiny LTX-2 Connector V1 reference tensors for C++/GGML parity testing. + +Covers: +- FeatureExtractorV1 (masked norm + aggregate_embed Linear) +- Embeddings1DConnector (2× BasicTransformerBlock1D + final rms_norm, with + num_learnable_registers weights present but unused on all-ones mask path) +- PixArtAlphaTextProjection (caption_projection inside the DiT) + +Usage: + /home/ilintar/venv/bin/python dump_connector.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file + +from ltx_core.text_encoders.gemma.embeddings_connector import Embeddings1DConnector +from ltx_core.text_encoders.gemma.embeddings_processor import ( + EmbeddingsProcessor, + convert_to_additive_mask, +) +from ltx_core.text_encoders.gemma.feature_extractor import FeatureExtractorV1 +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection + + +SEED = 0 + +# Two variants to exercise different connector paths: +# - "nopad" (default): SEQ_LEN=8, NUM_REGISTERS=4, mask all-ones. Register +# replacement is a no-op (reals fill everything) — +# covers the "skip concat" branch in C++. +# - "padded" (env CONNECTOR_VARIANT=padded): SEQ_LEN=8, NUM_REGISTERS=8, +# T_REAL=3 with left-padded mask [0,0,0,0,0,1,1,1]. +# Register replacement moves reals to the front and +# fills positions [T_REAL, num_reg) with the trailing +# slice of learnable_registers — this is the path the +# production conditioner/LTX2ConnectorRunner now takes +# when T_real < num_registers. +import os +VARIANT = os.environ.get("CONNECTOR_VARIANT", "nopad") +assert VARIANT in ("nopad", "padded") + +OUT_DIR = pathlib.Path("/tmp/connector_ref" if VARIANT == "nopad" else "/tmp/connector_ref_padded") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny config (mirrors real LTX-2 head_dim=128 for fp16-stable attention; +# 2 heads keeps inner_dim small enough for fast parity). +NUM_HEADS = 2 +HEAD_DIM = 32 +INNER_DIM = NUM_HEADS * HEAD_DIM # 64 +NUM_LAYERS = 2 +ROPE_THETA = 10_000.0 +ROPE_MAX_POS = [1] + +FEAT_NUM_LAYERS = 5 # fake "embed + 4 transformer layers" +FLAT_DIM = INNER_DIM * FEAT_NUM_LAYERS # 80 + +CAPTION_CHANNELS = INNER_DIM # 64 +CAPTION_HIDDEN = 128 # DiT inner dim (larger than connector) +CAPTION_OUT = CAPTION_HIDDEN # default: = hidden_size + +BATCH = 1 + +if VARIANT == "nopad": + NUM_REGISTERS = 4 + SEQ_LEN = 8 # > num_reg so register replacement is a no-op + T_REAL = 8 # entire SEQ_LEN is real tokens +else: # padded + NUM_REGISTERS = 8 + SEQ_LEN = 8 # == num_reg (Python requires SEQ_LEN % NUM_REGISTERS == 0) + T_REAL = 3 # left-padded: only last 3 positions are real + +assert SEQ_LEN % NUM_REGISTERS == 0 + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe_name = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe_name}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() == 1: + # RMSNorm weights (standard, not Gemma's (1+w)) etc.: keep at 1 so + # effective scale is identity at init. For plain biases zero is + # also fine; we just keep the default shape. + if "norm" in name.lower() or "weight" == name.split(".")[-1] and p.shape[0] == INNER_DIM: + p.fill_(1.0) + else: + p.zero_() + elif p.dim() == 2: + fan_in = p.shape[1] + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + else: + p.normal_(mean=0.0, std=0.02, generator=g) + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + # --- Build modules (tiny). --- + aggregate_embed = torch.nn.Linear(FLAT_DIM, INNER_DIM, bias=False) + feature_extractor = FeatureExtractorV1(aggregate_embed=aggregate_embed, is_av=False) + + connector = Embeddings1DConnector( + attention_head_dim=HEAD_DIM, + num_attention_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + positional_embedding_theta=ROPE_THETA, + positional_embedding_max_pos=ROPE_MAX_POS, + causal_temporal_positioning=False, + num_learnable_registers=NUM_REGISTERS, + rope_type=LTXRopeType.INTERLEAVED, + # True = numpy fp64 linspace + pow cast to fp32 at the end. Matches our + # C++ fp64 path byte-exactly. With False, torch's fp32 pow drifts 1 ULP + # at the tail of the grid, causing ~5e-2 cos/sin diffs we can't reproduce. + double_precision_rope=True, + apply_gated_attention=False, + ) + + caption_projection = PixArtAlphaTextProjection( + in_features=CAPTION_CHANNELS, + hidden_size=CAPTION_HIDDEN, + out_features=CAPTION_OUT, + act_fn="gelu_tanh", + ) + + # Tame weights deterministically. + tame_(feature_extractor) + tame_(connector) + tame_(caption_projection) + + # Cast to float32 (the tame() doesn't touch registers which default to bfloat16). + with torch.no_grad(): + if hasattr(connector, "learnable_registers"): + g = torch.Generator().manual_seed(SEED + 1) + connector.learnable_registers.data = ( + torch.rand(NUM_REGISTERS, INNER_DIM, generator=g) * 2.0 - 1.0 + ).to(torch.float32) + + feature_extractor.eval() + connector.eval() + caption_projection.eval() + + # --- Build inputs. --- + rng = np.random.default_rng(SEED + 2) + # Pretend 49-layer stack (tiny): [B, T, D=INNER_DIM, L=FEAT_NUM_LAYERS] + stacked = torch.tensor( + rng.normal(loc=0.0, scale=1.0, size=(BATCH, SEQ_LEN, INNER_DIM, FEAT_NUM_LAYERS)), + dtype=torch.float32, + ) + # Binary attention mask. Left-padded when VARIANT="padded": first + # (SEQ_LEN - T_REAL) positions are pad (0), last T_REAL are real (1). + attention_mask = torch.ones((BATCH, SEQ_LEN), dtype=torch.int64) + if VARIANT == "padded": + attention_mask[:, : SEQ_LEN - T_REAL] = 0 + # Zero-out the padded positions in the stacked input too, matching what + # the real HF pipeline feeds (padded tokens have zero embeddings after + # feature extraction since Gemma's pad_token embedding is unused in the + # text-to-video pipeline — FeatureExtractor masks them out anyway). + stacked[:, : SEQ_LEN - T_REAL, :, :] = 0 + + manifest = Manifest() + save_tensor(stacked, "stacked_in", manifest) + save_tensor(attention_mask.to(torch.float32), "attention_mask", manifest) + + # --- 1. Feature extractor. --- + with torch.no_grad(): + feat_out, _ = feature_extractor(stacked, attention_mask, padding_side="left") + save_tensor(feat_out, "feat_ext_out", manifest) + print(f" feat_ext_out shape={tuple(feat_out.shape)} " + f"mean={feat_out.mean().item():.4f} std={feat_out.std().item():.4f}") + + # --- 2. Connector. --- + additive_mask = convert_to_additive_mask(attention_mask, feat_out.dtype) + # Run connector piece-by-piece to capture intermediates. + with torch.no_grad(): + hs = feat_out + am = additive_mask + # Register replacement (no-op for all-ones mask, but exercises the path). + if connector.num_learnable_registers: + hs, am = connector._replace_padded_with_learnable_registers(hs, am) + save_tensor(hs, "after_registers", manifest) + + indices_grid = torch.arange(hs.shape[1], dtype=torch.float32) + indices_grid = indices_grid[None, None, :] + from ltx_core.model.transformer.rope import ( + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, + ) + freq_gen = generate_freq_grid_np if connector.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=connector.inner_dim, + out_dtype=hs.dtype, + theta=connector.positional_embedding_theta, + max_pos=connector.positional_embedding_max_pos, + num_attention_heads=connector.num_attention_heads, + rope_type=connector.rope_type, + freq_grid_generator=freq_gen, + ) + cos_f, sin_f = freqs_cis + save_tensor(cos_f, "rope_cos", manifest) + save_tensor(sin_f, "rope_sin", manifest) + + for i, block in enumerate(connector.transformer_1d_blocks): + hs = block(hs, attention_mask=am, pe=freqs_cis) + save_tensor(hs, f"conn_block_{i}_out", manifest) + print(f" conn_block_{i}_out shape={tuple(hs.shape)} " + f"mean={hs.mean().item():.4f} std={hs.std().item():.4f}") + + from ltx_core.utils import rms_norm + hs = rms_norm(hs) + save_tensor(hs, "conn_final_out", manifest) + print(f" conn_final_out shape={tuple(hs.shape)} " + f"mean={hs.mean().item():.4f} std={hs.std().item():.4f}") + + # --- 3. Caption projection. --- + with torch.no_grad(): + caption_out = caption_projection(hs) + save_tensor(caption_out, "caption_proj_out", manifest) + print(f" caption_proj_out shape={tuple(caption_out.shape)} " + f"mean={caption_out.mean().item():.4f} std={caption_out.std().item():.4f}") + + # --- Save state dict under C++-friendly keys. --- + state: Dict[str, torch.Tensor] = {} + # Feature extractor + state["feature_extractor.aggregate_embed.weight"] = ( + feature_extractor.aggregate_embed.weight.detach().to(torch.float32).contiguous() + ) + # Connector parameters + for key, value in connector.state_dict().items(): + state[f"connector.{key}"] = value.detach().to(torch.float32).contiguous() + # Caption projection + for key, value in caption_projection.state_dict().items(): + state[f"caption_projection.{key}"] = value.detach().to(torch.float32).contiguous() + + save_file(state, str(OUT_DIR / "state_dict.safetensors")) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(state.keys())) + "\n") + manifest.dump(OUT_DIR / "manifest.json") + + (OUT_DIR / "config.json").write_text(json.dumps({ + "num_heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "inner_dim": INNER_DIM, + "num_layers": NUM_LAYERS, + "num_registers": NUM_REGISTERS, + "rope_theta": ROPE_THETA, + "rope_max_pos": ROPE_MAX_POS, + "feat_num_layers": FEAT_NUM_LAYERS, + "flat_dim": FLAT_DIM, + "caption_channels": CAPTION_CHANNELS, + "caption_hidden": CAPTION_HIDDEN, + "caption_out": CAPTION_OUT, + "seq_len": SEQ_LEN, + "batch": BATCH, + }, indent=2)) + + print(f"\nDone. {len(manifest.entries)} tensors → {OUT_DIR}") + print(f"State dict: {len(state)} keys → {OUT_DIR}/state_dict.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_gemma.py b/tests/ltx_parity/dump_gemma.py new file mode 100644 index 000000000..b295f9d04 --- /dev/null +++ b/tests/ltx_parity/dump_gemma.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +"""Dump tiny Gemma 3 reference tensors for C++/GGML parity testing. + +Strategy mirrors dump_reference.py: instantiate a tiny Gemma3TextModel (6 layers so +sliding-window pattern triggers at layer 5, small dims) with deterministic tamed +weights, run one forward pass on fixed input_ids, and write every intermediate tensor +(embedding-post-scale, per-layer output, all-layer stack, final norm) to +/tmp/gemma_ref/tensors/ as raw fp32 bytes. Also write the state_dict as safetensors +so the C++ side can load identical weights. + +Usage: + /home/ilintar/venv/bin/python dump_gemma.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file +from transformers import Gemma3TextConfig, Gemma3TextModel + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/gemma_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Select config via GEMMA_PARITY_VARIANT env var: "tiny" (default) or "deep". +# The tiny variant is fast but only exercises 6 stacked layers with tamed weights. +# The deep variant scales to 24 layers × 512 hidden to stress-test accumulated drift +# and the full sliding/global interleave pattern (same sliding_window_pattern=6 as +# the real Gemma 3 12B). Shared code path in both — differences are pure scaling. +import os +VARIANT = os.environ.get("GEMMA_PARITY_VARIANT", "tiny") + +if VARIANT == "deep": + NUM_LAYERS = 24 + HIDDEN_SIZE = 512 + NUM_HEADS = 8 + NUM_KV_HEADS = 4 + HEAD_DIM = 64 + INTERMEDIATE_SIZE = 1024 + VOCAB_SIZE = 1024 + SLIDING_WINDOW = 16 + SLIDING_WINDOW_PATTERN = 6 + SEQ_LEN = 32 # > sliding_window so the sliding mask actually bites + # Each "deep" run reuses /tmp/gemma_ref but under a distinct tensor prefix so + # test_gemma_parity.cpp can load both files without key collision. + TENSOR_PREFIX_MODEL = "text_encoder_deep.model" + TENSOR_TAG_PREFIX = "deep_" # applied to tensor output filenames +else: + # Tiny Gemma 3 config. 6 layers so layer index 5 is the first (and only) global + # layer under the (i+1)%6 rule — exercises both sliding and full paths. + NUM_LAYERS = 6 + HIDDEN_SIZE = 128 + NUM_HEADS = 4 + NUM_KV_HEADS = 2 + HEAD_DIM = 32 # NOTE: != HIDDEN_SIZE / NUM_HEADS. Matches Gemma's non-standard head_dim. + INTERMEDIATE_SIZE = 256 + VOCAB_SIZE = 512 + SLIDING_WINDOW = 4 + SLIDING_WINDOW_PATTERN = 6 + SEQ_LEN = 8 + TENSOR_PREFIX_MODEL = "text_encoder.model" + TENSOR_TAG_PREFIX = "" + +RMS_EPS = 1e-6 +ROPE_THETA = 1_000_000.0 +ROPE_LOCAL_THETA = 10_000.0 +BATCH = 1 + +TENSOR_PREFIX = TENSOR_PREFIX_MODEL # Our LLM wrapper stores TextModel under .model, + # so the full key is prefix.model.. + + +# -------- Utility -------- + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe_name = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe_name}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + """Apply deterministic, finite weights. Mirrors dump_reference.py's approach: + RMSNorm weights = 1, linears ~= Kaiming with a smaller gain. + """ + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() == 1: + # All 1D params: RMS weights, LN weights, biases. Gemma uses RMS with + # `weight = zeros` + `(1 + weight)` pattern; see below — keep at 0 so + # effective scale is 1.0 at init. + p.zero_() + elif p.dim() == 2: + fan_in = p.shape[1] + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + else: + p.normal_(mean=0.0, std=0.02, generator=g) + + +# -------- Main -------- + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + # Real Gemma 3 12B config has rope_scaling={"rope_type": "linear", "factor": 8.0} + # applied to full_attention layers only (HuggingFace gemma3 config.json). Mirror that + # here in the deep variant so C++ parity actually exercises the scaling path. The + # tiny variant keeps scaling disabled (factor=1) for faster iteration / backward compat. + rope_scaling = {"rope_type": "linear", "factor": 8.0} if VARIANT == "deep" else None + config = Gemma3TextConfig( + vocab_size=VOCAB_SIZE, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + rms_norm_eps=RMS_EPS, + rope_theta=ROPE_THETA, + rope_local_base_freq=ROPE_LOCAL_THETA, + rope_scaling=rope_scaling, + sliding_window=SLIDING_WINDOW, + sliding_window_pattern=SLIDING_WINDOW_PATTERN, + max_position_embeddings=1024, + attention_bias=False, + attn_logit_softcapping=None, + final_logit_softcapping=None, + query_pre_attn_scalar=HEAD_DIM, # 1/sqrt(head_dim) scaling + hidden_activation="gelu_pytorch_tanh", + ) + + print("Config summary:") + print(f" layer_types: {config.layer_types}") + print(f" hidden_size: {config.hidden_size}") + print(f" head_dim: {config.head_dim}") + print(f" sliding_window: {config.sliding_window}") + + model = Gemma3TextModel(config) + model.eval() + tame_(model) + + # Fixed input ids. + rng = np.random.default_rng(SEED) + input_ids = torch.tensor(rng.integers(low=0, high=VOCAB_SIZE, size=(BATCH, SEQ_LEN)), dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + manifest = Manifest() + save_tensor(input_ids.to(torch.float32), f"{TENSOR_TAG_PREFIX}input_ids", manifest) # store as f32 for simplicity + + # Forward with output_hidden_states=True to capture every layer. + with torch.no_grad(): + out = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # out.hidden_states is a tuple: (embedding_out, layer_0_out, layer_1_out, ..., layer_N_out) + # So len == num_layers + 1. First element is post-embed-scale. + hidden_states = out.hidden_states + assert len(hidden_states) == NUM_LAYERS + 1, f"got {len(hidden_states)} hidden states" + + for i, h in enumerate(hidden_states): + tag = f"{TENSOR_TAG_PREFIX}hs_{i:02d}" if i > 0 else f"{TENSOR_TAG_PREFIX}hs_embed" + save_tensor(h, tag, manifest) + if VARIANT == "tiny" or i % 4 == 0 or i == len(hidden_states) - 1: + # Keep logs short for the deep variant. + print(f" {tag}: shape={tuple(h.shape)} mean={h.mean().item():.4f} std={h.std().item():.4f}") + + # Final norm'd output (post model.norm, which is Gemma's output rms). + save_tensor(out.last_hidden_state, f"{TENSOR_TAG_PREFIX}last_hidden_state", manifest) + + # Stacked all-(N+1)-layer tensor as LTX-2 consumes it: + # torch.stack(hidden_states, dim=-1) -> [B, T, H, N+1] + stacked = torch.stack(hidden_states, dim=-1) + save_tensor(stacked, f"{TENSOR_TAG_PREFIX}all_layers_stacked", manifest) + print(f" all_layers_stacked: shape={tuple(stacked.shape)}") + + # Write state dict with our C++-side prefix convention. + # If a prior run (e.g. "tiny" → then "deep") already wrote a state_dict with a + # different prefix, merge instead of overwriting so both variants live in one + # safetensors file and the C++ test can load either config path on demand. + state_dict = model.state_dict() + prefixed = {f"{TENSOR_PREFIX}.{k}": v.to(torch.float32).contiguous() for k, v in state_dict.items()} + sd_path = OUT_DIR / "state_dict.safetensors" + if sd_path.exists(): + try: + from safetensors.torch import load_file + existing = load_file(str(sd_path)) + for k, v in existing.items(): + if k.startswith(f"{TENSOR_PREFIX}."): + continue # replace our own prefix on re-run + prefixed[k] = v + print(f" merged {len(existing)} existing tensors into new state_dict") + except Exception as e: + print(f" warning: could not merge existing state_dict ({e}); overwriting") + save_file(prefixed, str(sd_path)) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(prefixed.keys())) + "\n") + manifest.dump(OUT_DIR / f"manifest{'_deep' if VARIANT == 'deep' else ''}.json") + + # Also dump config JSON so C++ side can cross-check shapes if needed. + (OUT_DIR / "config.json").write_text(json.dumps({ + "num_layers": NUM_LAYERS, + "hidden_size": HIDDEN_SIZE, + "num_heads": NUM_HEADS, + "num_kv_heads": NUM_KV_HEADS, + "head_dim": HEAD_DIM, + "intermediate_size": INTERMEDIATE_SIZE, + "vocab_size": VOCAB_SIZE, + "rms_norm_eps": RMS_EPS, + "sliding_window": SLIDING_WINDOW, + "sliding_window_pattern": SLIDING_WINDOW_PATTERN, + "rope_theta_global": ROPE_THETA, + "rope_theta_local": ROPE_LOCAL_THETA, + "seq_len": SEQ_LEN, + "batch": BATCH, + "embed_scale": math.sqrt(HIDDEN_SIZE), + "layer_types": config.layer_types, + "tensor_prefix": TENSOR_PREFIX, + }, indent=2)) + + print(f"\nDone. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(prefixed)} keys → {OUT_DIR}/state_dict.safetensors") + print(f"Manifest: {OUT_DIR}/manifest.json") + print(f"Name list: {OUT_DIR}/tensor_names.txt") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_reference.py b/tests/ltx_parity/dump_reference.py new file mode 100644 index 000000000..97f8476b9 --- /dev/null +++ b/tests/ltx_parity/dump_reference.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +"""Dump LTX-2 reference tensors for C++/GGML parity testing. + +Strategy: instantiate a TINY LTX-2 model (2 layers, small dims) with deterministic +random weights, run a single forward pass on fixed inputs, and write every intermediate +tensor (post-each-block, post-AdaLN, post-patchify, final output) to +/tmp/ltx_ref/tensors/ as raw fp32 bytes. Also dump the state_dict as safetensors so the +C++ side can load the exact same weights. + +Usage: + /home/ilintar/venv/bin/python dump_reference.py + +Outputs: + /tmp/ltx_ref/manifest.json -- catalogue of every dumped tensor + /tmp/ltx_ref/state_dict.safetensors -- model weights + /tmp/ltx_ref/tensors/*.bin -- raw fp32 bytes, one file per tensor + /tmp/ltx_ref/tensor_names.txt -- state_dict.keys() for name-mapping verification +""" + +from __future__ import annotations + +import json +import os +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file + +from ltx_core.components.schedulers import LTX2Scheduler +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.attention import Attention, AttentionFunction +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.model import LTXModel, LTXModelType +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import ( + LTXRopeType, + apply_rotary_emb, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) +from ltx_core.model.transformer.timestep_embedding import ( + PixArtAlphaCombinedTimestepSizeEmbeddings, +) +from ltx_core.guidance.perturbations import BatchedPerturbationConfig + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/ltx_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny model config — deliberately small so every tensor is cheap to dump. +INNER_DIM = 128 +NUM_HEADS = 4 +HEAD_DIM = 32 # NUM_HEADS * HEAD_DIM = INNER_DIM +NUM_LAYERS = 2 +IN_CHANNELS = 16 +OUT_CHANNELS = 16 +CROSS_ATTN_DIM = 128 # keep == INNER_DIM to avoid needing caption_projection for now +NORM_EPS = 1e-6 + +# Toy latent (F, H, W) — small but with at least 2 frames to exercise temporal axis. +F_LAT, H_LAT, W_LAT = 2, 4, 6 +BATCH = 1 +FPS = 24.0 + +# Synthetic text context. +CONTEXT_LEN = 8 + + +# -------- Utility -------- + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, tensor: torch.Tensor, notes: str = ""): + t = tensor.detach().to(torch.float32).contiguous().cpu() + # Flatten name → filename by replacing '/' with '__' so everything lives in one dir. + fname = name.replace("/", "__") + ".bin" + path = TENSOR_DIR / fname + path.write_bytes(t.numpy().tobytes()) + self.entries.append( + { + "name": name, + "shape": list(t.shape), + "dtype": "float32", + "nbytes": t.numel() * 4, + "path": str(path.relative_to(OUT_DIR)), + "notes": notes, + } + ) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def seeded_randn(shape, seed_offset=0): + g = torch.Generator().manual_seed(SEED + seed_offset) + return torch.randn(shape, generator=g, dtype=torch.float32) + + +# -------- Dumpers -------- + + +def dump_rope(): + """Dump RoPE freqs_cis + apply_rotary_emb result for a known grid.""" + # 3D positions, middle-grid form: shape [B, n_pos_dims, T, 2] with (start, end) pairs. + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + # Time axis divided by fps per ltx_pipelines/utils/tools.py:135. + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + cos, sin = precompute_freqs_cis( + positions, + dim=INNER_DIM, + out_dtype=torch.float32, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=NUM_HEADS, + rope_type=LTXRopeType.SPLIT, + freq_grid_generator=generate_freq_grid_pytorch, + ) + + # Apply to a known q tensor so we can diff both the pe itself and the post-rotation output. + q = seeded_randn((BATCH, T, INNER_DIM), seed_offset=100) + q_rot = apply_rotary_emb(q, (cos, sin), LTXRopeType.SPLIT) + + m = {} + m["rope/positions"] = positions + m["rope/cos"] = cos + m["rope/sin"] = sin + m["rope/q_in"] = q + m["rope/q_rotated"] = q_rot + return m + + +def dump_scheduler(): + """LTX2Scheduler output for a few representative configurations. + Keys: 'schedule/tokens_{N}_steps_{S}' → sigma array of length S+1. + """ + scheduler = LTX2Scheduler() + cases = [ + # (tokens, steps, stretch, terminal) + (1024, 10, True, 0.1), # small latent (BASE_SHIFT anchor) + (1024, 30, True, 0.1), + (4096, 10, True, 0.1), # MAX_SHIFT anchor + (4096, 40, True, 0.1), # typical LTX-2 default + (2560, 30, True, 0.1), # interpolated + (4096, 8, False, 0.1), # no stretch path + ] + out = {} + for tokens, steps, stretch, terminal in cases: + # LTX2Scheduler expects a `latent` tensor to derive tokens from shape[2:]. + # Fake one with product(shape[2:]) == tokens. + fake_latent = torch.zeros(1, 1, tokens) + sigmas = scheduler.execute( + steps=steps, latent=fake_latent, + max_shift=2.05, base_shift=0.95, + stretch=stretch, terminal=terminal, + ) + key = f"schedule/tokens{tokens}_steps{steps}_stretch{int(stretch)}" + out[key] = sigmas.detach().float() + return out + + +def dump_adaln(): + """AdaLayerNormSingle: t → (modulation[B, coeff, dim], embedded[B, dim]).""" + torch.manual_seed(SEED + 2) + adaln = AdaLayerNormSingle(embedding_dim=INNER_DIM, embedding_coefficient=6).eval() + + # Fixed timestep σ ∈ (0, 1). Python applies *1000 externally; mirror that here. + sigma = torch.tensor([0.42], dtype=torch.float32) + t_scaled = sigma * 1000.0 + + with torch.no_grad(): + modulation, embedded = adaln(t_scaled, hidden_dtype=torch.float32) + + # Extract sub-weights for loading into C++. The isolated AdaLN test weights are not loaded into + # the full LTXRunner, so the prefix only needs to be unique w.r.t. the full-model weights. + sd = {f"adaln_standalone.{k}": v.detach().float() for k, v in adaln.state_dict().items()} + + return { + "adaln/sigma": sigma, + "adaln/t_scaled": t_scaled, + "adaln/modulation": modulation, + "adaln/embedded": embedded, + }, sd + + +def dump_full_model(): + """Tiny LTXModel (VideoOnly) forward, dumping per-block outputs.""" + torch.manual_seed(SEED + 3) + + # Stash a helper to tame magnitudes for parity testing. With default init, scale_shift_table is + # torch.empty(...) (uninitialised memory — random garbage) and many Linears have Kaiming init + # which, compounded across blocks with AdaLN * (1 + scale) + shift modulation, produces values + # that overflow fp32 (output becomes NaN). We don't care about the semantics of the weights — + # only that C++ and Python compute the SAME function on the SAME weights — so we replace them + # with bounded random values post-construction. + + model = LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=NUM_HEADS, + attention_head_dim=HEAD_DIM, + in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + num_layers=NUM_LAYERS, + cross_attention_dim=CROSS_ATTN_DIM, + norm_eps=NORM_EPS, + attention_type=AttentionFunction.PYTORCH, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + timestep_scale_multiplier=1000, + use_middle_indices_grid=True, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=False, + apply_gated_attention=False, + caption_projection=None, # cross_attention_dim == inner_dim so no projection needed + cross_attention_adaln=False, + ).eval() + + # Tame weights: small random gaussians with scale 1/sqrt(dim), scale_shift_table zeroed so the + # forward is well-conditioned even with two randomly-initialised blocks stacked. + with torch.no_grad(): + for name, p in model.named_parameters(): + if "scale_shift_table" in name: + p.zero_() + continue + # q_norm / k_norm RMSNorm weights must be ~1 to act as normalisers (not kill signal). + if name.endswith("q_norm.weight") or name.endswith("k_norm.weight"): + p.fill_(1.0) + continue + if p.dim() == 1: + # biases → zero + p.zero_() + else: + # Kaiming-ish: std ~ 1/sqrt(fan_in) + fan_in = p.shape[1] if p.dim() >= 2 else p.numel() + p.normal_(0.0, 1.0 / (fan_in ** 0.5)) + + # Synthetic inputs. + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + latent = seeded_randn((BATCH, IN_CHANNELS, F, H, W), seed_offset=200) + sigma = torch.tensor([0.5], dtype=torch.float32) + context = seeded_randn((BATCH, CONTEXT_LEN, CROSS_ATTN_DIM), seed_offset=300) + + # Build positions in (B, n_pos_dims, T, 2) middle-grid form. + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + # LTX's Modality carries the latent pre-patchify (shape [B, C, F, H, W] → flat [B, T, C]). + # patchify_proj is Linear(in_channels, inner_dim) so we need [B, T, C] for input. + latent_flat = latent.permute(0, 2, 3, 4, 1).reshape(BATCH, T, IN_CHANNELS) + + # For pure T2V, per-token timesteps = sigma broadcast. + timesteps = sigma.view(BATCH, 1).expand(BATCH, T).contiguous() + + # Positions shape the preprocessor wants is [B, 3, T] (no middle-grid pair dim) when + # use_middle_indices_grid=False, or [B, 3, T, 2] when True. Our positions are already the + # [B, 3, T, 2] form. Good. + modality = Modality( + latent=latent_flat, + sigma=sigma, + timesteps=timesteps, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + + # Instrument: intercept transformer_blocks outputs so we can dump per-block. + per_block_outputs = {} + orig_forwards = [] + for i, blk in enumerate(model.transformer_blocks): + orig = blk.forward + orig_forwards.append(orig) + + def make_capture(idx, original): + def capture(video=None, audio=None, perturbations=None): + out_video, out_audio = original(video=video, audio=audio, perturbations=perturbations) + per_block_outputs[f"block_{idx:02d}_out"] = out_video.x.detach().float().clone() + return out_video, out_audio + return capture + + blk.forward = make_capture(i, orig) + + with torch.no_grad(): + vx, _ = model(video=modality, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + # Also capture post-patchify result by running patchify_proj manually (same computation). + with torch.no_grad(): + patchified = model.patchify_proj(latent_flat) + tm_mod, tm_embedded = model.adaln_single(timesteps.flatten() * 1000.0, hidden_dtype=torch.float32) + tm_mod = tm_mod.view(BATCH, -1, tm_mod.shape[-1]) + tm_embedded = tm_embedded.view(BATCH, -1, tm_embedded.shape[-1]) + + # Also save the unflattened latent in [C, F, H, W] order (batch=1 squeezed). + # Memory layout: W innermost → matches ggml ne=[W, H, F, C] which is what LTXRunner::build_graph + # expects at its entry point. Convert by squeezing batch dim from the original [B=1, C, F, H, W]. + latent_unflat = latent.squeeze(0) # [C, F, H, W] + + # Velocity output comes out of the Python model as [B, T, C=out_channels]. Also save the + # unflattened [C, F, H, W] form so C++ can compare without reshaping. + vx_unflat = vx.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) # [C, F, H, W] + + tensors = { + "model/latent_in": latent_flat, + "model/latent_unflat": latent_unflat, + "model/sigma": sigma, + "model/timesteps_per_token": timesteps, + "model/context_in": context, + "model/positions": positions, + "model/patchify_out": patchified, + "model/adaln_modulation": tm_mod, + "model/adaln_embedded_timestep": tm_embedded, + "model/velocity_out": vx, + "model/velocity_out_unflat": vx_unflat, + } + for k, v in per_block_outputs.items(): + tensors[f"model/{k}"] = v + + # Use the sd.cpp convention: DiT weights live under "model.diffusion_model.". + # Pairs with LTXRunner's default prefix so the C++ loader reads names verbatim. + sd = {f"model.diffusion_model.{k}": v.detach().float() for k, v in model.state_dict().items()} + + # --- Single-step Euler parity ---------------------------------------------------------------- + # Starting from the noisy latent + the same velocity we just computed, run ONE deterministic + # Euler step using the LTX2Scheduler with 10 steps at σ=0.5 (which falls between sigmas[k] + # and sigmas[k+1] for some k — we pick the step endpoints manually so C++ gets exact inputs). + # This validates the (σ_next - σ) * v formula through the denoiser↔DiT integration boundary. + sched = LTX2Scheduler() + sched_sigmas = sched.execute(steps=10, latent=torch.zeros(1, 1, T), stretch=True, terminal=0.1) + + # Pick one adjacent sigma pair. sigmas[4] is reasonably mid-trajectory for 10 steps. + step_idx = 4 + sigma_cur = sched_sigmas[step_idx].item() + sigma_next = sched_sigmas[step_idx + 1].item() + + # The model was just run at σ=0.5; for the Euler test, re-run at σ_cur (a schedule value). + # The `vx` we already have is at σ=0.5 which doesn't match; redo the forward with sigma_cur. + timesteps_step = torch.tensor([sigma_cur], dtype=torch.float32).view(BATCH, 1).expand(BATCH, T).contiguous() + modality_step = Modality( + latent=latent_flat, + sigma=torch.tensor([sigma_cur], dtype=torch.float32), + timesteps=timesteps_step, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + with torch.no_grad(): + v_step, _ = model(video=modality_step, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + # Euler step: x_next = x + (σ_next - σ) * v (LTX-2 predicts velocity directly). + x_next = latent_flat + (sigma_next - sigma_cur) * v_step + + # Also dump the unflattened form for C++ convenience. + x_next_unflat = x_next.reshape(BATCH, F, H, W, IN_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) # [C, F, H, W] + v_step_unflat = v_step.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) + + tensors["euler/sigma_cur"] = torch.tensor([sigma_cur], dtype=torch.float32) + tensors["euler/sigma_next"] = torch.tensor([sigma_next], dtype=torch.float32) + tensors["euler/v_step"] = v_step + tensors["euler/v_step_unflat"] = v_step_unflat + tensors["euler/x_next"] = x_next + tensors["euler/x_next_unflat"] = x_next_unflat + + return tensors, sd + + +def dump_full_model_v2(num_layers: int = NUM_LAYERS, + zero_scale_shift_table: bool = True, + prefix: str = "model.diffusion_model_v2", + tensor_prefix: str = "v2model", + seed_offset: int = 4): + """Tiny LTXModel (VideoOnly) with V2 features enabled: + - cross_attention_adaln=True (adds prompt_scale_shift_table, prompt_adaln_single, + extends scale_shift_table to 9 coeffs, routes CA through apply_cross_attention_adaln) + - apply_gated_attention=True (adds to_gate_logits on attn1 and attn2) + State-dict is saved under `prefix` so multiple variants can coexist in the same file. + + Args: + num_layers: how many transformer blocks to stack. Deeper values exercise + cross-layer drift (e.g. the real 22B DiT has 48). + zero_scale_shift_table: if False, initialise all scale_shift_table / + prompt_scale_shift_table weights with bounded random values so the + modulation path (AdaLN multiply/shift + CA mod) is actually exercised + — the default True path is too well-conditioned to surface sign/layout + bugs in the (1+scale) and shift-kv branches. + """ + torch.manual_seed(SEED + seed_offset) + + model = LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=NUM_HEADS, + attention_head_dim=HEAD_DIM, + in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + num_layers=num_layers, + cross_attention_dim=CROSS_ATTN_DIM, + norm_eps=NORM_EPS, + attention_type=AttentionFunction.PYTORCH, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + timestep_scale_multiplier=1000, + use_middle_indices_grid=True, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=False, + apply_gated_attention=True, + caption_projection=None, + cross_attention_adaln=True, + ).eval() + + # Tame weights, same recipe as V1. The sst branch is optional: leaving it zero + # means AdaLN modulation degenerates to identity, which hides bugs in the + # (1 + scale) path and the CA-AdaLN shift_kv / scale_kv broadcast. + with torch.no_grad(): + for name, p in model.named_parameters(): + if "scale_shift_table" in name: + if zero_scale_shift_table: + p.zero_() + else: + # Keep magnitudes small so the stacked modulation doesn't explode + # across layers. scale_shift_table rows are added to a (0, 1]-ish + # AdaLN output; 0.05 keeps the post-modulation scale in ~[0.95, 1.05]. + p.normal_(0.0, 0.05) + continue + if name.endswith("q_norm.weight") or name.endswith("k_norm.weight"): + p.fill_(1.0) + continue + if p.dim() == 1: + p.zero_() + else: + fan_in = p.shape[1] if p.dim() >= 2 else p.numel() + p.normal_(0.0, 1.0 / (fan_in ** 0.5)) + + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + latent = seeded_randn((BATCH, IN_CHANNELS, F, H, W), seed_offset=400) + sigma = torch.tensor([0.5], dtype=torch.float32) + context = seeded_randn((BATCH, CONTEXT_LEN, CROSS_ATTN_DIM), seed_offset=500) + + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + latent_flat = latent.permute(0, 2, 3, 4, 1).reshape(BATCH, T, IN_CHANNELS) + timesteps = sigma.view(BATCH, 1).expand(BATCH, T).contiguous() + modality = Modality( + latent=latent_flat, + sigma=sigma, + timesteps=timesteps, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + + per_block_outputs = {} + for i, blk in enumerate(model.transformer_blocks): + orig = blk.forward + + def make_capture(idx, original): + def capture(video=None, audio=None, perturbations=None): + out_video, out_audio = original(video=video, audio=audio, perturbations=perturbations) + per_block_outputs[f"block_{idx:02d}_out"] = out_video.x.detach().float().clone() + return out_video, out_audio + return capture + + blk.forward = make_capture(i, orig) + + with torch.no_grad(): + vx, _ = model(video=modality, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + with torch.no_grad(): + patchified = model.patchify_proj(latent_flat) + tm_mod, tm_embedded = model.adaln_single(timesteps.flatten() * 1000.0, hidden_dtype=torch.float32) + tm_mod = tm_mod.view(BATCH, -1, tm_mod.shape[-1]) + tm_embedded = tm_embedded.view(BATCH, -1, tm_embedded.shape[-1]) + # V2 extra: prompt_adaln output driven by sigma (× scale_mult = 1000). + p_mod, _ = model.prompt_adaln_single( + (sigma * 1000.0).flatten(), hidden_dtype=torch.float32 + ) + p_mod = p_mod.view(BATCH, -1, p_mod.shape[-1]) + + latent_unflat = latent.squeeze(0) + vx_unflat = vx.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) + + tensors = { + f"{tensor_prefix}/latent_in": latent_flat, + f"{tensor_prefix}/latent_unflat": latent_unflat, + f"{tensor_prefix}/sigma": sigma, + f"{tensor_prefix}/timesteps_per_token": timesteps, + f"{tensor_prefix}/context_in": context, + f"{tensor_prefix}/positions": positions, + f"{tensor_prefix}/patchify_out": patchified, + f"{tensor_prefix}/adaln_modulation": tm_mod, + f"{tensor_prefix}/adaln_embedded_timestep": tm_embedded, + f"{tensor_prefix}/prompt_modulation": p_mod, + f"{tensor_prefix}/velocity_out": vx, + f"{tensor_prefix}/velocity_out_unflat": vx_unflat, + } + for k, v in per_block_outputs.items(): + tensors[f"{tensor_prefix}/{k}"] = v + + sd = {f"{prefix}.{k}": v.detach().float() for k, v in model.state_dict().items()} + return tensors, sd + + +# -------- Main -------- + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + + torch.use_deterministic_algorithms(False) # some ops (layernorm) aren't deterministic + torch.manual_seed(SEED) + + manifest = Manifest() + state_dict: Dict[str, torch.Tensor] = {} + + print("[1/4] RoPE …") + for name, t in dump_rope().items(): + manifest.add(name, t) + + print("[2/4] LTX2Scheduler …") + for name, t in dump_scheduler().items(): + manifest.add(name, t) + + print("[3/4] AdaLayerNormSingle …") + adaln_tensors, adaln_sd = dump_adaln() + for name, t in adaln_tensors.items(): + manifest.add(name, t) + state_dict.update(adaln_sd) + + print("[4/5] Full LTXModel (tiny, V1) …") + model_tensors, model_sd = dump_full_model() + for name, t in model_tensors.items(): + manifest.add(name, t) + state_dict.update(model_sd) + + print("[5/6] Full LTXModel (tiny, V2: cross_attention_adaln + apply_gated_attention) …") + model_v2_tensors, model_v2_sd = dump_full_model_v2() + for name, t in model_v2_tensors.items(): + manifest.add(name, t) + state_dict.update(model_v2_sd) + + # Deep V2: 8 layers + non-zero scale_shift_table so accumulated modulation drift + # surfaces. The original V2 dump is too gentle (only 2 layers, zeroed sst) to + # catch bugs that only matter when modulation is non-trivial. + print("[6/6] Full LTXModel (tiny, V2-deep: 8 layers, non-zero scale_shift_table) …") + v2_deep_tensors, v2_deep_sd = dump_full_model_v2( + num_layers=8, + zero_scale_shift_table=False, + prefix="model.diffusion_model_v2_deep", + tensor_prefix="v2deep", + seed_offset=7, + ) + for name, t in v2_deep_tensors.items(): + manifest.add(name, t) + state_dict.update(v2_deep_sd) + + # Safetensors requires contiguous CPU tensors. + sd_contig = {k: v.contiguous().cpu() for k, v in state_dict.items()} + save_file(sd_contig, str(OUT_DIR / "state_dict.safetensors")) + + manifest_path = OUT_DIR / "manifest.json" + manifest.dump(manifest_path) + + with (OUT_DIR / "tensor_names.txt").open("w") as f: + for name in sorted(state_dict.keys()): + t = state_dict[name] + f.write(f"{name}\t{list(t.shape)}\t{t.dtype}\n") + + print(f"Done. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(state_dict)} keys → {OUT_DIR}/state_dict.safetensors") + print(f"Manifest: {manifest_path}") + print(f"Name inventory: {OUT_DIR}/tensor_names.txt") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_s2d.py b/tests/ltx_parity/dump_s2d.py new file mode 100644 index 000000000..6a2f395c9 --- /dev/null +++ b/tests/ltx_parity/dump_s2d.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""Dump reference space-to-depth / depth-to-space outputs along each of the three +stride axes (W, H, T) as standalone test vectors. + +Each case applies a single-axis stride=2 split (the building block that will be +composed to give full 3D SpaceToDepth). We dump both the input and the expected +output so the C++ side can verify its ggml reshape+permute chain byte-exact. + +Output: /tmp/s2d_ref/tensors/*.bin + manifest.json + config.json. +Usage: /home/ilintar/venv/bin/python dump_s2d.py +""" + +from __future__ import annotations + +import json +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from einops import rearrange + +OUT_DIR = pathlib.Path("/tmp/s2d_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Distinct primes where possible so any mis-axis mixup shows up immediately. +B, C, T, H, W = 1, 3, 4, 6, 8 # after: (T/2, H, W), (T, H/2, W), (T, H, W/2) per case +FACTOR = 2 + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + def add(self, name, t): self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + def dump(self, p): p.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save(t: torch.Tensor, name: str, mf: Manifest): + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{name}.bin") + mf.add(name, t) + + +def s2d_W(x: torch.Tensor, p3: int) -> torch.Tensor: + # [B, C, T, H, W*p3] -> [B, C*p3, T, H, W] + return rearrange(x, "b c t h (w p3) -> b (c p3) t h w", p3=p3) + + +def s2d_H(x: torch.Tensor, p2: int) -> torch.Tensor: + # [B, C, T, H*p2, W] -> [B, C*p2, T, H, W] + return rearrange(x, "b c t (h p2) w -> b (c p2) t h w", p2=p2) + + +def s2d_T(x: torch.Tensor, p1: int) -> torch.Tensor: + # [B, C, T*p1, H, W] -> [B, C*p1, T, H, W] + return rearrange(x, "b c (t p1) h w -> b (c p1) t h w", p1=p1) + + +def s2d_full(x: torch.Tensor, p1: int, p2: int, p3: int) -> torch.Tensor: + # [B, C, T*p1, H*p2, W*p3] -> [B, C*p1*p2*p3, T, H, W] + return rearrange(x, "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=p1, p2=p2, p3=p3) + + +def d2s_W(x: torch.Tensor, p3: int) -> torch.Tensor: + return rearrange(x, "b (c p3) t h w -> b c t h (w p3)", p3=p3) + + +def d2s_H(x: torch.Tensor, p2: int) -> torch.Tensor: + return rearrange(x, "b (c p2) t h w -> b c t (h p2) w", p2=p2) + + +def d2s_T(x: torch.Tensor, p1: int) -> torch.Tensor: + return rearrange(x, "b (c p1) t h w -> b c (t p1) h w", p1=p1) + + +def d2s_full(x: torch.Tensor, p1: int, p2: int, p3: int) -> torch.Tensor: + return rearrange(x, "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=p1, p2=p2, p3=p3) + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + mf = Manifest() + torch.manual_seed(0) + + # --- Axis-W primitive --- + x_w = torch.randn(B, C, T, H, W * FACTOR) + save(x_w, "input_axisW", mf) + save(s2d_W(x_w, FACTOR), "expected_axisW", mf) + + # --- Axis-H primitive --- + x_h = torch.randn(B, C, T, H * FACTOR, W) + save(x_h, "input_axisH", mf) + save(s2d_H(x_h, FACTOR), "expected_axisH", mf) + + # --- Axis-T primitive --- + x_t = torch.randn(B, C, T * FACTOR, H, W) + save(x_t, "input_axisT", mf) + save(s2d_T(x_t, FACTOR), "expected_axisT", mf) + + # --- Full 3D (stride=(2,2,2)) composition --- + x_all = torch.randn(B, C, T * FACTOR, H * FACTOR, W * FACTOR) + save(x_all, "input_full222", mf) + save(s2d_full(x_all, FACTOR, FACTOR, FACTOR), "expected_full222", mf) + + # --- Stride=(1,2,2) (what compress_space_res uses) --- + x_122 = torch.randn(B, C, T, H * FACTOR, W * FACTOR) + save(x_122, "input_full122", mf) + save(s2d_full(x_122, 1, FACTOR, FACTOR), "expected_full122", mf) + + # --- Stride=(2,1,1) (compress_time_res) --- + x_211 = torch.randn(B, C, T * FACTOR, H, W) + save(x_211, "input_full211", mf) + save(s2d_full(x_211, FACTOR, 1, 1), "expected_full211", mf) + + # --- DepthToSpace (single-axis + composed) --- + # Input for axis primitives: [B, C_large, T, H, W] where C_large = C * factor. + dx_w = torch.randn(B, C * FACTOR, T, H, W) + save(dx_w, "dinput_axisW", mf) + save(d2s_W(dx_w, FACTOR), "dexpected_axisW", mf) + + dx_h = torch.randn(B, C * FACTOR, T, H, W) + save(dx_h, "dinput_axisH", mf) + save(d2s_H(dx_h, FACTOR), "dexpected_axisH", mf) + + dx_t = torch.randn(B, C * FACTOR, T, H, W) + save(dx_t, "dinput_axisT", mf) + save(d2s_T(dx_t, FACTOR), "dexpected_axisT", mf) + + dx_222 = torch.randn(B, C * (FACTOR ** 3), T, H, W) + save(dx_222, "dinput_full222", mf) + save(d2s_full(dx_222, FACTOR, FACTOR, FACTOR), "dexpected_full222", mf) + + dx_122 = torch.randn(B, C * (FACTOR ** 2), T, H, W) + save(dx_122, "dinput_full122", mf) + save(d2s_full(dx_122, 1, FACTOR, FACTOR), "dexpected_full122", mf) + + dx_211 = torch.randn(B, C * FACTOR, T, H, W) + save(dx_211, "dinput_full211", mf) + save(d2s_full(dx_211, FACTOR, 1, 1), "dexpected_full211", mf) + + # --- PixelNorm (dim=1 RMS) --- + eps = 1e-8 + pn_in = torch.randn(B, 5, T, H, W) # C=5 to exercise a non-power-of-2 channel + pn_out = pn_in / torch.sqrt((pn_in ** 2).mean(dim=1, keepdim=True) + eps) + save(pn_in, "pn_input", mf) + save(pn_out, "pn_expected", mf) + + # --- PerChannelStatistics --- + # Random mu and sigma (sigma > 0). Buffers shape [C] as in the real VAE. + c_pcs = 6 + pcs_in = torch.randn(B, c_pcs, T, H, W) + pcs_mu = torch.randn(c_pcs) + pcs_sigma = torch.rand(c_pcs) + 0.5 # keep away from zero + save(pcs_in, "pcs_input", mf) + save(pcs_mu, "pcs_mu", mf) + save(pcs_sigma, "pcs_sigma", mf) + save((pcs_in - pcs_mu.view(1, c_pcs, 1, 1, 1)) / pcs_sigma.view(1, c_pcs, 1, 1, 1), + "pcs_normalize_expected", mf) + save((pcs_in * pcs_sigma.view(1, c_pcs, 1, 1, 1)) + pcs_mu.view(1, c_pcs, 1, 1, 1), + "pcs_unnormalize_expected", mf) + + mf.dump(OUT_DIR / "manifest.json") + (OUT_DIR / "config.json").write_text(json.dumps({ + "B": B, "C": C, "T": T, "H": H, "W": W, "FACTOR": FACTOR, + "pn_C": 5, "pn_eps": eps, + "pcs_C": c_pcs, + }, indent=2)) + print(f"wrote {len(mf.entries)} tensors under {OUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_vae.py b/tests/ltx_parity/dump_vae.py new file mode 100644 index 000000000..f8f48a9dd --- /dev/null +++ b/tests/ltx_parity/dump_vae.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +"""Dump tiny LTX-2 VAE reference tensors for C++/GGML parity testing. + +Strategy mirrors dump_reference.py / dump_gemma.py: instantiate a tiny VideoEncoder +and VideoDecoder with deterministic tamed weights, run one forward pass each on +fixed inputs, and save per-block intermediate outputs and the state_dict. + +Tiny config exercises one of each encoder block type (compress_space_res, +compress_time_res, res_x) and a matching decoder (res_x, compress_time, compress_space) +plus the output AdaLN + PerChannelStatistics (un)normalize. + +Usage: + /home/ilintar/venv/bin/python dump_vae.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch + +from safetensors.torch import save_file + +from ltx_core.model.video_vae.video_vae import VideoEncoder, VideoDecoder +from ltx_core.model.video_vae.enums import NormLayerType, LogVarianceType, PaddingModeType + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/vae_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny VAE config. patch_size=2 (vs standard 4) to keep spatial dims small. +# encoder: compress_space_res(×2 ch) then compress_time_res(×2 ch) then res_x(1 layer). +# decoder: res_x(1 layer), compress_time, compress_space (reversed during construction). +IN_CHANNELS = 3 +LATENT_CHANNELS = 8 +DECODER_BASE_CH = 8 # decoder conv_in goes 128 -> 8 * 8 = 64 (with *8 multiplier) +PATCH_SIZE = 2 +NORM_LAYER = NormLayerType.PIXEL_NORM +LOG_VAR = LogVarianceType.UNIFORM +PADDING_ENC = PaddingModeType.ZEROS +PADDING_DEC = PaddingModeType.REFLECT + +# Video shape: 1 + 8*k frames required by encoder's validator. 1 + 8*1 = 9 → F=9. +# Spatial must divide by (patch_size * 2 * 2) = 8 for one compress_space_res + one compress_time_res. +# H = W = 16 is the minimum that divides 8 cleanly after patchify. +BATCH, F_IN, H_IN, W_IN = 1, 9, 16, 16 + +DECODE_TIMESTEP = 0.05 # Gemma/LTX-2 conventional decoder timestep + + +# -------- Utility -------- + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + """Deterministic, finite weights. Reuses the pattern from dump_reference.py. + + - 1D params (biases, norm weights, scale_shift_tables, per_channel_scale*): + zero-initialized. Gemma-style convention where (1+w) is used as the effective + scale works the same for VAE's ResnetBlock3D AdaLN (hidden * (1 + scale) + shift). + - 2D/3D/4D/5D params (linears, convs): Kaiming-ish with std=1/sqrt(fan_in). + - PerChannelStatistics buffers: std-of-means = 1.0, mean-of-means = 0.0 so + normalize/un_normalize become identity + scale. + """ + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() <= 1: + # 0-D (scalars like timestep_scale_multiplier) or 1-D (biases, norm weights, + # scale_shift_tables, per_channel_scale*) — zero init. + p.zero_() + # `timestep_scale_multiplier` needs to be exactly the trained value (1000.0) + # to match the denoiser scale, but a zero multiplier just zeroes the timestep + # embedding; the parity path still exercises the rest of the block math. + else: + fan_in = max(1, p.numel() // p.shape[0]) + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + # PerChannelStatistics is a buffer (not a parameter), init separately: + for name, buf in model.named_buffers(): + if "std-of-means" in name: + buf.fill_(1.0) + elif "mean-of-means" in name: + buf.fill_(0.0) + + +def build_encoder() -> VideoEncoder: + return VideoEncoder( + convolution_dimensions=3, + in_channels=IN_CHANNELS, + out_channels=LATENT_CHANNELS, + encoder_blocks=[ + ("compress_space_res", {"multiplier": 2}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ], + patch_size=PATCH_SIZE, + norm_layer=NORM_LAYER, + latent_log_var=LOG_VAR, + encoder_spatial_padding_mode=PADDING_ENC, + ) + + +def build_decoder() -> VideoDecoder: + # Encoder reduces temporal by 2 and spatial by patch_size*2 = 4. Decoder matches inverse. + return VideoDecoder( + convolution_dimensions=3, + in_channels=LATENT_CHANNELS, + out_channels=IN_CHANNELS, + decoder_blocks=[ + # order is reversed inside VideoDecoder ctor — this is the encoder-side order: + ("compress_space", {"multiplier": 1}), + ("compress_time", {"multiplier": 1}), + ("res_x", {"num_layers": 1}), + ], + patch_size=PATCH_SIZE, + norm_layer=NORM_LAYER, + causal=False, + timestep_conditioning=True, + decoder_spatial_padding_mode=PADDING_DEC, + base_channels=DECODER_BASE_CH, + ) + + +# -------- Main -------- + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + encoder = build_encoder().eval() + decoder = build_decoder().eval() + + tame_(encoder) + tame_(decoder) + + # Input video: (B=1, C=3, F=9, H=16, W=16). + rng = np.random.default_rng(SEED) + video_np = rng.standard_normal((BATCH, IN_CHANNELS, F_IN, H_IN, W_IN), dtype=np.float32) + video = torch.from_numpy(video_np).clone() + + manifest = Manifest() + save_tensor(video, "video_in", manifest) + print(f"input: shape={tuple(video.shape)}") + + # --- Encoder forward --- + with torch.no_grad(): + x = video + # Replicate VideoEncoder.forward manually so we can cache intermediates. + from ltx_core.model.video_vae.ops import patchify + x = patchify(x, patch_size_hw=PATCH_SIZE, patch_size_t=1) + save_tensor(x, "enc_post_patchify", manifest) + + x = encoder.conv_in(x) + save_tensor(x, "enc_post_conv_in", manifest) + + for i, blk in enumerate(encoder.down_blocks): + x = blk(x) + save_tensor(x, f"enc_block_{i}", manifest) + + x = encoder.conv_norm_out(x) + save_tensor(x, "enc_post_norm", manifest) + x = encoder.conv_act(x) + x = encoder.conv_out(x) + save_tensor(x, "enc_post_conv_out", manifest) + + # Replicate UNIFORM latent_log_var path: means = x[:, :-1], logvar = x[:, -1:]. + if LOG_VAR == LogVarianceType.UNIFORM: + means = x[:, :-1, ...] + logvar = x[:, -1:, ...] + # (We save just the means and the final normalized latent; don't need logvar for parity.) + save_tensor(means, "enc_means_preNorm", manifest) + latent = encoder.per_channel_statistics.normalize(means) + save_tensor(latent, "latent", manifest) + else: + raise RuntimeError("only UNIFORM supported in dumper") + + print(f"latent: shape={tuple(latent.shape)} mean={latent.mean().item():.4f} std={latent.std().item():.4f}") + + # --- Decoder forward (deterministic path; no noise, fixed timestep) --- + timestep = torch.full((BATCH,), DECODE_TIMESTEP, dtype=torch.float32) + with torch.no_grad(): + y = decoder.per_channel_statistics.un_normalize(latent) + save_tensor(y, "dec_post_unnorm", manifest) + + # Match the real decoder.forward: self.causal=False is set by the configurator, + # so every conv call uses causal=False. Earlier versions of this dumper relied + # on the default causal=True which diverged from actual behavior and masked a + # conv1/conv2 mismatch in the C++ port. + y = decoder.conv_in(y, causal=False) + save_tensor(y, "dec_post_conv_in", manifest) + + # TimestepEmbedder probe: feed the exact `timestep` used below, save the 256-dim + # result so the C++ side can byte-diff its TimestepEmbedder output against Python's. + # Uses the inner time_embedder (embedding_dim=256) from the res_x block. + te_probe = decoder.up_blocks[0].time_embedder( + timestep=timestep.flatten(), hidden_dtype=y.dtype) + save_tensor(te_probe, "te_probe_up0", manifest) + + # Intermediate after each up_block (reversed decoder config). + # Probe INSIDE the first res_x block: dump the pixel_norm(conv_in_output) to verify + # the norm path is byte-exact. This is the Python `hidden_states = self.norm1(x)` + # inside the first ResnetBlock3D of up_blocks[0]. + from ltx_core.model.common.normalization import PixelNorm + probe_block = decoder.up_blocks[0].res_blocks[0] + y_pre = y # still conv_in output here; save a copy. + y_norm1 = probe_block.norm1(y_pre) + save_tensor(y_norm1, "dec_resblock0_post_norm1", manifest) + # Also save post_adaln1 (just the modulation, no silu/conv yet). + ts_embed_block = decoder.up_blocks[0].time_embedder( + timestep=timestep.flatten(), hidden_dtype=y_pre.dtype + ).view(BATCH, -1, 1, 1, 1) + ada_probe = probe_block.scale_shift_table[None, ..., None, None, None] + ts_embed_block.reshape( + BATCH, 4, -1, 1, 1, 1 + ) + sh1, sc1, sh2, sc2 = ada_probe.unbind(dim=1) + y_adaln1 = y_norm1 * (1 + sc1) + sh1 + save_tensor(y_adaln1, "dec_resblock0_post_adaln1", manifest) + y_silu1 = probe_block.non_linearity(y_adaln1) + y_conv1 = probe_block.conv1(y_silu1, causal=False) + save_tensor(y_conv1, "dec_resblock0_post_conv1", manifest) + y_norm2 = probe_block.norm2(y_conv1) + save_tensor(y_norm2, "dec_resblock0_post_norm2", manifest) + + # Build the timestep embedding that UNetMidBlock3D would use internally for + # the res_x block. The scale multiplier is a learned scalar (we zero-inited it). + # The parity comparison only verifies the *forward* math; if the multiplier is + # 0 then the time embedding collapses. That's fine since we verify tracewise. + # Inject a timestep only when calling the block (passed through forward()). + + # Replicate VideoDecoder.forward partial path. + # Important: the decoder's up_blocks list is REVERSED of the config list. + # Our config: [compress_space, compress_time, res_x]. After reversing: + # [res_x, compress_time, compress_space]. So up_blocks[0] is the res_x. + + # Timestep scale is used inside last_time_embedder; but res_x UNetMidBlock3D + # has its own time_embedder. Pass raw timestep; each block handles scaling. + + for i, blk in enumerate(decoder.up_blocks): + # Only res_x (UNetMidBlock3D) accepts timestep; up/down sample blocks don't. + from ltx_core.model.video_vae.resnet import UNetMidBlock3D + if isinstance(blk, UNetMidBlock3D): + y = blk(y, causal=False, timestep=timestep) + else: + y = blk(y, causal=False) + save_tensor(y, f"dec_block_{i}", manifest) + + # Final AdaLN output + conv_norm_out: this is the `last_scale_shift_table` + time_embedder path. + ada = decoder.last_scale_shift_table[None, ..., None, None, None] + decoder.last_time_embedder( + timestep=(timestep * decoder.timestep_scale_multiplier).flatten(), + hidden_dtype=y.dtype, + ).view(BATCH, 2, -1, 1, 1, 1) + shift, scale = ada.unbind(dim=1) + y = decoder.conv_norm_out(y) + save_tensor(y, "dec_post_pixel_norm", manifest) + y = y * (1 + scale) + shift + save_tensor(y, "dec_post_ada", manifest) + y = decoder.conv_act(y) + y = decoder.conv_out(y, causal=False) + save_tensor(y, "dec_post_conv_out", manifest) + + from ltx_core.model.video_vae.ops import unpatchify + y = unpatchify(y, patch_size_hw=PATCH_SIZE, patch_size_t=1) + save_tensor(y, "video_out", manifest) + + print(f"decoded: shape={tuple(y.shape)} mean={y.mean().item():.4f} std={y.std().item():.4f}") + + # --- State dict: concatenate encoder + decoder + per_channel_statistics under "vae." prefix. --- + prefixed = {} + for k, v in encoder.state_dict().items(): + prefixed[f"vae.encoder.{k}"] = v.to(torch.float32).contiguous() + for k, v in decoder.state_dict().items(): + prefixed[f"vae.decoder.{k}"] = v.to(torch.float32).contiguous() + # PerChannelStatistics is registered inside both encoder & decoder AND also dumped under + # a top-level `vae.per_channel_statistics.*` path (matching the real checkpoint convention, + # per VAE_ENCODER_COMFY_KEYS_FILTER). We keep all three copies so encoder/decoder + # blocks can load from either the nested or the canonical path. + pcs = encoder.per_channel_statistics + for bufname, buf in pcs.named_buffers(): + # .clone() to sever storage sharing with the nested copies — safetensors + # refuses to dump multiple keys pointing at the same underlying buffer. + prefixed[f"vae.per_channel_statistics.{bufname}"] = buf.detach().to(torch.float32).clone().contiguous() + + save_file(prefixed, str(OUT_DIR / "state_dict.safetensors")) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(prefixed.keys())) + "\n") + manifest.dump(OUT_DIR / "manifest.json") + + (OUT_DIR / "config.json").write_text(json.dumps({ + "in_channels": IN_CHANNELS, + "latent_channels": LATENT_CHANNELS, + "decoder_base_ch": DECODER_BASE_CH, + "patch_size": PATCH_SIZE, + "norm_layer": NORM_LAYER.value, + "log_var": LOG_VAR.value, + "batch": BATCH, + "frames": F_IN, + "height": H_IN, + "width": W_IN, + "decode_timestep": DECODE_TIMESTEP, + "encoder_blocks": [ + ["compress_space_res", {"multiplier": 2}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 1}], + ], + "decoder_blocks": [ + ["compress_space", {"multiplier": 1}], + ["compress_time", {"multiplier": 1}], + ["res_x", {"num_layers": 1}], + ], + }, indent=2)) + + print(f"\nDone. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(prefixed)} keys → {OUT_DIR}/state_dict.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/test_connector_parity.cpp b/tests/ltx_parity/test_connector_parity.cpp new file mode 100644 index 000000000..f364dae5c --- /dev/null +++ b/tests/ltx_parity/test_connector_parity.cpp @@ -0,0 +1,297 @@ +// LTX-2 text connector parity test (V1 / 19B). +// +// Loads /tmp/connector_ref/{state_dict.safetensors, tensors/*.bin} produced by +// dump_connector.py, runs: +// 1. CPU feature_extractor_normalize on the stacked input +// 2. LTX2ConnectorRunner::compute through each probe stage +// and diffs against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx_connector.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/connector_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // Tiny config (must match dump_connector.py). + const int64_t B = 1; + const int64_t T = 8; + const int NUM_HEADS = 2; + const int HEAD_DIM = 32; + const int64_t D = NUM_HEADS * HEAD_DIM; // connector inner_dim = 64 + const int64_t L = 5; // stacked layers + const int64_t FLAT_DIM = D * L; // 320 + const int NUM_LAYERS = 2; + const int NUM_REGISTERS = 4; + const int64_t CAPTION_CHANNELS = D; // 64 + const int64_t CAPTION_HIDDEN = 128; + const int64_t CAPTION_OUT = 128; + const float THETA = 10000.0f; + const std::vector MAX_POS = {1}; + + // --- 1. Load state dict. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct runner. + ggml_backend_t backend = ggml_backend_cpu_init(); + LTXConnector::LTX2ConnectorRunner runner( + backend, /*offload_params_to_cpu=*/false, + FLAT_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS, NUM_REGISTERS, + CAPTION_CHANNELS, CAPTION_HIDDEN, CAPTION_OUT, + THETA, MAX_POS, tsm, /*prefix=*/""); + + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, ""); + std::printf("[load] %zu param tensors…\n", param_tensors.size()); + + // Diagnose any missing tensors. + int missing_shown = 0; + std::set tsm_keys; + for (const auto& kv : tsm) tsm_keys.insert(kv.first); + for (const auto& pt : param_tensors) { + if (tsm_keys.find(pt.first) == tsm_keys.end()) { + if (missing_shown < 5) { + std::printf("[load] missing in file: %s\n", pt.first.c_str()); + missing_shown++; + } + } + } + + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed\n"); + return 1; + } + + // --- 3. Load stacked input (ref layout: [B, T, D, L]). + auto stacked_in = load_raw_bin(ref_dir + "/tensors/stacked_in.bin", {L, D, T, B}); + + // --- 4. CPU-side feature extractor normalization. + std::vector seq_lens(B, static_cast(T)); // all-ones mask + sd::Tensor normed({FLAT_DIM, T, B}); + LTXConnector::feature_extractor_normalize( + stacked_in.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + + // --- 5. Run each probe stage and diff. + struct Probe { + int stage; + const char* name; + std::vector shape; // ne order (innermost first) + float tol_max_abs; + float tol_mean_abs; + }; +// Tolerances reflect: (1) fp16 K/V cast in ggml_ext_attention_ext (~1e-3 per + // attention layer), (2) residual fp32 cos/sin divergence between torch and + // libm at the tail of the freq grid (~6e-3 PE max diff → ~1e-3 per q/k + // rotation). Two attention layers → ~2-3e-3 max_abs cap end-to-end. + const Probe probes[] = { + {0, "feat_ext_out", {D, T, B}, 1e-4f, 5e-5f}, + {1, "conn_block_0_out", {D, T, B}, 3e-3f, 5e-4f}, + {2, "conn_block_1_out", {D, T, B}, 4e-3f, 1e-3f}, + {3, "conn_final_out", {D, T, B}, 3e-3f, 5e-4f}, + {4, "caption_proj_out", {CAPTION_OUT, T, B}, 3e-3f, 1e-3f}, + }; + + bool all_pass = true; + std::printf("\n=== LTX-2 Connector parity ===\n"); + std::printf("%-20s %11s %11s %11s %s\n", "tag", "max_abs", "mean_abs", "max_rel", "result"); + + for (const auto& p : probes) { + auto out = runner.compute(/*n_threads=*/1, normed, p.stage); + auto ref = load_raw_bin(ref_dir + "/tensors/" + p.name + ".bin", p.shape); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "[%s] size mismatch got=%ld want=%ld\n", + p.name, out.numel(), ref.numel()); + return 1; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + bool pass = s.max_abs < p.tol_max_abs && s.mean_abs < p.tol_mean_abs; + std::printf(" %-18s %.3e %.3e %.3e %s\n", + p.name, s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + if (!pass && s.max_abs_idx >= 0) { + int64_t i = s.max_abs_idx; + std::printf(" max-diff @ idx=%ld: got=%+.6f want=%+.6f diff=%+.6f\n", + i, out.data()[i], ref.data()[i], out.data()[i] - ref.data()[i]); + } + all_pass &= pass; + } + + std::printf("\n%s\n", all_pass ? "Connector parity: PASS" : "Connector parity: FAIL"); + + // ---------- Padded variant: T_REAL < NUM_REGISTERS ---------- + // This section exercises the learnable-register concat path in + // LTX2ConnectorRunner::build_graph that the primary run above skips (there + // T=8 > NUM_REGISTERS=4). The reference is dumped by + // `CONNECTOR_VARIANT=padded dump_connector.py` with NUM_REGISTERS=8, + // SEQ_LEN=8 and a left-padded attention_mask making only the last 3 tokens + // real. Python runs the full pipeline (feature_extractor → replace_padded + // → connector); C++ feeds only the 3 real tokens (slide-to-front done in + // the conditioner on the production path) and the runner's concat-with- + // registers path must reconstruct the same 8-token sequence internally. + const std::string padded_dir = "/tmp/connector_ref_padded"; + std::ifstream padded_check(padded_dir + "/state_dict.safetensors"); + if (!padded_check.is_open()) { + std::printf("\n[padded] %s not found — skip. Run " + "`CONNECTOR_VARIANT=padded dump_connector.py` to enable.\n", + padded_dir.c_str()); + return all_pass ? 0 : 3; + } + padded_check.close(); + + std::printf("\n=== LTX-2 Connector parity (padded: T_real=3 < num_reg=8) ===\n"); + + const int64_t PAD_T_REAL = 3; + const int64_t PAD_T_FULL = 8; + const int NUM_REGISTERS_PAD = 8; + + ModelLoader pad_loader; + if (!pad_loader.init_from_file(padded_dir + "/state_dict.safetensors")) { + std::fprintf(stderr, "fatal: padded init_from_file failed\n"); + return 1; + } + const auto& pad_tsm = pad_loader.get_tensor_storage_map(); + std::printf("[padded state_dict] loaded %zu tensors\n", pad_tsm.size()); + + LTXConnector::LTX2ConnectorRunner pad_runner( + backend, /*offload_params_to_cpu=*/false, + FLAT_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS, NUM_REGISTERS_PAD, + CAPTION_CHANNELS, CAPTION_HIDDEN, CAPTION_OUT, + THETA, MAX_POS, pad_tsm, /*prefix=*/""); + pad_runner.alloc_params_buffer(); + + std::map pad_params; + pad_runner.get_param_tensors(pad_params, ""); + if (!pad_loader.load_tensors(pad_params)) { + std::fprintf(stderr, "fatal: padded load_tensors failed\n"); + return 1; + } + + // Load the full padded stacked input (padded positions at the START), then + // slice to only the T_REAL real tokens at the tail — this is what the + // production conditioner passes to the connector runner after sliding the + // real rows to the front. + auto pad_stacked_full = load_raw_bin(padded_dir + "/tensors/stacked_in.bin", + {L, D, PAD_T_FULL, B}); + // Ref layout [B, T, D, L] → ggml ne [L, D, T, B]. Real tokens occupy + // indices [PAD_T_FULL - PAD_T_REAL .. PAD_T_FULL) along axis T (ne[2]). + sd::Tensor pad_stacked_real({L, D, PAD_T_REAL, B}); + for (int64_t b = 0; b < B; ++b) { + for (int64_t t = 0; t < PAD_T_REAL; ++t) { + for (int64_t d = 0; d < D; ++d) { + for (int64_t l = 0; l < L; ++l) { + int64_t src = ((b * PAD_T_FULL + (PAD_T_FULL - PAD_T_REAL + t)) * D + d) * L + l; + int64_t dst = ((b * PAD_T_REAL + t) * D + d) * L + l; + pad_stacked_real.data()[dst] = pad_stacked_full.data()[src]; + } + } + } + } + + // CPU normalize the real-only stacked input (no padding). + std::vector pad_seq_lens(B, static_cast(PAD_T_REAL)); + sd::Tensor pad_normed({FLAT_DIM, PAD_T_REAL, B}); + LTXConnector::feature_extractor_normalize( + pad_stacked_real.data(), pad_seq_lens.data(), pad_normed.data(), + static_cast(B), static_cast(PAD_T_REAL), static_cast(D), static_cast(L), + "left", 1e-6f); + + // Connector should internally concat learnable_registers[T_real:num_reg] + // → output shape at the final stage is [D, num_reg, B]. + bool pad_pass = true; + const Probe pad_probes[] = { + // Feature-extractor output is just the T_REAL real tokens (shape + // [D, T_REAL, B]); Python's feat_ext_out covers T_FULL padded and we + // only check the real-token tail. + {3, "conn_final_out", {D, PAD_T_FULL, B}, 6e-3f, 2e-3f}, + {4, "caption_proj_out", {CAPTION_OUT, PAD_T_FULL, B}, 6e-3f, 2e-3f}, + }; + + for (const auto& p : pad_probes) { + auto out = pad_runner.compute(/*n_threads=*/1, pad_normed, p.stage); + auto ref = load_raw_bin(padded_dir + "/tensors/" + p.name + ".bin", p.shape); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "[padded %s] size mismatch got=%ld want=%ld\n", + p.name, out.numel(), ref.numel()); + return 1; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + bool pass = s.max_abs < p.tol_max_abs && s.mean_abs < p.tol_mean_abs; + std::printf(" %-18s %.3e %.3e %.3e %s\n", + p.name, s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + if (!pass && s.max_abs_idx >= 0) { + int64_t i = s.max_abs_idx; + std::printf(" max-diff @ idx=%ld: got=%+.6f want=%+.6f diff=%+.6f\n", + i, out.data()[i], ref.data()[i], out.data()[i] - ref.data()[i]); + } + pad_pass &= pass; + } + + std::printf("\n%s\n", pad_pass ? "Connector padded parity: PASS" : "Connector padded parity: FAIL"); + return (all_pass && pad_pass) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_gemma_parity.cpp b/tests/ltx_parity/test_gemma_parity.cpp new file mode 100644 index 000000000..c43c2e567 --- /dev/null +++ b/tests/ltx_parity/test_gemma_parity.cpp @@ -0,0 +1,287 @@ +// Gemma 3 C++ parity test. +// +// Loads /tmp/gemma_ref/{state_dict.safetensors, tensors/*.bin} produced by +// tests/ltx_parity/dump_gemma.py, runs one LLMRunner forward pass on the same +// input_ids, and diffs each of the N+1 hidden states (embedding + per-layer + +// post-final-norm last) against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "llm.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/gemma_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // --- 1. Load state dict. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + // Skip convert_tensors_name(): the "text_encoder" prefix is remapped to + // "cond_stage_model.transformer." by the conversion table (see name_conversion.cpp:1112), + // which would break our direct-load parity test. We match key names exactly. + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct LLMRunner with GEMMA3 arch. Hyperparams auto-detect from tensor shapes. + ggml_backend_t backend = ggml_backend_cpu_init(); + LLM::LLMRunner runner(LLM::LLMArch::GEMMA3, backend, /*offload_params_to_cpu=*/false, + tsm, /*prefix=*/"text_encoder", /*enable_vision=*/false); + + const auto& p = runner.params; + std::printf("[config] layers=%ld hidden=%ld heads=%d kv_heads=%d head_dim=%d " + "ff=%ld vocab=%ld sw=%d pattern=%d\n", + p.num_layers, p.hidden_size, p.num_heads, p.num_kv_heads, p.head_dim, + p.intermediate_size, p.vocab_size, p.sliding_window, p.sliding_window_pattern); + + // --- 3. Load params buffer and weights. + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "text_encoder"); + std::printf("[load] loading %zu param tensors…\n", param_tensors.size()); + + // Collect tsm keys for diffing. + std::set tsm_keys; + for (const auto& kv : tsm) tsm_keys.insert(kv.first); + + // Dump any param_tensor keys not present in the file (for diagnosing name mismatches). + int missing_shown = 0; + for (const auto& pt : param_tensors) { + if (tsm_keys.find(pt.first) == tsm_keys.end()) { + if (missing_shown < 5) { + std::printf("[load] missing in file: %s\n", pt.first.c_str()); + missing_shown++; + } + } + } + + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed (some names unmatched?)\n"); + return 1; + } + + // --- 4. Load reference inputs. Dumper saved input_ids as f32 for simplicity. + const int64_t B = 1, T = 8, H = p.hidden_size; + auto input_ids_f32 = load_raw_bin(ref_dir + "/tensors/input_ids.bin", {T, B}); + sd::Tensor input_ids({T, B}); + for (int64_t i = 0; i < T; ++i) { + input_ids.data()[i] = (int32_t)input_ids_f32.data()[i]; + } + sd::Tensor empty_mask; + + std::printf("[input] input_ids: "); + for (int i = 0; i < T; i++) std::printf("%d ", input_ids.data()[i]); + std::printf("\n"); + + // Override the window size to match the tiny-config dump. + runner.params.sliding_window = 4; + // The tiny-config Python dump does NOT apply linear rope_scaling (that's only + // wired for the "deep" variant which mirrors the real 12B). Disable scaling + // on the C++ side so we compare apples-to-apples. + runner.params.rope_scaling_factor_global = 1.0f; + + // --- 5a. First test the basic forward path (returns just last_hidden_state after norm). + std::printf("[compute] basic forward (last_hidden_state only)…\n"); + std::fflush(stdout); + auto basic = runner.compute(/*n_threads=*/1, input_ids, empty_mask, {}, {}); + std::printf("[compute] basic forward done, numel=%ld first=%.4f\n", basic.numel(), basic.numel() > 0 ? basic.data()[0] : 0.f); + std::fflush(stdout); + + // --- 5b. Compute all N+1 hidden states. + std::printf("[compute] running forward pass with all-hidden-states path…\n"); + std::fflush(stdout); + auto stacked = runner.compute_all_hidden_states(/*n_threads=*/1, input_ids, empty_mask); + std::printf("[compute] done, stacked shape=[%zu dims] numel=%ld\n", stacked.shape().size(), stacked.numel()); + std::fflush(stdout); + + // stacked has shape (sd::Tensor layout = innermost-first): {N+1, H, T, B}. + const int64_t N_plus_1 = p.num_layers + 1; + if (stacked.numel() != N_plus_1 * H * T * B) { + std::fprintf(stderr, "fatal: stacked numel mismatch got=%ld expected=%ld\n", + stacked.numel(), N_plus_1 * H * T * B); + return 1; + } + std::printf("[output] stacked shape=[%ld,%ld,%ld,%ld] numel=%ld\n", + N_plus_1, H, T, B, stacked.numel()); + + // --- 6. Slice each layer out of the stacked tensor and diff. + // Memory layout: innermost=N+1, so all layers for one (h, t, b) are adjacent. + // For a given layer_idx l: layer_data[b][t][h] = stacked[((b*T + t)*H + h)*(N+1) + l]. + // Ref layer is stored with innermost=H, shape [B, T, H]. So we reconstruct layer l by + // scattering. + // Tolerances reflect the fp16 cast inside ggml_ext_attention_ext (K/V go through + // GGML_TYPE_F16 before the softmax). Reference Python stays in fp32, so ~1e-3 abs + // drift per attention layer is baked in. For 6 stacked layers we budget ~6× that. + // max_rel is skipped — small reference values blow up relative error even when + // absolute agreement is fine. + const float tol_max_abs = 1e-2f; + const float tol_mean_abs = 2e-3f; + + bool all_pass = true; + std::printf("\n=== Gemma hidden-state parity ===\n"); + std::printf("%-18s %11s %11s %11s\n", "tag", "max_abs", "mean_abs", "max_rel"); + + std::vector layer_buf(B * T * H); + for (int l = 0; l < N_plus_1; l++) { + // Gather layer l from the stacked tensor. + const float* src = stacked.data(); + for (int64_t b = 0; b < B; b++) { + for (int64_t t = 0; t < T; t++) { + for (int64_t h = 0; h < H; h++) { + int64_t stacked_idx = ((b * T + t) * H + h) * N_plus_1 + l; + int64_t ref_idx = (b * T + t) * H + h; + layer_buf[ref_idx] = src[stacked_idx]; + } + } + } + std::string tag = (l == 0) ? "hs_embed" : ("hs_" + std::string(l < 10 ? "0" : "") + std::to_string(l)); + auto ref = load_raw_bin(ref_dir + "/tensors/" + tag + ".bin", {H, T, B}); + auto s = diff_fp32(layer_buf.data(), ref.data(), (int64_t)layer_buf.size()); + bool pass = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf(" %-16s %.3e %.3e %.3e %s\n", + tag.c_str(), s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + all_pass &= pass; + } + std::printf("\n%s (tol: max_abs<%.1e mean_abs<%.1e)\n", + all_pass ? "Gemma parity: PASS" : "Gemma parity: FAIL", + tol_max_abs, tol_mean_abs); + + // --- Deep variant parity: 24 layers × 512 hidden, seq=32 with 16-wide sliding --- + // Mirrors the real Gemma 3 12B's sliding_window_pattern=6 (so ~every 6th layer does + // global attention) at scaled-down dims. Catches drift patterns that only appear + // across many layers / real hidden-size, without requiring the full 12B download. + bool deep_present = false; + for (const auto& k : tsm_keys) { + if (k.rfind("text_encoder_deep.", 0) == 0) { + deep_present = true; + break; + } + } + if (!deep_present) { + std::printf("\n[deep] no text_encoder_deep.* tensors found (run " + "`GEMMA_PARITY_VARIANT=deep dump_gemma.py` to enable); skipping\n"); + return all_pass ? 0 : 3; + } + + std::printf("\n=== Gemma deep parity (24L × 512H, sliding=16, seq=32) ===\n"); + LLM::LLMRunner deep_runner(LLM::LLMArch::GEMMA3, backend, /*offload=*/false, + tsm, /*prefix=*/"text_encoder_deep", /*enable_vision=*/false); + const auto& dp = deep_runner.params; + std::printf("[deep config] layers=%ld hidden=%ld heads=%d kv_heads=%d head_dim=%d ff=%ld\n", + dp.num_layers, dp.hidden_size, dp.num_heads, dp.num_kv_heads, dp.head_dim, + dp.intermediate_size); + + deep_runner.alloc_params_buffer(); + std::map deep_params; + deep_runner.get_param_tensors(deep_params, "text_encoder_deep"); + if (!loader.load_tensors(deep_params)) { + std::fprintf(stderr, "fatal: deep load_tensors failed\n"); + return 1; + } + + const int64_t Td = 32; + const int64_t Hd = dp.hidden_size; + auto deep_input_ids_f32 = load_raw_bin(ref_dir + "/tensors/deep_input_ids.bin", {Td, 1}); + sd::Tensor deep_input_ids({Td, 1}); + for (int64_t i = 0; i < Td; ++i) deep_input_ids.data()[i] = (int32_t)deep_input_ids_f32.data()[i]; + sd::Tensor deep_empty_mask; + + // Override sliding window to match the deep variant's config (tiny: 4, deep: 16). + deep_runner.params.sliding_window = 16; + + auto deep_stacked = deep_runner.compute_all_hidden_states(/*n_threads=*/1, + deep_input_ids, + deep_empty_mask); + const int64_t deep_N_plus_1 = dp.num_layers + 1; + GGML_ASSERT(deep_stacked.numel() == deep_N_plus_1 * Hd * Td * 1); + + std::printf("[deep output] stacked shape=[%ld,%ld,%ld,1] numel=%ld\n", + deep_N_plus_1, Hd, Td, deep_stacked.numel()); + + const float deep_tol_max_abs = 5e-2f; // 24 layers → ~4× baseline drift budget + const float deep_tol_mean_abs = 1e-2f; + bool deep_all_pass = true; + std::printf("%-22s %11s %11s %11s\n", "tag", "max_abs", "mean_abs", "max_rel"); + + std::vector deep_layer_buf(Td * Hd); + for (int l = 0; l < deep_N_plus_1; ++l) { + const float* src = deep_stacked.data(); + for (int64_t t = 0; t < Td; ++t) { + for (int64_t h = 0; h < Hd; ++h) { + int64_t stacked_idx = (t * Hd + h) * deep_N_plus_1 + l; + deep_layer_buf[t * Hd + h] = src[stacked_idx]; + } + } + std::string tag = (l == 0) ? "deep_hs_embed" : ("deep_hs_" + std::string(l < 10 ? "0" : "") + std::to_string(l)); + auto ref = load_raw_bin(ref_dir + "/tensors/" + tag + ".bin", {Hd, Td, 1}); + auto s = diff_fp32(deep_layer_buf.data(), ref.data(), (int64_t)deep_layer_buf.size()); + bool pass = s.max_abs < deep_tol_max_abs && s.mean_abs < deep_tol_mean_abs; + std::printf(" %-20s %.3e %.3e %.3e %s\n", + tag.c_str(), s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + deep_all_pass &= pass; + } + std::printf("\n%s (tol: max_abs<%.1e mean_abs<%.1e)\n", + deep_all_pass ? "Gemma deep parity: PASS" : "Gemma deep parity: FAIL", + deep_tol_max_abs, deep_tol_mean_abs); + + return (all_pass && deep_all_pass) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_gemma_tokenizer.cpp b/tests/ltx_parity/test_gemma_tokenizer.cpp new file mode 100644 index 000000000..6ce61b28f --- /dev/null +++ b/tests/ltx_parity/test_gemma_tokenizer.cpp @@ -0,0 +1,88 @@ +// Tokenizer parity test for Gemma 3. +// +// Encodes a handful of fixed strings with our GemmaTokenizer and compares to the +// token IDs produced by transformers' AutoTokenizer (google/gemma-3-12b-it). The +// expected IDs below are hard-coded from a Python reference run — if the HF vocab +// ever changes they must be regenerated. + +#include +#include +#include +#include + +#include "tokenizers/gemma_tokenizer.h" + +namespace { + +struct Case { + std::string input; + std::vector expected_no_bos; +}; + +bool run_case(GemmaTokenizer& tk, const Case& c, int idx) { + std::vector got = tk.encode(c.input); + bool ok = (got == c.expected_no_bos); + std::printf(" [%2d] ", idx); + if (ok) { + std::printf("PASS %zu tokens: ", got.size()); + } else { + std::printf("FAIL got=%zu exp=%zu ", got.size(), c.expected_no_bos.size()); + } + for (size_t i = 0; i < got.size() && i < 8; i++) std::printf("%d ", got[i]); + if (got.size() > 8) std::printf("..."); + std::printf("\n"); + if (!ok) { + std::printf(" input : %s\n", c.input.c_str()); + std::printf(" expected : "); + for (int x : c.expected_no_bos) std::printf("%d ", x); + std::printf("\n got : "); + for (int x : got) std::printf("%d ", x); + std::printf("\n"); + } + return ok; +} + +} // namespace + +int main(int argc, char** argv) { + const char* default_path = + "/home/ilintar/.cache/huggingface/hub/models--google--gemma-3-12b-it/" + "snapshots/96b6f1eccf38110c56df3a15bffe176da04bfd80/tokenizer.json"; + std::string path = (argc > 1) ? argv[1] : default_path; + + GemmaTokenizer tk; + std::printf("[load] %s\n", path.c_str()); + if (!tk.load_from_file(path)) { + std::fprintf(stderr, "fatal: could not load tokenizer\n"); + return 1; + } + std::printf("[load] vocab=%d bos=%d eos=%d pad=%d unk=%d\n", + tk.vocab_size(), tk.BOS_TOKEN_ID, tk.EOS_TOKEN_ID, tk.PAD_TOKEN_ID, tk.UNK_TOKEN_ID); + + // Ground truth from transformers.AutoTokenizer("google/gemma-3-12b-it") with + // add_special_tokens=False. + std::vector cases = { + {"hello", {23391}}, + {"hello world", {23391, 1902}}, + {" a b", {138, 236746, 138, 236763}}, + {"naïve", {1789, 238527, 560}}, + {"你好", {144626}}, + {"→ a", {238183, 496}}, + {"The quick brown fox jumps over the lazy dog.", + {818, 3823, 8864, 37423, 38167, 1024, 506, 31770, 4799, 236761}}, + {"", {}}, + {" ", {236743}}, + {"\n\n\ttabs and\nnewlines", + {108, 255968, 39218, 532, 107, 208697}}, + {"mixed: ABCdef 123 !@# UNK char: \xe2\x80\x8b", + {63258, 236787, 21593, 2063, 236743, 236770, 236778, 236800, + 1717, 236940, 236865, 7866, 236855, 1577, 236787, 36504}}, + }; + + int pass = 0; + for (size_t i = 0; i < cases.size(); i++) { + if (run_case(tk, cases[i], (int)i)) pass++; + } + std::printf("\n%d / %zu cases passed.\n", pass, cases.size()); + return (pass == (int)cases.size()) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp b/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp new file mode 100644 index 000000000..cc93cf112 --- /dev/null +++ b/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp @@ -0,0 +1,221 @@ +// LTX-2 VAE encode→decode round-trip sanity check on a real 22B VAE checkpoint. +// +// Purpose: rule out whether the blocky output from the end-to-end LTX-2 pipeline +// is caused by a broken VAE decoder. Constructs a simple synthetic video +// (color gradient ramps), runs the real LTX-2 VAE through encode→decode, and +// reports the reconstruction MSE + dumps the first output frame's values. +// +// If MSE is small (<0.05 for bounded [-1,1] input), the VAE is sound and the +// structural issue must live upstream (DiT / conditioning). If MSE is high, +// the VAE itself is miscomputing and that explains the pipeline output. +// +// Usage: sd-ltx2-vae-roundtrip [WIDTH [HEIGHT [FRAMES]]] + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" + +#include "ltxvae.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +void apply_pcs_duplication(String2TensorStorage& tsm) { + // Mirror stable-diffusion.cpp's LTX-2-specific duplication so the nested + // encoder.per_channel_statistics.* / decoder.per_channel_statistics.* paths + // that our VideoEncoder/Decoder blocks look up exist. + const std::string top_pre = "first_stage_model.per_channel_statistics."; + std::vector> to_copy; + for (const auto& kv : tsm) { + const std::string& k = kv.first; + if (k.rfind(top_pre, 0) == 0) { + to_copy.push_back({k, k.substr(top_pre.size())}); + } + } + size_t copied = 0; + for (auto& pair : to_copy) { + auto src_it = tsm.find(pair.first); + if (src_it == tsm.end()) continue; + for (const char* sub : {"encoder", "decoder"}) { + std::string dst = "first_stage_model." + std::string(sub) + + ".per_channel_statistics." + pair.second; + if (tsm.find(dst) != tsm.end()) continue; + TensorStorage dup = src_it->second; + dup.name = dst; + tsm[dst] = dup; + copied++; + } + } + std::printf("[pcs] duplicated %zu entries\n", copied); +} + +// Builds a 9-frame [W, H, T, 3] synthetic video in [-1, 1]. Produces a +// spatial gradient ramp in R/G/B channels so reconstruction is easy to eyeball. +sd::Tensor make_synthetic_video(int W, int H, int T) { + sd::Tensor v({W, H, T, 3}); + for (int t = 0; t < T; ++t) { + float tphase = static_cast(t) / std::max(T - 1, 1); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + float r = static_cast(w) / std::max(W - 1, 1) * 2.0f - 1.0f; + float g = static_cast(h) / std::max(H - 1, 1) * 2.0f - 1.0f; + float b = tphase * 2.0f - 1.0f; + v.index(w, h, t, 0) = r; + v.index(w, h, t, 1) = g; + v.index(w, h, t, 2) = b; + } + } + } + return v; +} + +struct DiffStats { + float max_abs = 0.f, mean_abs = 0.f, mse = 0.f; +}; +DiffStats diff_stats(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0, sum_sq = 0.0; + for (int64_t i = 0; i < n; ++i) { + float d = a[i] - b[i]; + float ad = std::fabs(d); + s.max_abs = std::max(s.max_abs, ad); + sum_abs += ad; + sum_sq += static_cast(d) * d; + } + s.mean_abs = static_cast(sum_abs / std::max(n, 1)); + s.mse = static_cast(sum_sq / std::max(n, 1)); + return s; +} + +} // namespace + +int main(int argc, char** argv) { + sd_set_log_callback( + [](enum sd_log_level_t /*level*/, const char* text, void* /*data*/) { + std::fputs(text, stderr); + }, + nullptr); + + const char* vae_path = (argc >= 2) + ? argv[1] + : "/media/ilintar/D_SSD/models/ltx-2/ltx-2.3-22b-dev_video_vae.safetensors"; + int W = (argc >= 3) ? std::atoi(argv[2]) : 128; + int H = (argc >= 4) ? std::atoi(argv[3]) : 128; + int T = (argc >= 5) ? std::atoi(argv[4]) : 9; + + std::printf("[cfg] vae_path = %s\n", vae_path); + std::printf("[cfg] input video = %dx%d, %d frames\n", W, H, T); + + ModelLoader loader; + // The raw 22B video_vae.safetensors ships tensors as `encoder.*`, `decoder.*`, + // `per_channel_statistics.*` with no top-level prefix. Passing prefix="vae." on + // init adds that so the subsequent convert_tensors_name() remaps `vae.` → + // `first_stage_model.` via name_conversion.cpp. + if (!loader.init_from_file(vae_path, "vae.")) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", vae_path); + return 1; + } + loader.convert_tensors_name(); + auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors\n", tsm.size()); + apply_pcs_duplication(tsm); + + ggml_backend_t backend = ggml_backend_cpu_init(); + LTXVAE::LTX2VAERunner vae(backend, /*offload=*/false, tsm, + /*prefix=*/"first_stage_model", + VERSION_LTX2, + /*in_ch=*/3, /*latent_ch=*/128, /*patch=*/4, + /*decoder_base_ch=*/128, /*timestep_cond=*/false, + LTXVAE::LTX2VAERunner::ltx2_22b_enc_specs(), + LTXVAE::LTX2VAERunner::ltx2_22b_dec_specs()); + vae.alloc_params_buffer(); + std::map vae_params; + vae.get_param_tensors(vae_params, "first_stage_model"); + std::printf("[vae] requesting %zu param tensors\n", vae_params.size()); + if (!loader.load_tensors(vae_params)) { + std::fprintf(stderr, "fatal: vae load_tensors failed (weights unmatched?)\n"); + return 1; + } + + // Build synthetic [W, H, T, 3] video in [-1, 1]. + auto video = make_synthetic_video(W, H, T); + std::printf("[input] shape = [W=%d, H=%d, T=%d, C=3] min=%.3f max=%.3f mean=%.3f\n", + W, H, T, *std::min_element(video.data(), video.data() + video.numel()), + *std::max_element(video.data(), video.data() + video.numel()), + [&]() { + double s = 0; + for (int64_t i = 0; i < video.numel(); ++i) s += video.data()[i]; + return static_cast(s / video.numel()); + }()); + + // --- Encode --- + std::printf("[encode] running…\n"); + auto latent = vae._compute(/*n_threads=*/1, video, /*decode_graph=*/false); + std::printf("[encode] latent shape = ["); + for (size_t i = 0; i < latent.shape().size(); ++i) + std::printf("%s%lld", (i ? ", " : ""), (long long)latent.shape()[i]); + std::printf("] numel=%lld\n", (long long)latent.numel()); + if (latent.empty()) { + std::fprintf(stderr, "fatal: encode produced empty output\n"); + return 2; + } + + // Latent first 8 values for eyeballing. + std::printf("[encode] first 8 latent values: "); + for (int i = 0; i < 8 && i < latent.numel(); ++i) + std::printf("%+.3f ", latent.data()[i]); + std::printf("\n"); + + // Encoder's output layout is [W_lat, H_lat, T_lat, C_lat]. The decoder's + // expected input is the same layout. + // --- Decode --- + std::printf("[decode] running…\n"); + auto recon = vae._compute(/*n_threads=*/1, latent, /*decode_graph=*/true); + std::printf("[decode] recon shape = ["); + for (size_t i = 0; i < recon.shape().size(); ++i) + std::printf("%s%lld", (i ? ", " : ""), (long long)recon.shape()[i]); + std::printf("] numel=%lld\n", (long long)recon.numel()); + if (recon.empty()) { + std::fprintf(stderr, "fatal: decode produced empty output\n"); + return 3; + } + + if (recon.numel() != video.numel()) { + std::fprintf(stderr, "fatal: recon numel %lld != input numel %lld " + "(enc/dec changed element count)\n", + (long long)recon.numel(), (long long)video.numel()); + return 4; + } + + std::printf("[decode] first 8 recon values: "); + for (int i = 0; i < 8 && i < recon.numel(); ++i) + std::printf("%+.3f ", recon.data()[i]); + std::printf("\n[input ] first 8 input values: "); + for (int i = 0; i < 8 && i < video.numel(); ++i) + std::printf("%+.3f ", video.data()[i]); + std::printf("\n"); + + // Diff. + auto s = diff_stats(recon.data(), video.data(), recon.numel()); + std::printf("\n=== round-trip diff ===\n"); + std::printf(" max_abs = %.3e\n", s.max_abs); + std::printf(" mean_abs = %.3e\n", s.mean_abs); + std::printf(" mse = %.3e\n", s.mse); + + // Loose pass thresholds. LTX-2 VAE is lossy but mean_abs <0.1 for a smooth + // gradient is a reasonable ceiling. Anything much worse means structural + // divergence, not just compression. + const float tol_mse = 0.05f; + bool pass = s.mse < tol_mse; + std::printf("%s (tol: mse < %.1e)\n", + pass ? "VAE round-trip: PASS" : "VAE round-trip: FAIL", + tol_mse); + return pass ? 0 : 5; +} diff --git a/tests/ltx_parity/test_ltx_parity.cpp b/tests/ltx_parity/test_ltx_parity.cpp new file mode 100644 index 000000000..3bfe14920 --- /dev/null +++ b/tests/ltx_parity/test_ltx_parity.cpp @@ -0,0 +1,438 @@ +// LTX-2 C++ parity test. +// +// Loads the state dict + reference intermediate tensors dumped by +// tests/ltx_parity/dump_reference.py, runs one forward pass of LTXRunner on the same inputs, +// and diffs the output against the Python reference. +// +// Tolerances: F32 backend is expected to match to ~1e-4 abs / ~1e-3 rel. Larger drift points to +// a block-level bug — rerun with --intermediate to capture per-block outputs via the cache API. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "denoiser.hpp" +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +void print_shape(const char* label, const std::vector& shape) { + std::printf("%s[", label); + for (size_t i = 0; i < shape.size(); ++i) { + std::printf("%s%ld", i == 0 ? "" : ", ", shape[i]); + } + std::printf("]\n"); +} + +} // namespace + +// Returns true if all schedule cases agree within absolute tolerance. +bool check_schedule(const std::string& ref_dir) { + struct Case { + int tokens; + int steps; + bool stretch; + const char* label; + }; + // Must match the cases in dump_reference.py::dump_scheduler. + const std::vector cases = { + {1024, 10, true, "tokens1024_steps10_stretch1"}, + {1024, 30, true, "tokens1024_steps30_stretch1"}, + {4096, 10, true, "tokens4096_steps10_stretch1"}, + {4096, 40, true, "tokens4096_steps40_stretch1"}, + {2560, 30, true, "tokens2560_steps30_stretch1"}, + {4096, 8, false, "tokens4096_steps8_stretch0"}, + }; + + bool all_pass = true; + std::printf("\n=== LTX2FlowDenoiser::get_sigmas parity ===\n"); + for (const auto& c : cases) { + LTX2FlowDenoiser denoiser; + denoiser.stretch = c.stretch; + auto cpp_sigmas = denoiser.get_sigmas(static_cast(c.steps), c.tokens, + DISCRETE_SCHEDULER, VERSION_LTX2); + auto ref = load_raw_bin(ref_dir + "/tensors/schedule__" + c.label + ".bin", + {static_cast(c.steps + 1)}); + if (static_cast(cpp_sigmas.size()) != ref.numel()) { + std::fprintf(stderr, "[sched %s] size mismatch cpp=%zu ref=%ld\n", + c.label, cpp_sigmas.size(), ref.numel()); + all_pass = false; + continue; + } + auto s = diff_fp32(cpp_sigmas.data(), ref.data(), ref.numel()); + // Schedules are small floats (≤1), 1e-5 abs tolerance is reasonable; mu/exp arithmetic + // differs negligibly between libm and Python math. + const float tol = 5e-5f; + bool pass = s.max_abs < tol; + std::printf(" %-32s max_abs=%.2e mean_abs=%.2e %s\n", + c.label, s.max_abs, s.mean_abs, pass ? "PASS" : "FAIL"); + if (!pass) { + std::printf(" cpp[0..3] = %.6f %.6f %.6f %.6f\n", cpp_sigmas[0], + cpp_sigmas[1], cpp_sigmas[2], cpp_sigmas[3]); + std::printf(" ref[0..3] = %.6f %.6f %.6f %.6f\n", ref.data()[0], + ref.data()[1], ref.data()[2], ref.data()[3]); + all_pass = false; + } + } + return all_pass; +} + +// Runs one Euler step in C++ using LTX2FlowDenoiser's scheduler values + a DiT velocity output, +// then diffs against the Python reference. +bool check_euler_step(const std::string& ref_dir, LTX::LTXRunner& runner, + const sd::Tensor& latent, const sd::Tensor& context) { + auto sigma_cur_ref = load_raw_bin(ref_dir + "/tensors/euler__sigma_cur.bin", {1}); + auto sigma_next_ref = load_raw_bin(ref_dir + "/tensors/euler__sigma_next.bin", {1}); + auto v_ref = load_raw_bin(ref_dir + "/tensors/euler__v_step_unflat.bin", {6, 4, 2, 16}); + auto x_next_ref = load_raw_bin(ref_dir + "/tensors/euler__x_next_unflat.bin", {6, 4, 2, 16}); + + const float sigma_cur = sigma_cur_ref.data()[0]; + const float sigma_next = sigma_next_ref.data()[0]; + + std::printf("\n=== Euler step parity (σ=%.4f → %.4f) ===\n", sigma_cur, sigma_next); + + // Run the DiT at σ_cur (pre-scaled by 1000 for AdaLN, via LTX2FlowDenoiser::sigma_to_t). + LTX2FlowDenoiser denoiser; + sd::Tensor t_in({1}); + t_in.data()[0] = denoiser.sigma_to_t(sigma_cur); + + sd::Tensor empty_mask; + auto v_cpp = runner.compute(/*n_threads=*/1, latent, t_in, context, empty_mask); + if (v_cpp.numel() != v_ref.numel()) { + std::fprintf(stderr, "fatal: velocity size mismatch\n"); + return false; + } + + // Compute x_next = latent + (σ_next - σ) * v (element-wise). + sd::Tensor x_next_cpp(latent.shape()); + const float dt = sigma_next - sigma_cur; + for (int64_t i = 0; i < latent.numel(); ++i) { + x_next_cpp.data()[i] = latent.data()[i] + dt * v_cpp.data()[i]; + } + + auto sv = diff_fp32(v_cpp.data(), v_ref.data(), v_cpp.numel()); + auto sx = diff_fp32(x_next_cpp.data(), x_next_ref.data(), x_next_cpp.numel()); + + std::printf(" velocity@σ_cur: max_abs=%.2e mean_abs=%.2e max_rel=%.2e\n", + sv.max_abs, sv.mean_abs, sv.max_rel); + std::printf(" x_next: max_abs=%.2e mean_abs=%.2e max_rel=%.2e\n", + sx.max_abs, sx.mean_abs, sx.max_rel); + + // x_next is (latent + dt * v). dt is ~0.09, v drift ~1e-4 → x_next drift ~1e-5. Tolerances + // are roughly the same as the base DiT test since the Euler step doesn't amplify. + const float tol_abs = 1e-3f; + const float tol_rel = 5e-2f; + return sv.max_abs < tol_abs && sv.max_rel < tol_rel && + sx.max_abs < tol_abs && sx.max_rel < tol_rel; +} + +int main() { + const std::string ref_dir = "/tmp/ltx_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // --- 1. Load the reference state dict. Weights are dumped with prefix "model.diffusion_model." + // which matches sd.cpp's default DiT location, so init_from_file with empty prefix passes names through. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + loader.convert_tensors_name(); // no-op for LTX-2 — names already match + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct LTXRunner on CPU with explicit tiny-model params + // (the real LTX-2 hyperparams num_heads=32/head_dim=128 are auto-detected from weight shapes, + // but the tiny test uses num_heads=4/head_dim=32 which can't be inferred from q_norm alone). + LTX::LTXParams tiny_params; + tiny_params.in_channels = 16; + tiny_params.out_channels = 16; + tiny_params.inner_dim = 128; + tiny_params.num_heads = 4; + tiny_params.head_dim = 32; + tiny_params.num_layers = 2; + tiny_params.cross_attention_dim = 128; + tiny_params.cross_attention_adaln = false; + tiny_params.apply_gated_attention = false; + + ggml_backend_t backend = ggml_backend_cpu_init(); + LTX::LTXRunner runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model", VERSION_LTX2, &tiny_params); + runner.set_fps(24.0f); + // Parity dump uses simplified (f, h, w) positions without VAE scale factors or + // causal_fix — mirror that here so positions match the Python reference. + runner.set_scale_factors(1, 1, 1); + runner.set_causal_fix(false); + + const auto& p = runner.ltx_params; + std::printf("[config] layers=%d inner=%ld heads=%d head_dim=%d " + "in=%ld out=%ld ca_dim=%ld\n", + p.num_layers, p.inner_dim, p.num_heads, p.head_dim, + p.in_channels, p.out_channels, p.cross_attention_dim); + + // --- 3. Allocate & load weights into the GGML graph. + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "model.diffusion_model"); + std::printf("[load] loading %zu param tensors…\n", param_tensors.size()); + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed (some names unmatched?)\n"); + return 1; + } + + // --- 4. Load reference inputs. + // latent_unflat is dumped as [C=16, F=2, H=4, W=6] (C outermost, W innermost in memory). + // LTXRunner::build_graph expects ggml ne=[W, H, T=F, C], so sd::Tensor shape is {6, 4, 2, 16} + // (sd shape[0] = innermost dim). Raw memory layout is identical. + auto latent = load_raw_bin(ref_dir + "/tensors/model__latent_unflat.bin", {6, 4, 2, 16}); + auto sigma_in = load_raw_bin(ref_dir + "/tensors/model__sigma.bin", {1}); + + // The C++ AdaLN now expects pre-scaled σ (see src/ltx.hpp:AdaLayerNormSingle docstring); + // the denoiser's sigma_to_t(σ)=σ*1000 will own this scaling in production. For the test + // we do it inline. + sd::Tensor timesteps({1}); + timesteps.data()[0] = sigma_in.data()[0] * 1000.0f; + + // context: Python shape [B=1, S=8, D=128] → ggml ne [128, 8, 1] → sd::Tensor shape {128, 8, 1}. + auto context = load_raw_bin(ref_dir + "/tensors/model__context_in.bin", {128, 8, 1}); + + sd::Tensor empty_mask; + + std::printf("[input] "); + print_shape("latent=", latent.shape()); + std::printf("[input] σ = %.6f → t = %.3f\n", sigma_in.data()[0], timesteps.data()[0]); + + // --- 5. Run forward. + std::printf("[compute] running single forward pass…\n"); + auto out = runner.compute(/*n_threads=*/1, latent, timesteps, context, empty_mask); + + print_shape("[output] out.shape = ", out.shape()); + + // Dump first & last few values to catch silent NaN / zeros before diffing. + std::printf("[output] first 8: "); + for (int i = 0; i < 8 && i < out.numel(); ++i) std::printf("%+.4f ", out.data()[i]); + std::printf("\n"); + std::printf("[output] last 8: "); + for (int64_t i = std::max(0, out.numel() - 8); i < out.numel(); ++i) std::printf("%+.4f ", out.data()[i]); + std::printf("\n"); + + // --- 6. Diff vs reference. + auto ref = load_raw_bin(ref_dir + "/tensors/model__velocity_out_unflat.bin", {6, 4, 2, 16}); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "fatal: element count mismatch cpp=%ld ref=%ld\n", + out.numel(), ref.numel()); + return 1; + } + std::printf("[ref] first 8: "); + for (int i = 0; i < 8 && i < ref.numel(); ++i) std::printf("%+.4f ", ref.data()[i]); + std::printf("\n"); + + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + std::printf("\n=== velocity_out parity ===\n"); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + s.max_abs, s.max_abs_idx, + s.max_abs_idx >= 0 ? out.data()[s.max_abs_idx] : 0.f, + s.max_abs_idx >= 0 ? ref.data()[s.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", s.mean_abs); + std::printf(" max_rel = %.3e\n", s.max_rel); + std::printf(" n = %ld\n\n", out.numel()); + + // FP32 tolerances realistic for multi-layer DiT: accumulation order (ggml's mat-mul vs + // torch.matmul), softmax + rope + rms_norm order-of-ops, and bf16 casts in flash-attn paths + // all add ~1e-4 abs / ~1e-2 rel drift per block. Mean_abs is the more stable indicator. + // + // max_rel is only meaningful when every |ref[i]| is comfortably above the expected noise + // floor. The V1 reference happens to contain a single element with |ref| ≈ 4e-5, so a + // 1e-5 abs drift (far below our max_abs tolerance) alone pushes max_rel to ~0.3. Skip + // the max_rel check here for the same reason V2-deep does — abs/mean catch real drift + // and the near-zero rel spike is noise. + const float tol_max_abs = 1e-3f; + const float tol_mean_abs = 2e-4f; + bool pass_dit = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_dit ? "DiT parity: PASS" : "DiT parity: FAIL", + tol_max_abs, tol_mean_abs); + + bool pass_sched = check_schedule(ref_dir); + std::printf("%s\n", pass_sched ? "Scheduler parity: PASS" : "Scheduler parity: FAIL"); + + bool pass_euler = check_euler_step(ref_dir, runner, latent, context); + std::printf("%s\n", pass_euler ? "Euler step parity: PASS" : "Euler step parity: FAIL"); + + // --- V2 parity (cross_attention_adaln=true + apply_gated_attention=true) ----------------- + // The V1 check above validates the base path with both V2 features disabled. The production + // 22B checkpoint uses both. This block reloads the same state_dict with a V2-flagged runner + // and compares against Python's `v2model/velocity_out_unflat` dump. + std::printf("\n=== V2 parity (cross_attention_adaln + apply_gated_attention) ===\n"); + LTX::LTXParams v2_params; + v2_params.in_channels = 16; + v2_params.out_channels = 16; + v2_params.inner_dim = 128; + v2_params.num_heads = 4; + v2_params.head_dim = 32; + v2_params.num_layers = 2; + v2_params.cross_attention_dim = 128; + v2_params.cross_attention_adaln = true; + v2_params.apply_gated_attention = true; + + LTX::LTXRunner v2_runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model_v2", VERSION_LTX2, &v2_params); + v2_runner.set_fps(24.0f); + v2_runner.set_scale_factors(1, 1, 1); + v2_runner.set_causal_fix(false); + v2_runner.alloc_params_buffer(); + + std::map v2_param_tensors; + v2_runner.get_param_tensors(v2_param_tensors, "model.diffusion_model_v2"); + std::printf("[v2] loading %zu param tensors under model.diffusion_model_v2\n", v2_param_tensors.size()); + if (!loader.load_tensors(v2_param_tensors)) { + std::fprintf(stderr, "fatal: V2 load_tensors failed\n"); + return 1; + } + + auto v2_latent = load_raw_bin(ref_dir + "/tensors/v2model__latent_unflat.bin", {6, 4, 2, 16}); + auto v2_sigma = load_raw_bin(ref_dir + "/tensors/v2model__sigma.bin", {1}); + sd::Tensor v2_timesteps({1}); + v2_timesteps.data()[0] = v2_sigma.data()[0] * 1000.0f; + auto v2_context = load_raw_bin(ref_dir + "/tensors/v2model__context_in.bin", {128, 8, 1}); + sd::Tensor v2_empty_mask; + + auto v2_out = v2_runner.compute(/*n_threads=*/1, v2_latent, v2_timesteps, v2_context, v2_empty_mask); + auto v2_ref = load_raw_bin(ref_dir + "/tensors/v2model__velocity_out_unflat.bin", {6, 4, 2, 16}); + + std::printf("[v2 output] first 8: "); + for (int i = 0; i < 8 && i < v2_out.numel(); ++i) std::printf("%+.4f ", v2_out.data()[i]); + std::printf("\n[v2 ref] first 8: "); + for (int i = 0; i < 8 && i < v2_ref.numel(); ++i) std::printf("%+.4f ", v2_ref.data()[i]); + std::printf("\n"); + + auto sv2 = diff_fp32(v2_out.data(), v2_ref.data(), v2_out.numel()); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + sv2.max_abs, sv2.max_abs_idx, + sv2.max_abs_idx >= 0 ? v2_out.data()[sv2.max_abs_idx] : 0.f, + sv2.max_abs_idx >= 0 ? v2_ref.data()[sv2.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", sv2.mean_abs); + std::printf(" max_rel = %.3e\n", sv2.max_rel); + + // Same max_rel skip as the V1 block above: the reference can contain a handful of + // near-zero elements whose tiny abs drift blows the relative error up without being + // a real parity regression. abs/mean catch actual drift. + bool pass_v2 = sv2.max_abs < tol_max_abs && sv2.mean_abs < tol_mean_abs; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_v2 ? "V2 DiT parity: PASS" : "V2 DiT parity: FAIL", + tol_max_abs, tol_mean_abs); + + // --- V2-deep parity: 8 layers + non-zero scale_shift_table ------------------------------- + // The V2 check above uses 2 layers with zeroed sst, so modulation is effectively identity + // and can hide sign/broadcast bugs in the (1+scale) and shift_kv/scale_kv branches. This + // block loads an 8-layer variant with randomised sst weights so any cross-layer drift in + // the V2 path surfaces. + std::printf("\n=== V2-deep parity (8 layers + non-zero scale_shift_table) ===\n"); + LTX::LTXParams v2_deep_params = v2_params; + v2_deep_params.num_layers = 8; + + LTX::LTXRunner v2_deep_runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model_v2_deep", VERSION_LTX2, &v2_deep_params); + v2_deep_runner.set_fps(24.0f); + v2_deep_runner.set_scale_factors(1, 1, 1); + v2_deep_runner.set_causal_fix(false); + v2_deep_runner.alloc_params_buffer(); + + std::map v2_deep_param_tensors; + v2_deep_runner.get_param_tensors(v2_deep_param_tensors, "model.diffusion_model_v2_deep"); + std::printf("[v2-deep] loading %zu param tensors\n", v2_deep_param_tensors.size()); + if (!loader.load_tensors(v2_deep_param_tensors)) { + std::fprintf(stderr, "fatal: V2-deep load_tensors failed\n"); + return 1; + } + + auto v2d_latent = load_raw_bin(ref_dir + "/tensors/v2deep__latent_unflat.bin", {6, 4, 2, 16}); + auto v2d_sigma = load_raw_bin(ref_dir + "/tensors/v2deep__sigma.bin", {1}); + sd::Tensor v2d_timesteps({1}); + v2d_timesteps.data()[0] = v2d_sigma.data()[0] * 1000.0f; + auto v2d_context = load_raw_bin(ref_dir + "/tensors/v2deep__context_in.bin", {128, 8, 1}); + sd::Tensor v2d_empty_mask; + + auto v2d_out = v2_deep_runner.compute(/*n_threads=*/1, v2d_latent, v2d_timesteps, v2d_context, v2d_empty_mask); + auto v2d_ref = load_raw_bin(ref_dir + "/tensors/v2deep__velocity_out_unflat.bin", {6, 4, 2, 16}); + + std::printf("[v2-deep output] first 8: "); + for (int i = 0; i < 8 && i < v2d_out.numel(); ++i) std::printf("%+.4f ", v2d_out.data()[i]); + std::printf("\n[v2-deep ref] first 8: "); + for (int i = 0; i < 8 && i < v2d_ref.numel(); ++i) std::printf("%+.4f ", v2d_ref.data()[i]); + std::printf("\n"); + + auto sv2d = diff_fp32(v2d_out.data(), v2d_ref.data(), v2d_out.numel()); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + sv2d.max_abs, sv2d.max_abs_idx, + sv2d.max_abs_idx >= 0 ? v2d_out.data()[sv2d.max_abs_idx] : 0.f, + sv2d.max_abs_idx >= 0 ? v2d_ref.data()[sv2d.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", sv2d.mean_abs); + std::printf(" max_rel = %.3e\n", sv2d.max_rel); + + // Tolerance: max_rel is dropped here because per-element rel_err with b_i in the 1e-4 + // range produces meaningless blow-ups (100% rel for 1e-4 abs). max_abs and mean_abs are + // the reliable signals — both on the order of the 2-layer V2 test confirms no accumulated + // drift across 8 layers × non-zero sst modulation. + const float tol_max_abs_deep = 5e-3f; + const float tol_mean_abs_deep = 1e-3f; + bool pass_v2_deep = sv2d.max_abs < tol_max_abs_deep && sv2d.mean_abs < tol_mean_abs_deep; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_v2_deep ? "V2-deep DiT parity: PASS" : "V2-deep DiT parity: FAIL", + tol_max_abs_deep, tol_mean_abs_deep); + + bool pass = pass_dit && pass_sched && pass_euler && pass_v2 && pass_v2_deep; + std::printf("\n%s\n", pass ? "ALL PARITY: PASS" : "ALL PARITY: FAIL"); + return pass ? 0 : 3; +} diff --git a/tests/ltx_parity/test_s2d_primitives.cpp b/tests/ltx_parity/test_s2d_primitives.cpp new file mode 100644 index 000000000..a2758ffd1 --- /dev/null +++ b/tests/ltx_parity/test_s2d_primitives.cpp @@ -0,0 +1,185 @@ +// Standalone test: verify our axis-W / axis-H / axis-T SpaceToDepth and +// DepthToSpace ggml recipes against Python einops `rearrange(...)` outputs +// dumped by dump_s2d.py. Composition tests cover the stride patterns used +// by the VAE: (2,2,2), (1,2,2), (2,1,1). + +#include +#include +#include +#include +#include +#include + +#include "ggml-cpu.h" +#include "ggml.h" +#include "ltxvae_primitives.hpp" + +namespace { + +constexpr int B = 1; +constexpr int C = 3; +constexpr int T = 4; +constexpr int H = 6; +constexpr int W = 8; +constexpr int FACTOR = 2; + +std::vector load_bin(const std::string& path, size_t expected_numel) { + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { std::fprintf(stderr, "cannot open %s\n", path.c_str()); std::exit(2); } + std::vector buf(expected_numel); + f.read(reinterpret_cast(buf.data()), expected_numel * sizeof(float)); + if (!f.good()) { std::fprintf(stderr, "short read on %s\n", path.c_str()); std::exit(2); } + return buf; +} + +enum Kind { S2D_W, S2D_H, S2D_T, S2D_222, S2D_122, S2D_211, + D2S_W, D2S_H, D2S_T, D2S_222, D2S_122, D2S_211, + PIXEL_NORM, PCS_NORMALIZE, PCS_UNNORMALIZE }; + +struct CaseSpec { + const char* name; + std::vector in_shape_ne; + std::vector expected_shape_ne; + Kind kind; +}; + +bool run_case(const CaseSpec& cs, const std::string& ref_dir) { + size_t in_numel = 1, out_numel = 1; + for (auto d : cs.in_shape_ne) in_numel *= d; + for (auto d : cs.expected_shape_ne) out_numel *= d; + + const bool is_d2s = (cs.kind >= D2S_W && cs.kind <= D2S_211); + const bool is_pn = (cs.kind == PIXEL_NORM); + const bool is_pcs = (cs.kind == PCS_NORMALIZE || cs.kind == PCS_UNNORMALIZE); + std::string in_file, exp_file; + if (is_pn) { + in_file = ref_dir + "/tensors/pn_input.bin"; + exp_file = ref_dir + "/tensors/pn_expected.bin"; + } else if (is_pcs) { + in_file = ref_dir + "/tensors/pcs_input.bin"; + exp_file = ref_dir + "/tensors/" + + std::string(cs.kind == PCS_NORMALIZE ? "pcs_normalize_expected.bin" + : "pcs_unnormalize_expected.bin"); + } else { + in_file = ref_dir + "/tensors/" + (is_d2s ? "dinput_" : "input_") + cs.name + ".bin"; + exp_file = ref_dir + "/tensors/" + (is_d2s ? "dexpected_" : "expected_") + cs.name + ".bin"; + } + auto in_data = load_bin(in_file, in_numel); + auto expected = load_bin(exp_file, out_numel); + + size_t mem_size = 128 * 1024 * 1024; + std::vector mem_buf(mem_size); + ggml_init_params params{mem_size, mem_buf.data(), false}; + ggml_context* ctx = ggml_init(params); + + ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, + cs.in_shape_ne[0], cs.in_shape_ne[1], + cs.in_shape_ne[2], cs.in_shape_ne[3]); + std::memcpy(x->data, in_data.data(), in_numel * sizeof(float)); + + ggml_tensor* y = nullptr; + ggml_tensor* mu_t = nullptr; + ggml_tensor* sigma_t = nullptr; + std::vector mu_data, sigma_data; + switch (cs.kind) { + case S2D_W: y = LTXVAE::space_to_depth_axisW(ctx, x, FACTOR); break; + case S2D_H: y = LTXVAE::space_to_depth_axisH(ctx, x, FACTOR); break; + case S2D_T: y = LTXVAE::space_to_depth_axisT(ctx, x, FACTOR); break; + case S2D_222: y = LTXVAE::space_to_depth(ctx, x, FACTOR, FACTOR, FACTOR); break; + case S2D_122: y = LTXVAE::space_to_depth(ctx, x, 1, FACTOR, FACTOR); break; + case S2D_211: y = LTXVAE::space_to_depth(ctx, x, FACTOR, 1, 1); break; + case D2S_W: y = LTXVAE::depth_to_space_axisW(ctx, x, FACTOR); break; + case D2S_H: y = LTXVAE::depth_to_space_axisH(ctx, x, FACTOR); break; + case D2S_T: y = LTXVAE::depth_to_space_axisT(ctx, x, FACTOR); break; + case D2S_222: y = LTXVAE::depth_to_space(ctx, x, FACTOR, FACTOR, FACTOR); break; + case D2S_122: y = LTXVAE::depth_to_space(ctx, x, 1, FACTOR, FACTOR); break; + case D2S_211: y = LTXVAE::depth_to_space(ctx, x, FACTOR, 1, 1); break; + case PIXEL_NORM: y = LTXVAE::pixel_norm(ctx, x, 1e-8f); break; + case PCS_NORMALIZE: + case PCS_UNNORMALIZE: { + int64_t C = cs.in_shape_ne[3]; + mu_data = load_bin(ref_dir + "/tensors/pcs_mu.bin", (size_t)C); + sigma_data = load_bin(ref_dir + "/tensors/pcs_sigma.bin", (size_t)C); + mu_t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, C); + sigma_t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, C); + std::memcpy(mu_t->data, mu_data.data(), C * sizeof(float)); + std::memcpy(sigma_t->data, sigma_data.data(), C * sizeof(float)); + y = (cs.kind == PCS_NORMALIZE) + ? LTXVAE::pcs_normalize(ctx, x, mu_t, sigma_t) + : LTXVAE::pcs_unnormalize(ctx, x, mu_t, sigma_t); + } break; + } + + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, y); + ggml_graph_compute_with_ctx(ctx, gf, 1); + + bool shape_ok = true; + for (int i = 0; i < 4; i++) if (y->ne[i] != cs.expected_shape_ne[i]) { shape_ok = false; break; } + if (!shape_ok) { + std::printf(" %-18s SHAPE_FAIL got=[%lld,%lld,%lld,%lld] exp=[%lld,%lld,%lld,%lld]\n", + cs.name, + (long long)y->ne[0], (long long)y->ne[1], (long long)y->ne[2], (long long)y->ne[3], + (long long)cs.expected_shape_ne[0], (long long)cs.expected_shape_ne[1], + (long long)cs.expected_shape_ne[2], (long long)cs.expected_shape_ne[3]); + ggml_free(ctx); + return false; + } + + const float* got = (const float*)y->data; + float max_abs = 0.f; + int64_t first_diff = -1; + for (size_t i = 0; i < out_numel; i++) { + float d = std::abs(got[i] - expected[i]); + if (d > max_abs) { max_abs = d; if (first_diff < 0) first_diff = (int64_t)i; } + } + // PixelNorm / PCS involve f32 divides & rms; relax tolerance slightly. + float tol = (cs.kind >= PIXEL_NORM) ? 5e-6f : 1e-6f; + bool pass = max_abs < tol; + std::printf(" %-18s %s max_abs=%.3e", cs.name, pass ? "PASS" : "FAIL", max_abs); + if (!pass && first_diff >= 0) { + std::printf(" first_diff_idx=%lld got=%.6f exp=%.6f", + (long long)first_diff, got[first_diff], expected[first_diff]); + } + std::printf("\n"); + + ggml_free(ctx); + return pass; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/s2d_ref"; + + std::vector cases = { + // SpaceToDepth + {"axisW", {W*FACTOR, H, T, C}, {W, H, T, C*FACTOR}, S2D_W}, + {"axisH", {W, H*FACTOR, T, C}, {W, H, T, C*FACTOR}, S2D_H}, + {"axisT", {W, H, T*FACTOR, C}, {W, H, T, C*FACTOR}, S2D_T}, + {"full222", {W*FACTOR, H*FACTOR, T*FACTOR, C}, {W, H, T, C*8}, S2D_222}, + {"full122", {W*FACTOR, H*FACTOR, T, C}, {W, H, T, C*4}, S2D_122}, + {"full211", {W, H, T*FACTOR, C}, {W, H, T, C*2}, S2D_211}, + // DepthToSpace (input has extra channels) + {"axisW", {W, H, T, C*FACTOR}, {W*FACTOR, H, T, C}, D2S_W}, + {"axisH", {W, H, T, C*FACTOR}, {W, H*FACTOR, T, C}, D2S_H}, + {"axisT", {W, H, T, C*FACTOR}, {W, H, T*FACTOR, C}, D2S_T}, + {"full222", {W, H, T, C*8}, {W*FACTOR, H*FACTOR, T*FACTOR, C}, D2S_222}, + {"full122", {W, H, T, C*4}, {W*FACTOR, H*FACTOR, T, C}, D2S_122}, + {"full211", {W, H, T, C*2}, {W, H, T*FACTOR, C}, D2S_211}, + // PixelNorm (dim=channel) and PerChannelStatistics + {"pn", {W, H, T, 5}, {W, H, T, 5}, PIXEL_NORM}, + {"pcs_norm", {W, H, T, 6}, {W, H, T, 6}, PCS_NORMALIZE}, + {"pcs_unnorm", {W, H, T, 6}, {W, H, T, 6}, PCS_UNNORMALIZE}, + }; + + std::printf("SpaceToDepth primitive parity:\n"); + int pass = 0; + for (size_t i = 0; i < cases.size(); i++) { + if (i == 6) std::printf("\nDepthToSpace primitive parity:\n"); + if (i == 12) std::printf("\nNorm primitives parity:\n"); + if (run_case(cases[i], ref_dir)) pass++; + } + std::printf("\n%d / %zu cases passed.\n", pass, cases.size()); + return (pass == (int)cases.size()) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_vae_parity.cpp b/tests/ltx_parity/test_vae_parity.cpp new file mode 100644 index 000000000..f4005683d --- /dev/null +++ b/tests/ltx_parity/test_vae_parity.cpp @@ -0,0 +1,378 @@ +// LTX-2 VAE C++ parity test. +// +// Loads /tmp/vae_ref/state_dict.safetensors (from dump_vae.py) plus the per-stage +// reference trace tensors, runs our C++ VideoEncoder + VideoDecoder, and diffs +// each stage against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ltxvae.hpp" + +// Standalone GGMLRunner that wraps a single LTXVAE::TimestepEmbedder block so we can +// isolate the sinusoidal + 2-linear path from the full VAE pipeline. +struct TERunner : public GGMLRunner { + LTXVAE::TimestepEmbedder te; + + TERunner(ggml_backend_t backend, bool offload, const String2TensorStorage& tsm, + const std::string& prefix, int embedding_dim) + : GGMLRunner(backend, offload), te(embedding_dim) { + te.init(params_ctx, tsm, prefix); + } + std::string get_desc() override { return "ltx2_vae_te_probe"; } + void get_param_tensors(std::map& tensors, const std::string& prefix) { + te.get_param_tensors(tensors, prefix); + } + sd::Tensor compute(int n_threads, const sd::Tensor& timestep) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* t = make_input(timestep); + auto runner_ctx = get_context(); + auto out = te.forward(&runner_ctx, t); + ggml_build_forward_expand(gf, out); + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } +}; + +// Runs JUST the ada-values reshape+slice on a pre-computed time_embed. Returns one of +// the 4 slices (chosen by `which` in 0..3 → shift1, scale1, shift2, scale2). This +// isolates the PyTorch `timestep.reshape(B, 4, -1, 1, 1, 1)` → unbind(dim=1) path +// in pure GGML ops to verify memory-order correctness. +struct ShiftProbeRunner : public GGMLRunner { + int in_channels; + int which; + ShiftProbeRunner(ggml_backend_t backend, bool offload, int in_ch, int which) + : GGMLRunner(backend, offload), in_channels(in_ch), which(which) {} + std::string get_desc() override { return "ltx2_vae_shift_probe"; } + sd::Tensor compute(int n_threads, const sd::Tensor& time_embed) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* te = make_input(time_embed); // ne=[4*C, 1] + auto re = ggml_reshape_2d(compute_ctx, te, in_channels, 4); // [C, 4] + auto out = ggml_ext_slice(compute_ctx, re, 1, which, which + 1); // [C, 1] + out = ggml_cont(compute_ctx, out); + ggml_build_forward_expand(gf, out); + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } +}; +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f, mean_abs = 0.f, max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { s.max_abs = abs_err; s.max_abs_idx = i; } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +bool compare(const std::string& tag, const sd::Tensor& got, + const std::string& ref_path, const std::vector& ref_shape, + float tol_max_abs, float tol_mean_abs) { + auto ref = load_raw_bin(ref_path, ref_shape); + if (got.numel() != ref.numel()) { + std::printf(" %-20s SHAPE_FAIL got_numel=%ld ref_numel=%ld\n", + tag.c_str(), got.numel(), ref.numel()); + return false; + } + auto s = diff_fp32(got.data(), ref.data(), got.numel()); + bool pass = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf(" %-20s %s max_abs=%.3e mean_abs=%.3e n=%ld\n", + tag.c_str(), pass ? "PASS" : "FAIL", s.max_abs, s.mean_abs, got.numel()); + return pass; +} + +} // namespace + +int main() { + // Enable library logging so load_tensors shape mismatches surface on stderr. + sd_set_log_callback( + [](enum sd_log_level_t /*level*/, const char* text, void* /*data*/) { + std::fputs(text, stderr); + }, + nullptr); + + const std::string ref_dir = "/tmp/vae_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // Tiny config from dump_vae.py: in=3, latent=8, patch=2, base_ch=8. + // Encoder: compress_space_res(×2), compress_time_res(×2), res_x(1 layer). + // Decoder: compress_space(m=1), compress_time(m=1), res_x(1 layer, timestep_cond=True). + const int in_ch = 3; + const int latent_ch = 8; + const int base_ch = 8; + const int patch = 2; + const int B = 1, F = 9, H = 16, W_ = 16; + + std::vector enc_specs = { + {LTXVAE::EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {LTXVAE::EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {LTXVAE::EncoderBlockKind::RES_X, 1, 1}, + }; + std::vector dec_specs = { + {LTXVAE::DecoderBlockKind::COMPRESS_SPACE, 1, 1}, + {LTXVAE::DecoderBlockKind::COMPRESS_TIME, 1, 1}, + {LTXVAE::DecoderBlockKind::RES_X, 1, 1}, + }; + + ggml_backend_t backend = ggml_backend_cpu_init(); + + // --- Encoder --- + LTXVAE::VAEEncoderRunner enc_runner(backend, /*offload=*/false, tsm, + /*prefix=*/"vae.encoder", + in_ch, latent_ch, patch, enc_specs); + enc_runner.alloc_params_buffer(); + std::map enc_params; + enc_runner.get_param_tensors(enc_params, "vae.encoder"); + std::printf("[enc] requesting %zu param tensors\n", enc_params.size()); + if (!loader.load_tensors(enc_params)) { + std::fprintf(stderr, "fatal: encoder load_tensors failed\n"); + return 1; + } + + // Load video input. Python shape (1, 3, 9, 16, 16) → GGML ne=[W=16, H=16, T=9, C=3]. + auto video_in = load_raw_bin(ref_dir + "/tensors/video_in.bin", {W_, H, F, in_ch}); + std::printf("[enc] running encoder (traced)\n"); + + bool pass = true; + struct Stage { int idx; const char* name; std::vector shape; float abs_tol, mean_tol; }; + // Dump order & shapes (PyTorch-majored): + // 0 post_patchify (1,12,9,8,8) → ne=[8,8,9,12] + // 1 post_conv_in (1,8,9,8,8) → ne=[8,8,9,8] + // 2 down_block[0] (cs) (1,16,9,4,4) → ne=[4,4,9,16] + // 3 down_block[1] (ct) (1,32,5,4,4) → ne=[4,4,5,32] + // 4 down_block[2] (res)(1,32,5,4,4) → ne=[4,4,5,32] + // 5 post_norm (1,32,5,4,4) → ne=[4,4,5,32] + // 6 post_conv_out (1,9,5,4,4) → ne=[4,4,5,9] + // 7 means_preNorm (1,8,5,4,4) → ne=[4,4,5,8] + // 8 latent (1,8,5,4,4) → ne=[4,4,5,8] + // Conv3d weights are stored f16 in the block — each conv boundary introduces a + // fp16-quantization step (~1e-3 abs per layer). Tolerances are set accordingly. + std::vector stages = { + {0, "enc_post_patchify", {8, 8, F, 12}, 1e-6f, 1e-7f}, // pure rearrange + {1, "enc_post_conv_in", {8, 8, F, 8}, 2e-3f, 3e-4f}, + {2, "enc_block_0", {4, 4, F, 16}, 3e-3f, 5e-4f}, + {3, "enc_block_1", {4, 4, 5, 32}, 5e-3f, 8e-4f}, + {4, "enc_block_2", {4, 4, 5, 32}, 5e-3f, 1e-3f}, + {5, "enc_post_norm", {4, 4, 5, 32}, 5e-3f, 1e-3f}, + {6, "enc_post_conv_out", {4, 4, 5, 9}, 5e-3f, 1e-3f}, + {7, "enc_means_preNorm", {4, 4, 5, 8}, 5e-3f, 1e-3f}, + {8, "latent", {4, 4, 5, 8}, 5e-3f, 1e-3f}, + }; + for (const auto& s : stages) { + auto got = enc_runner.compute(1, video_in, s.idx); + pass &= compare(s.name, got, ref_dir + "/tensors/" + s.name + ".bin", s.shape, + s.abs_tol, s.mean_tol); + } + + // --- Decoder --- + LTXVAE::VAEDecoderRunner dec_runner(backend, /*offload=*/false, tsm, + /*prefix=*/"vae.decoder", + latent_ch, in_ch, patch, base_ch, + /*timestep_cond=*/true, dec_specs); + dec_runner.alloc_params_buffer(); + std::map dec_params; + dec_runner.get_param_tensors(dec_params, "vae.decoder"); + std::printf("[dec] requesting %zu param tensors\n", dec_params.size()); + + // Diagnose any name/shape mismatches. + std::set file_keys; + for (const auto& kv : tsm) file_keys.insert(kv.first); + int missing = 0; + for (const auto& pt : dec_params) { + auto it = file_keys.find(pt.first); + if (it == file_keys.end()) { + if (missing < 10) std::printf("[dec] missing: %s\n", pt.first.c_str()); + missing++; + } + } + std::printf("[dec] %d / %zu tensors missing from file\n", missing, dec_params.size()); + + if (!loader.load_tensors(dec_params)) { + std::fprintf(stderr, "fatal: decoder load_tensors failed\n"); + return 1; + } + + // Feed the Python reference latent to the decoder so its diffs are independent of + // encoder errors. Once encoder parity is green we can chain them. + auto latent_ref = load_raw_bin(ref_dir + "/tensors/latent.bin", {4, 4, 5, latent_ch}); + sd::Tensor timestep_t({1}); + timestep_t.data()[0] = 0.05f; + // TimestepEmbedder micro-probe: bypass the full decoder and run just the + // up_blocks[0].time_embedder on timestep=0.05 to verify the sinusoidal + linear path. + { + TERunner te_runner(backend, false, tsm, "vae.decoder.up_blocks.0.time_embedder", 256); + te_runner.alloc_params_buffer(); + std::map te_params; + te_runner.get_param_tensors(te_params, "vae.decoder.up_blocks.0.time_embedder"); + if (!loader.load_tensors(te_params)) { + std::fprintf(stderr, "fatal: TE load failed\n"); + return 1; + } + auto te_out = te_runner.compute(1, timestep_t); + // Python dumps shape [B=1, 256] → innermost 256. sd::Tensor stores innermost-first, + // so shape is {256, 1}. Same numel. + pass &= compare("TimestepEmbedder", te_out, ref_dir + "/tensors/te_probe_up0.bin", + {256, 1}, 1e-4f, 1e-5f); + + // Verify the ada-values reshape+slice: Python does `te.reshape(B,4,-1,1,1,1)` → + // unbind(dim=1). The four unbound slices should be te[0:64], te[64:128], te[128:192], + // te[192:256]. Run each slice through the C++ reshape+slice path and byte-compare. + auto te_ref = load_raw_bin(ref_dir + "/tensors/te_probe_up0.bin", {256, 1}); + const char* which_names[] = {"shift1", "scale1", "shift2", "scale2"}; + for (int w = 0; w < 4; w++) { + ShiftProbeRunner sp(backend, false, /*in_ch=*/64, w); + auto slice = sp.compute(1, te_ref); + float maxd = 0.f; + for (int i = 0; i < 64; i++) { + float d = std::fabs(slice.data()[i] - te_ref.data()[w * 64 + i]); + if (d > maxd) maxd = d; + } + std::printf(" shift-probe %-7s max_abs vs te[%d:%d]=%.3e\n", + which_names[w], w * 64, (w + 1) * 64, maxd); + pass &= (maxd < 1e-6f); + } + } + + // Per-stage trace now includes intermediates pushed INSIDE the first res_x block: + // 0 post_unnorm, 1 post_conv_in, 2 time_embed, 3 post_norm1, 4 shift1, 5 scale1, + // 6 post_adaln1, 7 post_conv1, 8 post_norm2, 9 up_block[0] out, ... + auto got_norm1 = dec_runner.compute(1, latent_ref, timestep_t, 3); + pass &= compare("resblock0 post_norm1", got_norm1, + ref_dir + "/tensors/dec_resblock0_post_norm1.bin", + {4, 4, 5, 64}, 2e-3f, 5e-4f); + auto got_adaln1 = dec_runner.compute(1, latent_ref, timestep_t, 6); + pass &= compare("resblock0 post_adaln1", got_adaln1, + ref_dir + "/tensors/dec_resblock0_post_adaln1.bin", + {4, 4, 5, 64}, 5e-3f, 1e-3f); + auto got_conv1 = dec_runner.compute(1, latent_ref, timestep_t, 7); + pass &= compare("resblock0 post_conv1", got_conv1, + ref_dir + "/tensors/dec_resblock0_post_conv1.bin", + {4, 4, 5, 64}, 5e-3f, 1e-3f); + auto got_norm2 = dec_runner.compute(1, latent_ref, timestep_t, 8); + pass &= compare("resblock0 post_norm2", got_norm2, + ref_dir + "/tensors/dec_resblock0_post_norm2.bin", + {4, 4, 5, 64}, 1e-2f, 2e-3f); + + // After causal=false + reflect-padding fixes, trace indices in the decoder have shifted. + // New layout: + // 0 post_unnorm 1 post_conv_in 2 time_embed 3 post_norm1 4 shift1 + // 5 scale1 6 post_adaln1 7 post_conv1 8 post_norm2 9 up_block[0] out + // 10 up_block[1] 11 up_block[2] 12 post_pixel_norm 13 post_ada + // 14 post_conv_out 15 video_out + struct Stage2 { int idx; const char* name; std::vector shape; float atol, mtol; }; + std::vector stages2 = { + // Shapes in ne-order (W, H, T, C) after each decoder block. + // Compress_time expands T 5→9; compress_space expands spatial 4→8 (patch=2 still + // to apply at the very end via unpatchify). + { 9, "dec_block_0", {4, 4, 5, 64}, 1e-2f, 2e-3f}, + {10, "dec_block_1", {4, 4, F, 64}, 1e-2f, 2e-3f}, + {11, "dec_block_2", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {12, "dec_post_pixel_norm", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {13, "dec_post_ada", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {14, "dec_post_conv_out", {8, 8, F, 12}, 2e-2f, 4e-3f}, + }; + for (const auto& s : stages2) { + auto got = dec_runner.compute(1, latent_ref, timestep_t, s.idx); + pass &= compare(s.name, got, + ref_dir + "/tensors/" + std::string(s.name) + ".bin", + s.shape, s.atol, s.mtol); + } + + auto decoded = dec_runner.compute(1, latent_ref, timestep_t); + pass &= compare("dec video", decoded, ref_dir + "/tensors/video_out.bin", {W_, H, F, in_ch}, 1e-2f, 2e-3f); + + // Per-stage probe on the same runner (since GGMLRunner can be reused across + // multiple computes, as the encoder path does 9 times without issue). + if (!pass) { + std::printf("\n[dec] per-stage probe:\n"); + const char* stage_names[] = { + "dec_post_unnorm", "dec_post_conv_in", "dec_block_0", "dec_block_1", "dec_block_2", + "dec_post_pixel_norm", "dec_post_ada", "dec_post_conv_out", "video_out" + }; + for (int idx = 0; idx < 9; idx++) { + std::printf(" [%d] stage=%s computing...\n", idx, stage_names[idx]); std::fflush(stdout); + auto out = dec_runner.compute(1, latent_ref, timestep_t, idx); + std::string tag = stage_names[idx]; + std::string ref_path = ref_dir + "/tensors/" + tag + ".bin"; + std::ifstream check(ref_path); + if (check.good()) { + check.close(); + std::vector shape = {out.shape()[0], out.shape()[1], out.shape()[2], out.shape()[3]}; + auto ref = load_raw_bin(ref_path, shape); + if (ref.numel() != out.numel()) { + std::printf(" [%d] %-20s SHAPE_MISMATCH got=%ld ref=%ld (shape=%ld,%ld,%ld,%ld)\n", + idx, tag.c_str(), out.numel(), ref.numel(), + shape[0], shape[1], shape[2], shape[3]); + continue; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + std::printf(" [%d] %-20s max_abs=%.3e mean_abs=%.3e\n", + idx, tag.c_str(), s.max_abs, s.mean_abs); + } else { + float m0 = 0.f, m1 = 0.f; + for (int64_t i = 0; i < out.numel(); i++) { float a = std::fabs(out.data()[i]); m0 = std::max(m0, a); m1 += a; } + m1 /= out.numel() > 0 ? out.numel() : 1; + std::printf(" [%d] %-20s (no ref) shape=[%ld,%ld,%ld,%ld] max_abs=%.3f mean_abs=%.3f\n", + idx, tag.c_str(), + out.shape()[0], out.shape()[1], out.shape()[2], out.shape()[3], m0, m1); + } + } + } + + std::printf("\n%s\n", pass ? "ALL VAE PARITY: PASS" : "ALL VAE PARITY: FAIL"); + (void)B; + return pass ? 0 : 3; +} From 26ea8ea4439031d5fa9ae79e4f49ef9a358a1ab9 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Fri, 24 Apr 2026 22:19:03 +0200 Subject: [PATCH 2/2] Add backend fitting, some fixes --- examples/common/common.cpp | 35 ++ examples/common/common.h | 8 + include/stable-diffusion.h | 14 + src/backend_fit.hpp | 525 ++++++++++++++++++++ src/conditioner.hpp | 20 + src/ggml_extend.hpp | 75 ++- src/llm.hpp | 150 +++++- src/model.h | 2 + src/stable-diffusion.cpp | 223 +++++++-- tests/ltx_parity/CMakeLists.txt | 9 + tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp | 309 ++++++++++++ 11 files changed, 1287 insertions(+), 83 deletions(-) create mode 100644 src/backend_fit.hpp create mode 100644 tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 7ac0a0d30..bef4b0e96 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -380,6 +380,25 @@ ArgOptions SDContextParams::get_options() { "--chroma-t5-mask-pad", "t5 mask pad size of chroma", &chroma_t5_mask_pad}, + {"", + "--fit-target", + "auto-fit: MiB of free memory to leave on each GPU (default: 512)", + &auto_fit_target_mb}, + {"", + "--fit-compute-reserve-dit", + "auto-fit: MiB reserved on the DiT's GPU for its compute buffer " + "(default: 2048, 0 keeps the built-in default)", + &auto_fit_compute_reserve_dit_mb}, + {"", + "--fit-compute-reserve-vae", + "auto-fit: MiB reserved on the VAE's GPU for its compute buffer " + "(default: 1024, 0 keeps the built-in default)", + &auto_fit_compute_reserve_vae_mb}, + {"", + "--fit-compute-reserve-cond", + "auto-fit: MiB reserved on the conditioner's GPU for its compute " + "buffer (default: 512, 0 keeps the built-in default)", + &auto_fit_compute_reserve_cond_mb}, }; options.float_options = {}; @@ -449,6 +468,16 @@ ArgOptions SDContextParams::get_options() { "--chroma-enable-t5-mask", "enable t5 mask for chroma", true, &chroma_use_t5_mask}, + {"", + "--auto-fit", + "automatically pick DiT/VAE/Conditioner device placements based on " + "free GPU memory (priority: DiT+compute > VAE > Conditioner; " + "overflow goes to CPU or DiT-params-offload mode)", + true, &auto_fit}, + {"", + "--fit-dry-run", + "auto-fit: print the computed plan and exit without loading models", + true, &auto_fit_dry_run}, }; auto on_type_arg = [&](int argc, const char** argv, int index) { @@ -733,6 +762,12 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f chroma_use_t5_mask, chroma_t5_mask_pad, qwen_image_zero_cond_t, + auto_fit, + auto_fit_target_mb, + auto_fit_dry_run, + auto_fit_compute_reserve_dit_mb, + auto_fit_compute_reserve_vae_mb, + auto_fit_compute_reserve_cond_mb, }; return sd_ctx_params; } diff --git a/examples/common/common.h b/examples/common/common.h index 6e405b050..ab9e6864d 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -128,6 +128,14 @@ struct SDContextParams { bool qwen_image_zero_cond_t = false; + // Auto-fit: pick DiT/VAE/Conditioner device placements from free GPU memory. + bool auto_fit = false; + int auto_fit_target_mb = 512; + bool auto_fit_dry_run = false; + int auto_fit_compute_reserve_dit_mb = 0; // 0 = use header default + int auto_fit_compute_reserve_vae_mb = 0; + int auto_fit_compute_reserve_cond_mb = 0; + prediction_t prediction = PREDICTION_COUNT; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 9ab335627..370852f13 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -209,6 +209,20 @@ typedef struct { bool chroma_use_t5_mask; int chroma_t5_mask_pad; bool qwen_image_zero_cond_t; + + // Auto-fit: pick DiT/VAE/Conditioner devices based on free GPU memory. + // When `auto_fit` is true, the CLI placement overrides (env vars, + // keep_*_on_cpu) are ignored and the plan is computed automatically. + // `auto_fit_target_mb` is the memory to leave free per GPU (default 512). + // `auto_fit_dry_run` prints the plan and aborts init before loading. + // `auto_fit_compute_reserve_{dit,vae,cond}_mb` let the user tune the + // per-component compute-buffer reserve; 0 means use the built-in default. + bool auto_fit; + int auto_fit_target_mb; + bool auto_fit_dry_run; + int auto_fit_compute_reserve_dit_mb; + int auto_fit_compute_reserve_vae_mb; + int auto_fit_compute_reserve_cond_mb; } sd_ctx_params_t; typedef struct { diff --git a/src/backend_fit.hpp b/src/backend_fit.hpp new file mode 100644 index 000000000..4524ae93a --- /dev/null +++ b/src/backend_fit.hpp @@ -0,0 +1,525 @@ +#ifndef __SD_BACKEND_FIT_HPP__ +#define __SD_BACKEND_FIT_HPP__ + +// Auto-fit algorithm for distributing DiT, VAE, and conditioner (LLM + +// connector) across available GPU devices and system RAM. +// +// Inspired by llama.cpp's common_fit_params (tools/fit-params), but much +// coarser: sd.cpp treats each of {DiT, VAE, Conditioner} as a single atomic +// unit that lives entirely on one device (plus the DiT's compute buffer on +// the same GPU). There is no per-layer tensor_buft_overrides mechanism in +// sd.cpp today — the existing `offload_params_to_cpu` knob is the only way to +// "split" a model (it keeps params in RAM and streams them to the runtime +// backend per forward pass). +// +// Placement priority: DiT + compute buffer → VAE → Conditioner (+connector). +// Overflow falls back to CPU (or GPU_OFFLOAD_PARAMS for DiT). + +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#if defined(SD_USE_VULKAN) +#include "ggml-backend.h" +#endif + +#include "model.h" +#include "util.h" + +namespace backend_fit { + +constexpr int64_t MiB = 1024 * 1024; +constexpr int DEVICE_ID_CPU = -1; + +enum class ComponentKind { + DIT, + VAE, + CONDITIONER, // LLM + connector (share a backend) +}; + +enum class Placement { + CPU, + GPU, + GPU_OFFLOAD_PARAMS, // params in RAM, compute on GPU (DiT-only) +}; + +struct Component { + ComponentKind kind; + std::string name; + int64_t params_bytes = 0; // weight memory for this component + int64_t compute_bytes = 0; // reserved compute buffer on the chosen device + bool supports_offload = false; // true only for DiT +}; + +struct Device { + int id = DEVICE_ID_CPU; + std::string name; + std::string description; + int64_t free_bytes = 0; + int64_t total_bytes = 0; +}; + +struct Decision { + ComponentKind kind; + std::string name; + Placement placement = Placement::CPU; + int device_id = DEVICE_ID_CPU; + int64_t on_device_bytes = 0; // contribution to device_id's device memory + int64_t on_host_bytes = 0; // contribution to host RAM +}; + +struct Plan { + std::vector decisions; + std::map device_bytes; // device_id -> bytes used + int64_t host_bytes = 0; + bool any_changes = false; // true if a non-default placement was chosen +}; + +// Defaults chosen to leave enough headroom for typical diffusion/video models. +// Configurable via the CLI (--fit-compute-reserve-* in MiB). +struct ComputeReserves { + int64_t dit_bytes = int64_t(2048) * MiB; // video DiT compute buffer + int64_t vae_bytes = int64_t(1024) * MiB; // video VAE compute buffer + int64_t conditioner_bytes = int64_t(512) * MiB; // LLM + connector combined +}; + +// --- Classification ------------------------------------------------------- + +// Classify a tensor name into a ComponentKind. Returns false if the tensor is +// unused / not a primary weight we should count. +inline bool classify_tensor(const std::string& name, ComponentKind& out) { + // Connector lives inside `model.diffusion_model.*` by prefix but runs on + // the conditioner's backend, so it gets charged to CONDITIONER. + auto contains = [&](const char* s) { return name.find(s) != std::string::npos; }; + + // LTX-2 specific: the checkpoint carries audio-to-video branch weights + // (`.audio_*`, `.audio_to_video_*`, `.video_to_audio_*`, `audio_patchify_*`, + // `audio_scale_shift_*`, `audio_prompt_*`) that the video-only LTX2 + // diffusion module does NOT wire in. They're logged as "unknown tensor" + // warnings at load time and skipped. Excluding them here keeps the DiT + // params estimate honest (~9 GB) instead of including ~4 GB of audio + // tensors that never touch the GPU. + if (contains(".audio_") || + contains("audio_patchify") || + contains("audio_aggregate") || + contains("audio_scale_shift") || + contains("audio_prompt") || + contains("a2v_ca_audio") || + contains("a2v_ca_video")) { + return false; + } + + if (contains("embeddings_connector") || + contains("aggregate_embed") || + contains("text_embedding_projection")) { + out = ComponentKind::CONDITIONER; + return true; + } + + if (contains("model.diffusion_model.") || contains("unet.")) { + out = ComponentKind::DIT; + return true; + } + + if (contains("first_stage_model.") || + name.rfind("vae.", 0) == 0 || + name.rfind("tae.", 0) == 0) { + out = ComponentKind::VAE; + return true; + } + + if (contains("text_encoders") || + contains("cond_stage_model") || + contains("te.text_model.") || + contains("conditioner") || + name.rfind("text_encoder.", 0) == 0) { + out = ComponentKind::CONDITIONER; + return true; + } + + return false; +} + +// --- Memory estimation ---------------------------------------------------- + +// Sum params bytes per component using the same alignment padding and +// dtype-conversion rules as ModelLoader::get_params_mem_size. +inline std::vector estimate_components(ModelLoader& loader, + ggml_type override_wtype, + int64_t alignment, + const ComputeReserves& reserves) { + auto& storage = loader.get_tensor_storage_map(); + + int64_t bytes[3] = {0, 0, 0}; // DIT, VAE, CONDITIONER + int counts[3] = {0, 0, 0}; + + for (auto& [name, ts_const] : storage) { + // Work on a copy so we can apply the dtype override without mutating. + TensorStorage ts = ts_const; + if (is_unused_tensor(ts.name)) { + continue; + } + + ComponentKind k; + if (!classify_tensor(ts.name, k)) { + continue; + } + + if (override_wtype != GGML_TYPE_COUNT && + loader.tensor_should_be_converted(ts, override_wtype)) { + ts.type = override_wtype; + } + + int idx = int(k); + bytes[idx] += ts.nbytes() + alignment; + counts[idx] += 1; + } + + std::vector out; + out.reserve(3); + + out.push_back(Component{ + ComponentKind::DIT, "DiT", + bytes[int(ComponentKind::DIT)], reserves.dit_bytes, + /*supports_offload=*/true, + }); + out.push_back(Component{ + ComponentKind::VAE, "VAE", + bytes[int(ComponentKind::VAE)], reserves.vae_bytes, + /*supports_offload=*/false, + }); + out.push_back(Component{ + ComponentKind::CONDITIONER, "Conditioner", + bytes[int(ComponentKind::CONDITIONER)], reserves.conditioner_bytes, + /*supports_offload=*/true, // Gemma/etc. can stream params to GPU per encode + }); + + (void)counts; + return out; +} + +// --- Device enumeration --------------------------------------------------- + +inline std::vector enumerate_gpu_devices() { + std::vector out; + +#if defined(SD_USE_CUDA) + int count = ggml_backend_cuda_get_device_count(); + for (int i = 0; i < count; i++) { + Device d; + d.id = i; + char desc_buf[256] = {0}; + ggml_backend_cuda_get_device_description(i, desc_buf, sizeof(desc_buf)); + d.description = desc_buf; + d.name = "CUDA" + std::to_string(i); + size_t free_b = 0, total_b = 0; + ggml_backend_cuda_get_device_memory(i, &free_b, &total_b); + d.free_bytes = int64_t(free_b); + d.total_bytes = int64_t(total_b); + out.push_back(d); + } +#elif defined(SD_USE_VULKAN) + int count = ggml_backend_vk_get_device_count(); + for (int i = 0; i < count; i++) { + Device d; + d.id = i; + d.name = "Vulkan" + std::to_string(i); + // Vulkan backend does not expose a direct free-memory API; enumerate + // via ggml_backend_dev so we can reuse ggml_backend_dev_memory. + ggml_backend_dev_t dev = nullptr; + for (size_t j = 0; j < ggml_backend_dev_count(); j++) { + ggml_backend_dev_t candidate = ggml_backend_dev_get(j); + if (ggml_backend_dev_type(candidate) == GGML_BACKEND_DEVICE_TYPE_GPU && + std::string(ggml_backend_dev_name(candidate)).find("Vulkan") != std::string::npos) { + if (int(j) == i) { dev = candidate; break; } + } + } + if (dev) { + d.description = ggml_backend_dev_description(dev); + size_t free_b = 0, total_b = 0; + ggml_backend_dev_memory(dev, &free_b, &total_b); + d.free_bytes = int64_t(free_b); + d.total_bytes = int64_t(total_b); + } + out.push_back(d); + } +#endif + + return out; +} + +// --- Core algorithm ------------------------------------------------------- + +// Peak VRAM per GPU is computed from two contributions: +// 1. `nonoffload_sum` — sum of params of every non-offload component on +// that GPU. These live on VRAM from LOAD through their free-after-use +// point, overlapping during the load window. +// 2. `max_active_footprint` — the largest per-phase compute footprint, +// where a non-offload component's phase contributes just its compute +// buffer, and an offload component's phase contributes params+compute +// (its runtime buffer is full-size while active, freed by +// `free_compute_buffer_immediately=true` between phases). +// peak = nonoffload_sum + max_active_footprint. This is conservative: it +// assumes the load-time accumulation overlaps with an active compute phase +// of the worst-case component. In practice load finishes before any compute +// starts so this over-counts by max_active_footprint during load — safe. +inline int64_t gpu_peak(int gpu_idx, + const std::vector& pl, + const std::vector& dev, + const std::vector& components) { + int64_t nonoffload_sum = 0; + int64_t max_active_footprint = 0; + for (size_t i = 0; i < components.size(); i++) { + if (dev[i] != gpu_idx) continue; + const Component& c = components[i]; + if (pl[i] == Placement::GPU) { + nonoffload_sum += c.params_bytes; + max_active_footprint = std::max(max_active_footprint, c.compute_bytes); + } else if (pl[i] == Placement::GPU_OFFLOAD_PARAMS) { + max_active_footprint = std::max(max_active_footprint, + c.params_bytes + c.compute_bytes); + } + } + return nonoffload_sum + max_active_footprint; +} + +inline Plan compute_plan(const std::vector& components, + const std::vector& devices, + int64_t margin_bytes) { + // Enumeration approach: for each component we have up to (1 + 2 * nGPU) + // placement options — CPU, or non-offload / offload on each GPU (offload + // only when the component supports it). We try all combinations, filter + // infeasible ones (any GPU's computed peak exceeds its free-margin cap), + // and pick the combination with the best score. + // + // Score rewards GPU placement (heavily), non-offload over offload + // (avoids per-step stream cost), and GPU diversity (use multiple GPUs + // when possible instead of packing onto one). Priority runtime hot + // components are weighted higher: DiT >> Conditioner > VAE. + const size_t nC = components.size(); + const size_t nG = devices.size(); + + std::vector cap(nG, 0); + for (size_t g = 0; g < nG; g++) { + cap[g] = devices[g].free_bytes - margin_bytes; + if (cap[g] < 0) cap[g] = 0; + } + + struct OptionSlot { + Placement placement; + int device_idx; // index into devices, or -1 for CPU + }; + + auto build_options = [&](const Component& c) { + std::vector opts; + for (size_t g = 0; g < nG; g++) { + opts.push_back({Placement::GPU, int(g)}); + if (c.supports_offload) { + opts.push_back({Placement::GPU_OFFLOAD_PARAMS, int(g)}); + } + } + opts.push_back({Placement::CPU, -1}); + return opts; + }; + + std::vector> options; + options.reserve(nC); + for (const Component& c : components) { + options.push_back(build_options(c)); + } + + auto priority_weight = [](ComponentKind k) -> int { + switch (k) { + case ComponentKind::DIT: return 300; // runs N times per generation + case ComponentKind::CONDITIONER: return 120; // one large forward per prompt + case ComponentKind::VAE: return 60; // one decode per generation + } + return 1; + }; + + auto score = [&](const std::vector& pl, + const std::vector& dev) { + int64_t s = 0; + std::set gpus_used; + for (size_t i = 0; i < nC; i++) { + const int pw = priority_weight(components[i].kind); + if (pl[i] == Placement::GPU) { + s += 10 * pw; + gpus_used.insert(dev[i]); + } else if (pl[i] == Placement::GPU_OFFLOAD_PARAMS) { + s += 5 * pw; // still on GPU but with per-step stream overhead + gpus_used.insert(dev[i]); + } else { + s -= 10 * pw; + } + } + s += 2 * int64_t(gpus_used.size()); // mild spread bonus + return s; + }; + + std::vector idx(nC, 0); + std::vector best_pl; + std::vector best_dev; + int64_t best_score = std::numeric_limits::min(); + bool found_any = false; + + // Iterate the cartesian product of options. + while (true) { + std::vector pl(nC); + std::vector dev(nC); + for (size_t i = 0; i < nC; i++) { + pl[i] = options[i][idx[i]].placement; + dev[i] = options[i][idx[i]].device_idx; + } + // Feasibility check: peak on each GPU vs cap. + bool feasible = true; + for (size_t g = 0; g < nG; g++) { + if (gpu_peak(int(g), pl, dev, components) > cap[g]) { + feasible = false; + break; + } + } + if (feasible) { + int64_t sc = score(pl, dev); + if (sc > best_score) { + best_score = sc; + best_pl = pl; + best_dev = dev; + found_any = true; + } + } + + // Advance mixed-radix counter. + size_t pos = 0; + while (pos < nC) { + idx[pos]++; + if (idx[pos] < options[pos].size()) break; + idx[pos] = 0; + pos++; + } + if (pos >= nC) break; + } + + Plan plan; + if (!found_any) { + // Degenerate: no feasible solution (even all-CPU must be feasible by + // construction; but guard anyway). Fall back to CPU for everything. + best_pl.assign(nC, Placement::CPU); + best_dev.assign(nC, -1); + } + + for (size_t i = 0; i < nC; i++) { + const Component& c = components[i]; + Decision d; + d.kind = c.kind; + d.name = c.name; + d.placement = best_pl[i]; + if (best_pl[i] == Placement::CPU) { + d.device_id = DEVICE_ID_CPU; + d.on_host_bytes = c.params_bytes + c.compute_bytes; + plan.any_changes = true; + } else { + d.device_id = devices[best_dev[i]].id; + if (best_pl[i] == Placement::GPU) { + d.on_device_bytes = c.params_bytes + c.compute_bytes; + } else { // GPU_OFFLOAD_PARAMS + d.on_device_bytes = c.params_bytes + c.compute_bytes; // peak during its compute + d.on_host_bytes = c.params_bytes; + plan.any_changes = true; + } + } + plan.decisions.push_back(d); + plan.host_bytes += d.on_host_bytes; + } + + // Report per-device peak using the same model as feasibility check. + for (size_t g = 0; g < nG; g++) { + plan.device_bytes[devices[g].id] = gpu_peak(int(g), best_pl, best_dev, components); + } + return plan; +} + +inline const char* placement_str(Placement p) { + switch (p) { + case Placement::CPU: return "CPU"; + case Placement::GPU: return "GPU"; + case Placement::GPU_OFFLOAD_PARAMS: return "GPU(params->RAM)"; + } + return "?"; +} + +inline void print_plan(const Plan& plan, + const std::vector& components, + const std::vector& devices, + int64_t margin_bytes) { + LOG_INFO("auto-fit plan (margin=%lld MiB per GPU):", + (long long)(margin_bytes / MiB)); + LOG_INFO(" available devices:"); + if (devices.empty()) { + LOG_INFO(" (no GPU devices detected — all components will run on CPU)"); + } + for (const Device& d : devices) { + LOG_INFO(" %-8s %-32s free %6lld / %6lld MiB", + d.name.c_str(), d.description.c_str(), + (long long)(d.free_bytes / MiB), + (long long)(d.total_bytes / MiB)); + } + LOG_INFO(" components:"); + for (const Component& c : components) { + LOG_INFO(" %-12s params %6lld MiB, compute reserve %6lld MiB", + c.name.c_str(), + (long long)(c.params_bytes / MiB), + (long long)(c.compute_bytes / MiB)); + } + LOG_INFO(" decisions:"); + for (const Decision& d : plan.decisions) { + if (d.placement == Placement::CPU) { + LOG_INFO(" %-12s -> CPU (RAM %lld MiB)", + d.name.c_str(), (long long)(d.on_host_bytes / MiB)); + } else if (d.placement == Placement::GPU) { + LOG_INFO(" %-12s -> GPU %d (VRAM %lld MiB)", + d.name.c_str(), d.device_id, + (long long)(d.on_device_bytes / MiB)); + } else { + LOG_INFO(" %-12s -> GPU %d (params RAM) (VRAM %lld MiB, RAM %lld MiB)", + d.name.c_str(), d.device_id, + (long long)(d.on_device_bytes / MiB), + (long long)(d.on_host_bytes / MiB)); + } + } + LOG_INFO(" projected per-device peak (MAX of assigned components, " + "since free_params_immediately lets components time-share VRAM):"); + for (const Device& d : devices) { + int64_t peak = 0; + auto it = plan.device_bytes.find(d.id); + if (it != plan.device_bytes.end()) peak = it->second; + const int64_t remaining = d.free_bytes - peak; + LOG_INFO(" %-8s peak %6lld / %6lld MiB free (remaining %lld MiB)", + d.name.c_str(), + (long long)(peak / MiB), + (long long)(d.free_bytes / MiB), + (long long)(remaining / MiB)); + } + LOG_INFO(" %-8s host RAM additional %lld MiB", "CPU", + (long long)(plan.host_bytes / MiB)); +} + +// Convenience: look up the decision for a specific component. +inline const Decision* find_decision(const Plan& plan, ComponentKind kind) { + for (const Decision& d : plan.decisions) { + if (d.kind == kind) return &d; + } + return nullptr; +} + +} // namespace backend_fit + +#endif // __SD_BACKEND_FIT_HPP__ diff --git a/src/conditioner.hpp b/src/conditioner.hpp index c38342212..3e6d63d38 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -2203,6 +2203,16 @@ struct LTX2GemmaConditioner : public Conditioner { const int64_t L = gemma_num_hidden_layers + 1; GGML_ASSERT(stacked.numel() == L * D * T * B); + if (const char* dump_path = std::getenv("SD_DUMP_COND_STACKED")) { + FILE* f = std::fopen(dump_path, "wb"); + if (f) { + std::fwrite(stacked.data(), sizeof(float), stacked.numel(), f); + std::fclose(f); + LOG_INFO("SD_DUMP_COND_STACKED: wrote %ld floats to %s (ne=[%ld,%ld,%ld,%ld])", + (long)stacked.numel(), dump_path, (long)L, (long)D, (long)T, (long)B); + } + } + // 2. CPU normalize → [B, T, D*L]. seq_lens=[T_real_eff] + left-padding tells // the normalizer to zero out the pad positions (which live at [0, T_pad)). // T_real_eff caps at max_length to handle the truncated-prompt branch above. @@ -2242,6 +2252,16 @@ struct LTX2GemmaConditioner : public Conditioner { // before caption_projection — the DiT owns caption_projection). auto context = connector_runner->compute(n_threads, normed, /*stage=*/3); + if (const char* dump_path = std::getenv("SD_DUMP_COND_CONTEXT")) { + FILE* f = std::fopen(dump_path, "wb"); + if (f) { + std::fwrite(context.data(), sizeof(float), context.numel(), f); + std::fclose(f); + LOG_INFO("SD_DUMP_COND_CONTEXT: wrote %ld floats to %s", + (long)context.numel(), dump_path); + } + } + SDCondition cond; cond.c_crossattn = context; return cond; diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 8275f26e0..09d242004 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1684,6 +1684,12 @@ struct GGMLRunnerContext { std::shared_ptr weight_adapter = nullptr; }; +// Forward declaration — defined near support_get_rows() below. Used by +// GGMLRunner's ctor to publish its params backend so Embedding::init_params +// can pick the right get_rows allowlist without plumbing backend through +// every init_params override. +__STATIC_INLINE__ ggml_backend_t& current_params_backend(); + struct GGMLRunner { protected: typedef std::function get_graph_cb_t; @@ -2003,6 +2009,13 @@ struct GGMLRunner { } else { params_backend = runtime_backend; } + // Publish the RUNTIME backend (not params) so block init_params() can + // reach it from support_get_rows() without plumbing backend through + // every init_params override. Runtime matters for get_rows: when + // offload_params_to_cpu is true, params live on CPU but the actual + // get_rows executes on the runtime backend (the GPU), which requires + // the weight dtype to be CUDA-supported. + current_params_backend() = runtime_backend; } virtual ~GGMLRunner() { @@ -2073,6 +2086,12 @@ struct GGMLRunner { } void free_params_buffer() { + // When offload_params_to_cpu is in effect, the tensors currently point + // at `runtime_params_buffer` (on the runtime backend). Restore them to + // `params_buffer` and free the runtime copy before freeing the params + // buffer itself — otherwise subsequent offloads would re-copy freed + // memory and the runtime buffer would leak on teardown. + offload_params_to_params_backend(); if (params_buffer != nullptr) { ggml_backend_buffer_free(params_buffer); params_buffer = nullptr; @@ -2096,7 +2115,14 @@ struct GGMLRunner { ggml_gallocr_free(compute_allocr); compute_allocr = nullptr; } - offload_params_to_params_backend(); + // Intentionally do NOT call offload_params_to_params_backend() here. + // For offload mode, keeping runtime_params_buffer resident across + // compute() calls of the same runner is the whole point — otherwise + // a DiT sampling loop re-uploads ~9 GB from RAM to GPU every CFG pass + // (observed ~5.7 min of pure upload overhead on 60-step 720p runs). + // free_params_buffer() handles the offload teardown when the caller + // is actually done with the runner (sd.cpp triggers it via + // free_params_immediately between components). } // do copy after alloc graph @@ -2374,16 +2400,34 @@ class Linear : public UnaryBlock { } }; +// Set by GGMLRunner's constructor to the params backend of the most recently +// constructed runner. Read by support_get_rows() below. Defined as a +// function-local static so the header stays single-definition. +__STATIC_INLINE__ ggml_backend_t& current_params_backend() { + static ggml_backend_t b = nullptr; + return b; +} + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { // ggml-cpu implements get_rows for the full quant set in // ggml_compute_forward_get_rows (ggml-cpu/ops.cpp) — both the legacy - // Q{4,5,8}_{0,1} formats AND the K-quants / IQ-quants. Historically this - // allowlist only contained legacy types, which forced the LTX-2 Gemma-3 - // token_embd weight (IQ4_XS in the 12B checkpoint) to fall back to F32 - // during Embedding::init_params and cost an extra ~3.5 GB of RAM per - // encode. Keep the list tight — only what ggml-cpu get_rows actually - // dispatches — so an unsupported type still trips the F32 fallback. - std::set allow_types = { + // Q{4,5,8}_{0,1} formats AND the K-quants / IQ-quants. The CUDA kernel + // in ggml-cuda/getrows.cu only supports F16/BF16/F32/I32 + legacy + // Q4_{0,1}/Q5_{0,1}/Q8_0 — calling it with a K- or IQ-quant aborts. + // So the allowlist must match the BACKEND that will actually hold the + // Embedding weight. When the current runner's params backend is CUDA + // we fall back to F32 for non-legacy quants (costs ~3.5 GB VRAM for a + // Gemma 12B IQ4_XS token_embd, but is the only option that runs). + static const std::set allow_legacy = { + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q5_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_0, + }; + static const std::set allow_full = { GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, @@ -2406,10 +2450,17 @@ __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_XS, }; - if (allow_types.find(wtype) != allow_types.end()) { - return true; - } - return false; + + ggml_backend_t b = current_params_backend(); + const bool on_cpu = + (b == nullptr) || ggml_backend_is_cpu(b); + // Debug knob: set SD_FORCE_LEGACY_GETROWS=1 to apply the CUDA-safe + // allowlist on both CPU and CUDA. Useful for isolating whether CPU and + // CUDA agree on the F32 fallback path. + static const bool force_legacy = + std::getenv("SD_FORCE_LEGACY_GETROWS") != nullptr; + const auto& allow = (on_cpu && !force_legacy) ? allow_full : allow_legacy; + return allow.count(wtype) > 0; } class Embedding : public UnaryBlock { diff --git a/src/llm.hpp b/src/llm.hpp index 0a12d1958..8c4dc5b2f 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -21,6 +21,26 @@ #include "tokenizers/mistral_tokenizer.h" #include "tokenizers/qwen2_tokenizer.h" +// Debug tap: when non-null, Gemma layer-0 forward paths push intermediate +// tensors here (tagged "DBG:"). Definition lives as `inline` to keep +// this file header-only. Set from LLMRunner::compute_all_hidden_states when +// the SD_DUMP_LAYER0 env var is present. +inline std::vector* g_layer0_taps = nullptr; + +// Helper: preserve a tap's value by routing the graph THROUGH a ggml_cont +// copy. Returning the cont'd tensor (instead of the original) means the +// next op in the graph consumes the cont, so the allocator has to keep the +// cont's buffer live. Mathematically a bitwise copy — no graph change. +// The cont's name starts with "DBG:" so the dumper can find it. +inline ggml_tensor* tap_tensor(ggml_context* ctx, ggml_tensor* t, const char* name) { + if (::g_layer0_taps == nullptr) return t; + ggml_tensor* keep = ggml_cont(ctx, t); + ggml_set_output(keep); // tell allocator: don't reuse my buffer + ggml_set_name(keep, (std::string("DBG:") + name).c_str()); + ::g_layer0_taps->push_back(keep); + return keep; +} + namespace LLM { // Bumped aggressively for the 22B LTX-2 smoke test where Gemma 3 12B runs with // compute_all_hidden_states (49-layer concat stack over 48 layers of sandwich- @@ -90,6 +110,15 @@ namespace LLM { bool has_post_norms = false; float embed_scale = 1.0f; + // When true, Linear layers inside this model force GGML_PREC_F32 on + // their mul_mat ops. ggml-cuda defaults to F16 accumulation for + // quantized matmul, which drifts ~2% per layer vs the CPU/F32 path. + // For Gemma 3 used as a fixed embedding encoder (LTX-2) the compound + // drift across 48 layers corrupts the final embedding to uselessness + // on CUDA. Set true for Gemma 3; leave false for generative LLMs + // where the drift is acceptable and speed matters more. + bool force_matmul_prec_f32 = false; + LLMVisionParams vision; }; @@ -127,11 +156,12 @@ namespace LLM { bool use_gelu_tanh; public: - MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false, bool use_gelu_tanh = false) + MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false, + bool use_gelu_tanh = false, bool force_prec_f32 = false) : use_gelu_tanh(use_gelu_tanh) { - blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); - blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); - blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias, /*force_f32=*/false, force_prec_f32)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias, /*force_f32=*/false, force_prec_f32)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias, /*force_f32=*/false, force_prec_f32)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { @@ -456,10 +486,11 @@ namespace LLM { rope_theta_global(params.rope_theta_global), rope_theta_local(params.rope_theta_local), rope_scaling_factor_global(params.rope_scaling_factor_global) { - blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); - blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false); + const bool fp = params.force_matmul_prec_f32; + blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false, /*force_f32=*/false, fp); if (params.qk_norm) { if (arch == LLMArch::GEMMA3) { blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); @@ -490,9 +521,14 @@ namespace LLM { auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); auto out_proj = std::dynamic_pointer_cast(blocks["o_proj"]); - auto q = q_proj->forward(ctx, x); // [N, n_token, num_heads*head_dim] - auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] - auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] + const bool trace = (layer_idx == 0); + auto tag = [&](ggml_tensor* t, const char* name) { + return trace ? tap_tensor(ctx->ggml_ctx, t, name) : t; + }; + + auto q = tag(q_proj->forward(ctx, x), "q_proj"); + auto k = tag(k_proj->forward(ctx, x), "k_proj"); + auto v = tag(v_proj->forward(ctx, x), "v_proj"); q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim] k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] @@ -502,8 +538,8 @@ namespace LLM { auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); - q = q_norm->forward(ctx, q); - k = k_norm->forward(ctx, k); + q = tag(q_norm->forward(ctx, q), "q_norm"); + k = tag(k_norm->forward(ctx, k), "k_norm"); } if (arch == LLMArch::MISTRAL_SMALL_3_2) { @@ -524,8 +560,8 @@ namespace LLM { bool is_sliding = is_gemma_sliding_layer(); float theta = is_sliding ? rope_theta_local : rope_theta_global; float freq_scale = is_sliding ? 1.0f : (1.0f / rope_scaling_factor_global); - q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f); - k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f); + q = tag(ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f), "q_rope"); + k = tag(ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f), "k_rope"); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -543,9 +579,9 @@ namespace LLM { k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size] + x = tag(ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false), "attn_out"); // [N, n_token, hidden_size] - x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] + x = tag(out_proj->forward(ctx, x), "o_proj"); // [N, n_token, hidden_size] return x; } }; @@ -553,13 +589,14 @@ namespace LLM { struct TransformerBlock : public GGMLBlock { protected: bool has_post_norms; + int layer_idx; public: TransformerBlock(const LLMParams& params, int layer_idx = 0) - : has_post_norms(params.has_post_norms) { + : has_post_norms(params.has_post_norms), layer_idx(layer_idx) { bool gemma = (params.arch == LLMArch::GEMMA3); blocks["self_attn"] = std::make_shared(params, layer_idx); - blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size, false, gemma); + blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size, false, gemma, params.force_matmul_prec_f32); if (gemma) { blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); @@ -590,16 +627,25 @@ namespace LLM { auto post_ff_ln = std::dynamic_pointer_cast(blocks["post_feedforward_layernorm"]); auto residual = x; - x = input_ln->forward(ctx, x); + const bool trace_block = (layer_idx == 0); + auto tag = [&](ggml_tensor* t, const char* name) { + return trace_block ? tap_tensor(ctx->ggml_ctx, t, name) : t; + }; + if (trace_block) { + x = tag(x, "x_embed_in"); + residual = x; // residual must match the post-tap tensor + } + + x = tag(input_ln->forward(ctx, x), "input_ln"); x = self_attn->forward(ctx, x, input_pos, attention_mask, attention_mask_sliding); - x = post_attn_ln->forward(ctx, x); - x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + x = tag(post_attn_ln->forward(ctx, x), "post_attn_ln"); + x = tag(ggml_add_inplace(ctx->ggml_ctx, x, residual), "after_attn_res"); residual = x; - x = pre_ff_ln->forward(ctx, x); - x = mlp->forward(ctx, x); - x = post_ff_ln->forward(ctx, x); - x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + x = tag(pre_ff_ln->forward(ctx, x), "pre_ff_ln"); + x = tag(mlp->forward(ctx, x), "mlp_out"); + x = tag(post_ff_ln->forward(ctx, x), "post_ff_ln"); + x = tag(ggml_add_inplace(ctx->ggml_ctx, x, residual), "after_ff_res"); return x; } @@ -657,8 +703,10 @@ namespace LLM { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto x = embed_tokens->forward(ctx, input_ids); + x = tap_tensor(ctx->ggml_ctx, x, "embed_raw"); if (embed_scale != 1.0f) { x = ggml_scale(ctx->ggml_ctx, x, embed_scale); + x = tap_tensor(ctx->ggml_ctx, x, "embed_scaled"); } if (all_hidden_states) { all_hidden_states->push_back(x); @@ -840,6 +888,13 @@ namespace LLM { // = 1/factor. Sliding-attention layers stay unscaled. params.rope_scaling_factor_global = 8.f; params.has_post_norms = true; + // Gemma 3 has narrow weight scales; the CUDA mmvq/mmq kernels + // quantize activations to q8_1 (block-32 fp16 scale) while the + // CPU iq4_xs kernel uses q8_K (block-256 fp32 scale). That + // format mismatch causes ~5% per-layer drift and ruins the + // embedding. Requesting GGML_PREC_F32 routes matmul through + // cuBLAS dequant+GEMM, which matches CPU bit-for-bit. + params.force_matmul_prec_f32 = true; // embed_scale is sqrt(hidden_size); hidden_size is autodetected below, // so defer setting embed_scale until after the tensor-storage scan. } @@ -1056,10 +1111,26 @@ namespace LLM { sd::Tensor compute_all_hidden_states(const int n_threads, const sd::Tensor& input_ids, const sd::Tensor& attention_mask) { + // Debug hook: capture layer-0 intermediates via the global tap vector. + // Forward paths push tensors here when ::g_layer0_taps != nullptr. + std::vector taps; + const char* dump_dir = std::getenv("SD_DUMP_LAYER0"); + if (dump_dir != nullptr) ::g_layer0_taps = &taps; + struct TapGuard { + ~TapGuard() { ::g_layer0_taps = nullptr; } + } guard; + auto get_graph = [&]() -> ggml_cgraph* { std::vector hidden_states; ggml_cgraph* gf = build_graph(input_ids, attention_mask, {}, {}, &hidden_states); + // Keep taps alive through the allocator: mark each as an output + // (prevents buffer aliasing) and expand into the graph. + for (auto* t : taps) { + ggml_set_output(t); + ggml_build_forward_expand(gf, t); + } + GGML_ASSERT(!hidden_states.empty()); // Reshape each [H, T, B] -> [1, H, T, B] so we can concat along axis 0. ggml_tensor* stacked = nullptr; @@ -1075,7 +1146,32 @@ namespace LLM { ggml_build_forward_expand(gf, stacked); return gf; }; - return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + auto result = take_or_empty(GGMLRunner::compute(get_graph, n_threads, /*free_compute_buffer_immediately=*/false)); + + if (dump_dir != nullptr && !taps.empty()) { + LOG_INFO("SD_DUMP_LAYER0: dumping %zu tensors to %s/", taps.size(), dump_dir); + for (auto* t : taps) { + const char* full_name = ggml_get_name(t); + if (std::strncmp(full_name, "DBG:", 4) != 0) continue; + const char* name = full_name + 4; + size_t nbytes = ggml_nbytes(t); + std::vector buf(nbytes); + ggml_backend_tensor_get(t, buf.data(), 0, nbytes); + std::string path = std::string(dump_dir) + "/" + name + ".bin"; + FILE* f = std::fopen(path.c_str(), "wb"); + if (f) { + std::fwrite(buf.data(), 1, nbytes, f); + std::fclose(f); + LOG_INFO(" %-22s ne=[%ld,%ld,%ld,%ld] type=%s bytes=%zu -> %s", + name, (long)t->ne[0], (long)t->ne[1], (long)t->ne[2], (long)t->ne[3], + ggml_type_name(t->type), nbytes, path.c_str()); + } + } + // Free now so we don't leak the compute buffer. + free_compute_buffer(); + } + + return result; } int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { diff --git a/src/model.h b/src/model.h index a6049c8a8..fd8ba6f21 100644 --- a/src/model.h +++ b/src/model.h @@ -202,6 +202,8 @@ using TensorTypeRules = std::vector>; TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); +bool is_unused_tensor(const std::string& name); + class ModelLoader { protected: SDVersion version_ = VERSION_COUNT; diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 149f07b72..89d7d8998 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -8,6 +8,7 @@ #include "util.h" #include "auto_encoder_kl.hpp" +#include "backend_fit.hpp" #include "conditioner.hpp" #include "control.hpp" #include "denoiser.hpp" @@ -113,6 +114,12 @@ class StableDiffusionGGML { ggml_backend_t clip_backend = nullptr; ggml_backend_t control_net_backend = nullptr; ggml_backend_t vae_backend = nullptr; + // Actual device id that `backend` points at. `SD_CUDA_DEVICE` can go stale + // when auto-fit re-initialises `backend` onto a different GPU. We track the + // live value so per-component resolution can decide "same device as main" + // correctly. -1 means the main backend is CPU. + int backend_device_id = -1; + static constexpr int BACKEND_DEVICE_CPU = -1; SDVersion version; bool vae_decode_only = false; @@ -150,6 +157,20 @@ class StableDiffusionGGML { bool is_using_v_parameterization = false; bool is_using_edm_v_parameterization = false; + // Populated by auto-fit (when --auto-fit is passed). When enabled, this + // overrides env-var based per-component placement. device_id == -1 means + // "no override" (fall through to env vars / defaults). + struct FitOverride { + bool enabled = false; + int dit_device_id = -1; // -1 = keep main backend + int vae_device_id = -2; // -2 = no override (distinguishes from "force CPU") + int cond_device_id = -2; + bool dit_offload_params = false; // force offload_params_to_cpu for DiT only + bool vae_on_cpu = false; + bool cond_on_cpu = false; + }; + FitOverride fit_override; + std::map tensors; // lora_name => multiplier @@ -257,8 +278,9 @@ class StableDiffusionGGML { // (keeps weights on CPU and streams per-step to GPU). void init_backend() { #ifdef SD_USE_CUDA - int main_dev = get_env_int("SD_CUDA_DEVICE", 0); - backend = init_device_backend(main_dev, "main"); + int main_dev = get_env_int("SD_CUDA_DEVICE", 0); + backend = init_device_backend(main_dev, "main"); + backend_device_id = ggml_backend_is_cpu(backend) ? BACKEND_DEVICE_CPU : main_dev; #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); @@ -317,9 +339,12 @@ class StableDiffusionGGML { // the override matches the main device; otherwise creates a new backend (which // the caller is responsible for freeing via the existing `!= backend` dtor check). // `force_cpu` short-circuits to CPU regardless of the env var. + // `fit_device_id` is the auto-fit override: -2 means "no override", -1 means + // "force CPU", >=0 names a specific GPU. ggml_backend_t resolve_component_backend(const char* env_name, const char* component_name, - bool force_cpu) { + bool force_cpu, + int fit_device_id = -2) { if (force_cpu) { if (ggml_backend_is_cpu(backend)) { return backend; @@ -328,16 +353,24 @@ class StableDiffusionGGML { return ggml_backend_cpu_init(); } #if defined(SD_USE_CUDA) || defined(SD_USE_VULKAN) || defined(SD_USE_SYCL) - int main_dev = get_env_int("SD_CUDA_DEVICE", 0); - int override_dev = get_env_int(env_name, main_dev); - if (override_dev == main_dev && !ggml_backend_is_cpu(backend)) { - // Same device as main — reuse the main backend to save GPU memory/context. + // Reuse the main backend iff this component resolves to the same + // physical device. After auto-fit re-initialises the main backend + // onto a different GPU, `SD_CUDA_DEVICE` no longer reflects reality, + // so we compare against `backend_device_id` instead. + int override_dev; + if (fit_override.enabled && fit_device_id != -2) { + override_dev = fit_device_id; + } else { + override_dev = get_env_int(env_name, get_env_int("SD_CUDA_DEVICE", 0)); + } + if (override_dev == backend_device_id && !ggml_backend_is_cpu(backend)) { return backend; } return init_device_backend(override_dev, component_name); #else (void)env_name; (void)component_name; + (void)fit_device_id; return backend; #endif } @@ -611,6 +644,78 @@ class StableDiffusionGGML { LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); + // ------------------------------------------------------------------ + // Auto-fit: compute per-component GPU/CPU placement plan based on + // currently free device memory. Runs before backend resolution so we + // can redirect the DiT backend and set per-component placement flags. + // Only affects the run when sd_ctx_params->auto_fit is true. + if (sd_ctx_params->auto_fit) { + backend_fit::ComputeReserves reserves; + if (sd_ctx_params->auto_fit_compute_reserve_dit_mb > 0) { + reserves.dit_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_dit_mb) * backend_fit::MiB; + } + if (sd_ctx_params->auto_fit_compute_reserve_vae_mb > 0) { + reserves.vae_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_vae_mb) * backend_fit::MiB; + } + if (sd_ctx_params->auto_fit_compute_reserve_cond_mb > 0) { + reserves.conditioner_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_cond_mb) * backend_fit::MiB; + } + + const int64_t alignment_guess = 256; + auto components = backend_fit::estimate_components( + model_loader, wtype, alignment_guess, reserves); + auto devices = backend_fit::enumerate_gpu_devices(); + int64_t margin_bytes = + int64_t(std::max(0, sd_ctx_params->auto_fit_target_mb)) * backend_fit::MiB; + auto plan = backend_fit::compute_plan(components, devices, margin_bytes); + backend_fit::print_plan(plan, components, devices, margin_bytes); + + if (sd_ctx_params->auto_fit_dry_run) { + LOG_INFO("auto-fit: --fit-dry-run set, aborting init before loading models"); + return false; + } + + // Apply plan to fit_override. + fit_override.enabled = true; + auto dit_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::DIT); + auto vae_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::VAE); + auto cond_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::CONDITIONER); + + if (dit_d) { + fit_override.dit_device_id = dit_d->device_id; + fit_override.dit_offload_params = + (dit_d->placement == backend_fit::Placement::GPU_OFFLOAD_PARAMS); + // Re-init the main backend if the chosen DiT device differs from + // whatever init_backend() picked. Keep `backend_device_id` in + // sync — it's what resolve_component_backend compares against. + const int current_dev = backend_device_id; + if (!ggml_backend_is_cpu(backend) && dit_d->placement == backend_fit::Placement::CPU) { + LOG_INFO("auto-fit: switching DiT backend from GPU %d to CPU", current_dev); + ggml_backend_free(backend); + backend = ggml_backend_cpu_init(); + backend_device_id = BACKEND_DEVICE_CPU; + } else if (dit_d->placement != backend_fit::Placement::CPU && + dit_d->device_id != current_dev) { + LOG_INFO("auto-fit: switching DiT backend from GPU %d to GPU %d", + current_dev, dit_d->device_id); + ggml_backend_free(backend); + backend = init_device_backend(dit_d->device_id, "DiT (auto-fit)"); + backend_device_id = dit_d->device_id; + } + } + if (vae_d) { + fit_override.vae_device_id = vae_d->device_id; + fit_override.vae_on_cpu = (vae_d->placement == backend_fit::Placement::CPU); + } + if (cond_d) { + fit_override.cond_device_id = cond_d->device_id; + fit_override.cond_on_cpu = (cond_d->placement == backend_fit::Placement::CPU); + } + } + if (sd_ctx_params->lora_apply_mode == LORA_APPLY_AUTO) { bool have_quantized_weight = false; if (wtype != GGML_TYPE_COUNT && ggml_is_quantized(wtype)) { @@ -650,17 +755,34 @@ class StableDiffusionGGML { } bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; + if (fit_override.enabled && fit_override.cond_on_cpu) { + clip_on_cpu = true; + } + + // Per-component offload flags. `offload_params_to_cpu` (the user's + // global --offload-to-cpu) applies to every component. Auto-fit may + // additionally force DiT-only offload when the DiT doesn't fit in + // VRAM; that MUST NOT be propagated to the Conditioner/VAE, otherwise + // their weights get pinned in RAM and the system can OOM (e.g. an + // LTX-2 run pinning Gemma 9.5 GB + DiT 13 GB + VAE 1.4 GB in 32 GB RAM). + const bool dit_offload = offload_params_to_cpu || + (fit_override.enabled && fit_override.dit_offload_params); + const bool cond_offload = offload_params_to_cpu; + const bool vae_offload = offload_params_to_cpu; { // Pick a device for the text-encoder stack. SD_CUDA_DEVICE_CLIP overrides // (set to -1 for CPU); `keep_clip_on_cpu` still forces CPU regardless. - clip_backend = resolve_component_backend("SD_CUDA_DEVICE_CLIP", "CLIP/TextEncoder", clip_on_cpu); + // When auto-fit is active, fit_override.cond_device_id wins. + clip_backend = resolve_component_backend( + "SD_CUDA_DEVICE_CLIP", "CLIP/TextEncoder", clip_on_cpu, + fit_override.enabled ? fit_override.cond_device_id : -2); if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; @@ -681,53 +803,53 @@ class StableDiffusionGGML { } cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, sd_ctx_params->chroma_use_t5_mask, sd_ctx_params->chroma_t5_mask_pad); } else if (version == VERSION_OVIS_IMAGE) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version, "", false); } else { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); } diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_flux2(version)) { bool is_chroma = false; cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, true, 0, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.high_noise_diffusion_model", version); @@ -736,7 +858,7 @@ class StableDiffusionGGML { diffusion_model->get_desc() == "Wan2.1-FLF2V-14B" || diffusion_model->get_desc() == "Wan2.1-I2V-1.3B") { clip_vision = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); @@ -747,42 +869,42 @@ class StableDiffusionGGML { enable_vision = true; } cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version, "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version, sd_ctx_params->qwen_image_zero_cond_t); } else if (sd_version_is_anima(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model"); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version); } else if (sd_version_is_ernie_image(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model"); } else if (sd_version_is_ltx2(version)) { @@ -793,12 +915,12 @@ class StableDiffusionGGML { // The tokenizer.json path is required — prompts can't be encoded without // it. Any HuggingFace-format `tokenizer.json` for Gemma 3 works. cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, "text_encoder", SAFE_STR(sd_ctx_params->gemma_tokenizer_path)); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version); @@ -809,20 +931,20 @@ class StableDiffusionGGML { } if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, embbeding_map, version, PM_VERSION_2); } else { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, embbeding_map, version); } diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version); if (sd_ctx_params->diffusion_conv_direct) { @@ -847,16 +969,22 @@ class StableDiffusionGGML { } // Pick a device for the VAE. SD_CUDA_DEVICE_VAE overrides (set to -1 for CPU); - // `keep_vae_on_cpu` still forces CPU regardless. - vae_backend = resolve_component_backend("SD_CUDA_DEVICE_VAE", "VAE", - sd_ctx_params->keep_vae_on_cpu); + // `keep_vae_on_cpu` still forces CPU regardless. Auto-fit, when active, + // supplies fit_override.vae_device_id which takes precedence over env. + bool vae_on_cpu = sd_ctx_params->keep_vae_on_cpu; + if (fit_override.enabled && fit_override.vae_on_cpu) { + vae_on_cpu = true; + } + vae_backend = resolve_component_backend( + "SD_CUDA_DEVICE_VAE", "VAE", vae_on_cpu, + fit_override.enabled ? fit_override.vae_device_id : -2); auto create_tae = [&]() -> std::shared_ptr { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { return std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "decoder", vae_decode_only, @@ -864,7 +992,7 @@ class StableDiffusionGGML { } else { auto model = std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "decoder.layers", vae_decode_only, @@ -878,7 +1006,7 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { return std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "first_stage_model", vae_decode_only, @@ -896,7 +1024,7 @@ class StableDiffusionGGML { // than the 4-block tiny-test default. We hardcode the 22B spec here for // the smoke test; a proper auto-detect from tensor shapes is a follow-up. return std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "first_stage_model", version, @@ -909,7 +1037,7 @@ class StableDiffusionGGML { LTXVAE::LTX2VAERunner::ltx2_22b_dec_specs()); } else { auto model = std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "first_stage_model", vae_decode_only, @@ -932,7 +1060,7 @@ class StableDiffusionGGML { LOG_INFO("using FakeVAE"); first_stage_model = std::make_shared(version, vae_backend, - offload_params_to_cpu); + vae_offload); } else if (use_tae && !tae_preview_only) { LOG_INFO("using TAE for encoding / decoding"); first_stage_model = create_tae(); @@ -979,7 +1107,7 @@ class StableDiffusionGGML { if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { pmid_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "pmid", version, @@ -987,7 +1115,7 @@ class StableDiffusionGGML { LOG_INFO("using PhotoMaker Version 2"); } else { pmid_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "pmid", version); @@ -2452,6 +2580,13 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_t5_mask_pad = 1; + + sd_ctx_params->auto_fit = false; + sd_ctx_params->auto_fit_target_mb = 512; + sd_ctx_params->auto_fit_dry_run = false; + sd_ctx_params->auto_fit_compute_reserve_dit_mb = 0; + sd_ctx_params->auto_fit_compute_reserve_vae_mb = 0; + sd_ctx_params->auto_fit_compute_reserve_cond_mb = 0; } char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { diff --git a/tests/ltx_parity/CMakeLists.txt b/tests/ltx_parity/CMakeLists.txt index edd7771c4..4af843f91 100644 --- a/tests/ltx_parity/CMakeLists.txt +++ b/tests/ltx_parity/CMakeLists.txt @@ -60,3 +60,12 @@ add_executable(${VAE_RT_TARGET} target_link_libraries(${VAE_RT_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${VAE_RT_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_CC_TARGET sd-gemma-cpu-vs-cuda) + +add_executable(${GEMMA_CC_TARGET} + test_gemma_cpu_vs_cuda.cpp +) + +target_link_libraries(${GEMMA_CC_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_CC_TARGET} PUBLIC c_std_11 cxx_std_17) diff --git a/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp b/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp new file mode 100644 index 000000000..64a8a6150 --- /dev/null +++ b/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp @@ -0,0 +1,309 @@ +// Layer-0 CPU vs CUDA parity for Gemma 3. +// +// Loads the user's real Gemma GGUF twice — once with a CPU backend, once with +// a CUDA backend — runs compute_all_hidden_states on the same tokens with +// g_layer0_taps set, and diffs each intermediate. This lets us pinpoint which +// Gemma layer-0 op first diverges between CPU and CUDA without pulling in the +// DiT (which would push a 32 GB system into swap/OOM). +// +// Usage: +// sd-gemma-cpu-vs-cuda [cuda_device] + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif + +#include "llm.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + int64_t argmax = -1; +}; + +DiffStats diff_f32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum = 0.0; + for (int64_t i = 0; i < n; ++i) { + float d = std::fabs(a[i] - b[i]); + if (d > s.max_abs) { + s.max_abs = d; + s.argmax = i; + } + sum += d; + } + s.mean_abs = static_cast(sum / (n > 0 ? n : 1)); + return s; +} + +std::vector fetch(ggml_tensor* t) { + size_t n = ggml_nbytes(t); + std::vector out(n); + ggml_backend_tensor_get(t, out.data(), 0, n); + return out; +} + +struct TapDump { + std::string name; + std::vector ne; + ggml_type type; + std::vector data; +}; + +std::vector run_and_dump(const std::string& model_path, + ggml_backend_t backend, + const std::vector& tokens) { + ModelLoader loader; + if (!loader.init_from_file(model_path, "text_encoders.llm.")) { + std::fprintf(stderr, "fatal: init_from_file failed: %s\n", model_path.c_str()); + std::exit(1); + } + loader.convert_tensors_name(); + + // SD_GEMMA_FORCE_TYPE=f16|bf16|f32|q8_0|q4_0 — on-load retype so both + // backends take the same matmul path (avoids iq4_xs×q8_K-vs-q8_1 drift). + if (const char* t = std::getenv("SD_GEMMA_FORCE_TYPE")) { + std::string s = t; + ggml_type tgt = GGML_TYPE_F16; + if (s == "f32") tgt = GGML_TYPE_F32; + else if (s == "bf16") tgt = GGML_TYPE_BF16; + else if (s == "q8_0" || s == "q8") tgt = GGML_TYPE_Q8_0; + else if (s == "q4_0") tgt = GGML_TYPE_Q4_0; + std::printf("[retype] forcing weights to %s\n", ggml_type_name(tgt)); + loader.set_wtype_override(tgt); + } + + // Rename text_encoders.llm.* -> text_encoder.* (matches LTX-2 flow). + auto& tsm = loader.get_tensor_storage_map(); + { + const std::string from = "text_encoders.llm."; + const std::string to = "text_encoder."; + String2TensorStorage out; + for (auto& kv : tsm) { + std::string k = kv.first; + if (k.rfind(from, 0) == 0) { + k = to + k.substr(from.size()); + kv.second.name = k; + } + out[k] = std::move(kv.second); + } + tsm.swap(out); + } + // Gemma sandwich-norm renames (mirrored from stable-diffusion.cpp init). + auto rename_suffix = [&](const std::string& old_suffix, const std::string& new_suffix) { + String2TensorStorage out; + for (auto& kv : tsm) { + std::string k = kv.first; + size_t p = k.rfind(old_suffix); + if (p != std::string::npos && p + old_suffix.size() == k.size() && + k.find("text_encoder.model.layers.") != std::string::npos) { + k = k.substr(0, p) + new_suffix; + kv.second.name = k; + } + out[k] = std::move(kv.second); + } + tsm.swap(out); + }; + rename_suffix(".post_attention_layernorm.weight", ".pre_feedforward_layernorm.weight"); + rename_suffix(".post_attention_norm.weight", ".post_attention_layernorm.weight"); + rename_suffix(".post_ffw_norm.weight", ".post_feedforward_layernorm.weight"); + + LLM::LLMRunner runner(LLM::LLMArch::GEMMA3, backend, /*offload=*/false, + tsm, /*prefix=*/"text_encoder", /*enable_vision=*/false); + + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "text_encoder"); + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed\n"); + std::exit(1); + } + + // Dump token_embd weight rows to compare storage and sanity. + auto it = param_tensors.find("text_encoder.model.embed_tokens.weight"); + if (it != param_tensors.end()) { + ggml_tensor* w = it->second; + std::printf("[weight] embed_tokens.weight: type=%s ne=[%ld,%ld,%ld,%ld] nbytes=%zu\n", + ggml_type_name(w->type), (long)w->ne[0], (long)w->ne[1], (long)w->ne[2], (long)w->ne[3], ggml_nbytes(w)); + if (w->type == GGML_TYPE_F32) { + int64_t hidden = w->ne[0]; + for (int64_t row_idx : {0, 1, 2, 100, 106, 262207}) { + std::vector row(hidden); + ggml_backend_tensor_get(w, row.data(), (size_t)row_idx * hidden * sizeof(float), hidden * sizeof(float)); + double sum_abs = 0; + for (float v : row) sum_abs += std::fabs(v); + std::printf("[weight] row %6ld first 4: %+.4e %+.4e %+.4e %+.4e mean_abs=%.3e\n", + (long)row_idx, row[0], row[1], row[2], row[3], sum_abs / hidden); + } + } + } + + const int64_t T = static_cast(tokens.size()); + sd::Tensor input_ids({T, 1}); + for (int64_t i = 0; i < T; ++i) input_ids.data()[i] = tokens[i]; + sd::Tensor empty_mask; + + std::vector taps; + ::g_layer0_taps = &taps; + auto stacked = runner.compute_all_hidden_states(/*n_threads=*/4, input_ids, empty_mask); + + // Collect tap dumps immediately while compute buffer is still alive. + std::vector tap_dumps; + for (auto* t : taps) { + const char* nm = ggml_get_name(t); + if (!nm || std::strncmp(nm, "DBG:", 4) != 0) continue; + if (!t->buffer) { + std::fprintf(stderr, "[tap] %s: no buffer (allocator aliased)\n", nm); + continue; + } + TapDump td; + td.name = nm + 4; + td.type = t->type; + for (int i = 0; i < 4; ++i) td.ne.push_back(t->ne[i]); + td.data = fetch(t); + tap_dumps.push_back(std::move(td)); + } + ::g_layer0_taps = nullptr; + + // Slice each layer out of the stacked tensor. stacked layout (innermost + // first): ne=[N+1, H, T, B]. For layer l: value at (b,t,h,l). + const int64_t L = runner.params.num_layers + 1; + const int64_t H = runner.params.hidden_size; + const int64_t Tdim = T; + const int64_t B = 1; + const int64_t per_layer = H * Tdim * B; + + std::vector dumps; + dumps.reserve(L); + for (int64_t l = 0; l < L; ++l) { + TapDump d; + d.name = (l == 0) ? "stacked_L00" : ("stacked_L" + std::to_string(l)); + d.type = GGML_TYPE_F32; + d.ne = {H, Tdim, B, 1}; + d.data.resize(per_layer * sizeof(float)); + float* out = reinterpret_cast(d.data.data()); + const float* src = stacked.data(); + for (int64_t b = 0; b < B; ++b) { + for (int64_t t = 0; t < Tdim; ++t) { + for (int64_t h = 0; h < H; ++h) { + int64_t idx_stacked = ((b * Tdim + t) * H + h) * L + l; + out[(b * Tdim + t) * H + h] = src[idx_stacked]; + } + } + } + dumps.push_back(std::move(d)); + } + // Append tap dumps so the caller can diff per-op. + for (auto& td : tap_dumps) dumps.push_back(std::move(td)); + return dumps; +} + +} // namespace + +int main(int argc, char** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s [cuda_device]\n", argv[0]); + return 2; + } + const std::string model_path = argv[1]; + int cuda_device = argc >= 3 ? std::atoi(argv[2]) : 0; + + // Short prompt so a layer-0 graph stays tiny. + std::vector tokens = {2, 106, 108, 1055, 674, 25148, 110, 107}; // ~"user\n..." style filler + + std::printf("[run] CPU forward...\n"); + std::fflush(stdout); + ggml_backend_t cpu_backend = ggml_backend_cpu_init(); + auto cpu_dumps = run_and_dump(model_path, cpu_backend, tokens); + ggml_backend_free(cpu_backend); + std::printf("[run] CPU done, %zu taps\n", cpu_dumps.size()); + +#ifdef SD_USE_CUDA + std::printf("[run] CUDA (device %d) forward...\n", cuda_device); + std::fflush(stdout); + ggml_backend_t cuda_backend = ggml_backend_cuda_init(cuda_device); + if (!cuda_backend) { + std::fprintf(stderr, "fatal: CUDA backend init failed for device %d\n", cuda_device); + return 1; + } + auto cuda_dumps = run_and_dump(model_path, cuda_backend, tokens); + ggml_backend_free(cuda_backend); + std::printf("[run] CUDA done, %zu taps\n", cuda_dumps.size()); +#else + std::fprintf(stderr, "fatal: built without SD_USE_CUDA\n"); + return 1; +#endif + + // Diff by name. + std::map cpu_idx; + for (const auto& d : cpu_dumps) cpu_idx[d.name] = &d; + + std::printf("\n%-22s %-5s %12s %12s %12s %6s\n", + "tap", "type", "max_abs", "mean_abs", "cpu_mean_mag", "shape"); + int fail_count = 0; + for (const auto& c : cuda_dumps) { + auto it = cpu_idx.find(c.name); + if (it == cpu_idx.end()) { + std::printf(" %-20s [missing on CPU side]\n", c.name.c_str()); + continue; + } + const TapDump* p = it->second; + if (p->type != c.type || p->ne != c.ne) { + std::printf(" %-20s type/shape mismatch\n", c.name.c_str()); + continue; + } + if (c.type != GGML_TYPE_F32) { + // Cast to F32 for diffing if needed. For simplicity we only handle F32 here. + std::printf(" %-20s type=%s skipped\n", c.name.c_str(), ggml_type_name(c.type)); + continue; + } + int64_t n = int64_t(p->data.size() / sizeof(float)); + auto s = diff_f32( + reinterpret_cast(p->data.data()), + reinterpret_cast(c.data.data()), + n); + double cpu_mag = 0.0; + const float* cp = reinterpret_cast(p->data.data()); + for (int64_t i = 0; i < n; ++i) cpu_mag += std::fabs(cp[i]); + cpu_mag /= (n > 0 ? n : 1); + bool fail = (s.max_abs > 1e-3f * (float)cpu_mag + 1e-4f); + std::printf(" %-20s %-5s %12.3e %12.3e %12.3e [%ld,%ld,%ld,%ld] %s\n", + c.name.c_str(), ggml_type_name(c.type), + s.max_abs, s.mean_abs, (double)cpu_mag, + (long)c.ne[0], (long)c.ne[1], (long)c.ne[2], (long)c.ne[3], + fail ? "FAIL" : "ok"); + if (fail) { + fail_count++; + // First-fail detail dump: first 8 values from each side. + if (fail_count == 1) { + const float* cp = reinterpret_cast(p->data.data()); + const float* cu = reinterpret_cast(c.data.data()); + std::printf(" first 8 floats: CPU vs CUDA\n"); + for (int64_t i = 0; i < 8 && i < n; ++i) { + std::printf(" [%ld] %+.6e vs %+.6e (diff %+.3e)\n", + (long)i, cp[i], cu[i], cu[i] - cp[i]); + } + std::printf(" argmax element: CPU=%+.6e CUDA=%+.6e idx=%ld\n", + cp[s.argmax], cu[s.argmax], (long)s.argmax); + } + } + } + std::printf("\n%d taps diverged (max_abs > 1e-3 × mean(|cpu|) + 1e-4).\n", fail_count); + return fail_count == 0 ? 0 : 3; +}