diff --git a/common/arg.cpp b/common/arg.cpp index dd787290d256d..78d3bde7a6173 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2949,6 +2949,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"), [](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); } ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-threshold"}, "F", + string_format("confidence threshold for transfer (default: %.2f)", (double) params.diffusion.threshold), + [](common_params & params, const std::string & value) { params.diffusion.threshold = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-eos-early-stop"}, + string_format("enable early EOS termination (default: %s)", params.diffusion.eos_early_stop ? "true" : "false"), + [](common_params & params) { params.diffusion.eos_early_stop = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-hybrid"}, + string_format("enable hybrid diffusion optimization (default: %s)", params.diffusion.hybrid_diffusion ? "true" : "false"), + [](common_params & params) { params.diffusion.hybrid_diffusion = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); add_opt(common_arg( { "-lr", "--learning-rate" }, "ALPHA", string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0), diff --git a/common/common.h b/common/common.h index 2f23d0baa830e..718727c3bbb6e 100644 --- a/common/common.h +++ b/common/common.h @@ -266,6 +266,11 @@ struct common_params_diffusion { float cfg_scale = 0; // classifier-free guidance scale bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0 + + float threshold = -1.0f; // confidence threshold for transfer + bool eos_early_stop = false; // enable early EOS termination + bool hybrid_diffusion = false; // enable hybrid diffusion optimization + }; // reasoning API response format (not to be confused as chat template's reasoning format) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index daaf0bf49740f..8b711182628ea 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8729,6 +8729,13 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("LLaDA2MoeModelLM") +class LLaDA2MoeModel(BailingMoeV2Model): + model_arch = gguf.MODEL_ARCH.LLADA2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_diffusion_shift_logits(False) @ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM") class GroveMoeModel(TextModel): diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp index 273942a165ed0..28976ecfccd96 100644 --- a/examples/diffusion/diffusion-cli.cpp +++ b/examples/diffusion/diffusion-cli.cpp @@ -49,8 +49,11 @@ struct diffusion_params { int32_t block_length = 0; // Block size (for block scheduling) float alg_temp = 0; // algorithm temperature (0.0 = deterministic) bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 + float threshold = -1.0f; // Confidence threshold for transfer (-1.0 = not set, use alg_temp-based sampling) - int32_t max_length = 0; // Maximum sequence length + int32_t max_length = 0; // Maximum sequence length + bool eos_early_stop = false; // Enable early EOS termination + bool hybrid_diffusion = false; // Enable hybrid diffusion optimization with KV cache }; struct callback_data { @@ -232,6 +235,16 @@ static void diffusion_generate(llama_context * ctx, std::vector mask_positions; mask_positions.reserve(params.max_length); + // Get EOS token for early termination + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_token eos_token_id = llama_vocab_eos(vocab); + + if (params.eos_early_stop) { + GGML_ASSERT(eos_token_id != LLAMA_TOKEN_NULL); + } + + LOG_DBG("DEBUG: EOS token ID = %d\n", eos_token_id); + // Setup sampler chain struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); if (params.top_k > 0) { @@ -277,7 +290,26 @@ static void diffusion_generate(llama_context * ctx, int64_t total_time = 0; int64_t time_start = ggml_time_us(); - for (int block_num = 0; block_num < num_blocks; block_num++) { + bool all_tokens_filled = false; + + // Hybrid Diffusion: Pre-fill prompt if enabled and n_input > 0 + if (params.hybrid_diffusion && n_input > 0) { + // Decode prompt (0..n_input) to KV cache + batch.n_tokens = n_input; + for (int32_t i = 0; i < n_input; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = false; // No logits needed for prompt + } + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: failed to decode prompt\n", __func__); + return; + } + } + for (int block_num = 0; block_num < num_blocks && !all_tokens_filled; block_num++) { int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0; int32_t block_end = (params.schedule == BLOCK_BASED) ? std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : @@ -305,12 +337,59 @@ static void diffusion_generate(llama_context * ctx, } // Setup batch - for (int32_t i = 0; i < params.max_length; i++) { - batch.token[i] = output_tokens[i]; - batch.pos[i] = i; + int32_t batch_size; + int32_t batch_start_pos; + + // Hybrid Diffusion: Commit previous block to KV cache + if (params.hybrid_diffusion && block_num > 0 && step == 0) { + int32_t prev_block_start = (params.schedule == BLOCK_BASED) ? n_input + (block_num - 1) * params.block_length : 0; + int32_t prev_block_end = block_start; + + int32_t pb_size = prev_block_end - prev_block_start; + if (pb_size > 0) { + batch.n_tokens = pb_size; + for (int32_t i = 0; i < pb_size; i++) { + int32_t pos = prev_block_start + i; + batch.token[i] = output_tokens[pos]; + batch.pos[i] = pos; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = false; + } + + // Remove old KV for this range to ensure we write the fresh finalized tokens + llama_memory_seq_rm(llama_get_memory(ctx), 0, prev_block_start, prev_block_end); + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: failed to commit previous block %d\n", __func__, block_num - 1); + break; + } + } + } + + if (params.hybrid_diffusion) { + // Hybrid Diffusion: Truncate to active block only + batch_start_pos = block_start; + batch_size = block_end - block_start; + } else { + // Process full sequence + batch_start_pos = 0; + batch_size = params.max_length; + } + + // Hybrid Diffusion: Remove old KV for the active region before re-decoding + if (params.hybrid_diffusion) { + llama_memory_seq_rm(llama_get_memory(ctx), 0, batch_start_pos, batch_start_pos + batch_size); + } + + batch.n_tokens = batch_size; + for (int32_t i = 0; i < batch_size; i++) { + int32_t pos = batch_start_pos + i; + batch.token[i] = output_tokens[pos]; + batch.pos[i] = pos; batch.n_seq_id[i] = 1; batch.seq_id[i][0] = 0; - batch.logits[i] = 1; + batch.logits[i] = true; } float * logits = nullptr; @@ -330,8 +409,9 @@ static void diffusion_generate(llama_context * ctx, un_x_buffer[i] = params.mask_token_id; } - for (int32_t i = 0; i < params.max_length; i++) { - batch.token[i] = un_x_buffer[i]; + for (int32_t i = 0; i < batch_size; i++) { + int32_t pos = batch_start_pos + i; + batch.token[i] = un_x_buffer[pos]; } ret = llama_decode(ctx, batch); if (ret != 0) { @@ -361,10 +441,17 @@ static void diffusion_generate(llama_context * ctx, } auto get_logits_for_pos = [&](int32_t pos) -> const float * { + // Hybrid Diffusion: Map absolute pos to relative pos in logits + int32_t rel_pos = params.hybrid_diffusion ? (pos - batch_start_pos) : pos; + + if (params.hybrid_diffusion && (pos < batch_start_pos || pos >= batch_start_pos + batch_size)) { + return nullptr; // Position out of active batch range + } + if (params.shift_logits) { - return pos == 0 ? logits : logits + (pos - 1) * n_vocab; + return rel_pos == 0 ? logits : logits + (rel_pos - 1) * n_vocab; } - return logits + (pos) *n_vocab; + return logits + (rel_pos) * n_vocab; }; int64_t time_start_sampling = ggml_time_us(); @@ -416,6 +503,10 @@ static void diffusion_generate(llama_context * ctx, std::vector> confidences; std::vector sampled_tokens(mask_positions.size()); + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + int32_t high_conf_count = 0; + for (size_t i = 0; i < mask_positions.size(); i++) { int32_t pos = mask_positions[i]; const float * pos_logits = get_logits_for_pos(pos); @@ -438,61 +529,148 @@ static void diffusion_generate(llama_context * ctx, float conf = calculate_confidence(cur_p, params.algorithm, rng); + if (params.threshold > 0.0f && conf > params.threshold) { + high_conf_count++; + } + sampled_tokens[i] = sampled_token; confidences.emplace_back(conf, i); } - int32_t transfer_count = calculate_transfer_count( - step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); - if (transfer_count > 0) { - if (params.alg_temp == 0.0f) { - std::partial_sort(confidences.begin(), - confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), - confidences.end(), - [](const std::pair & a, const std::pair & b) { - if (a.first != b.first) { - return a.first > b.first; - } - return a.second < b.second; - }); - - for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { - int32_t mask_idx = confidences[i].second; - int32_t pos = mask_positions[mask_idx]; - output_tokens[pos] = sampled_tokens[mask_idx]; + int32_t actual_transfer_count; + + if (params.threshold > 0.0f) { + // Threshold-based confidence approach + if (high_conf_count >= transfer_count) { + // If we have enough high-confidence tokens, + // use stable_partition to move them to the front, preserving relative order (by position). + // This avoids a full sort. + std::stable_partition(confidences.begin(), + confidences.end(), + [threshold = params.threshold](const std::pair& item) { + return item.first > threshold; + }); + actual_transfer_count = high_conf_count; + } else { + // Fallback: Not enough high-confidence tokens to meet the schedule. + // Sort to find the top 'transfer_count' tokens. + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair & a, const std::pair & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + actual_transfer_count = transfer_count; } + actual_transfer_count = std::min(actual_transfer_count, (int32_t)confidences.size()); + } else { - conf_candidates.clear(); - for (size_t i = 0; i < confidences.size(); i++) { - float conf_logit = confidences[i].first / params.alg_temp; - conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + // alg_temp-based approach (fallback when threshold not set) + if (params.alg_temp == 0.0f) { + // Deterministic selection: sort and take top transfer_count + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair & a, const std::pair & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + actual_transfer_count = std::min(transfer_count, (int32_t) confidences.size()); + } else { + // Stochastic selection using alg_temp + conf_candidates.clear(); + for (size_t i = 0; i < confidences.size(); i++) { + float conf_logit = confidences[i].first / params.alg_temp; + conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + conf_candidates.data(), + conf_candidates.size(), + -1, + false, + }; + + // Sample transfer_count positions stochastically + actual_transfer_count = std::min(transfer_count, (int32_t) confidences.size()); + for (int32_t i = 0; i < actual_transfer_count; i++) { + llama_sampler_apply(dist_sampler, &conf_array); + int32_t selected_idx = conf_array.selected; + int32_t mask_idx = selected_idx; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + + // Mark as used by setting p to 0 + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + // Skip the common transfer loop below for stochastic case + actual_transfer_count = 0; } + } - llama_token_data_array conf_array = { - conf_candidates.data(), - conf_candidates.size(), - -1, - false, - }; - - for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { - llama_sampler_apply(dist_sampler, &conf_array); - int32_t selected_idx = conf_array.selected; - int32_t mask_idx = selected_idx; - int32_t pos = mask_positions[mask_idx]; - output_tokens[pos] = sampled_tokens[mask_idx]; - - conf_candidates[selected_idx].p = 0.0f; - conf_array.selected = -1; + // Transfer tokens (deterministic case for both models) + for (int32_t i = 0; i < actual_transfer_count; i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + llama_token transferred_token = sampled_tokens[mask_idx]; + output_tokens[pos] = transferred_token; + + // EOS early stop + if (params.eos_early_stop && transferred_token == eos_token_id) { + // Verify all tokens from n_input to pos are filled + bool all_filled_before_eos = true; + for (int32_t j = n_input; j < pos; j++) { + if (output_tokens[j] == params.mask_token_id) { + all_filled_before_eos = false; + break; + } + } + if (all_filled_before_eos) { + LOG_DBG("\nEOS detected at position %d, all prior tokens filled. Terminating.\n", pos); + n_generated = pos + 1 - n_input; + all_tokens_filled = true; + break; + } } } + if (params.eos_early_stop && all_tokens_filled) break; // Exit step loop + } else { + LOG_DBG("DEBUG: Transfer count is 0!\n"); } } int64_t time_end_sampling = ggml_time_us(); total_sampling_time += time_end_sampling - time_start_sampling; } + + // Check for EOS after block completes + if (params.eos_early_stop) { + for (int32_t i = n_input; i < block_end; i++) { + if (output_tokens[i] == eos_token_id) { + // Check if all tokens before EOS are filled + bool all_filled = true; + for (int32_t j = n_input; j < i; j++) { + if (output_tokens[j] == params.mask_token_id) { + all_filled = false; + break; + } + } + if (all_filled) { + LOG_DBG("\nEOS found at position %d after block %d. Terminating.\n", i, block_num); + n_generated = i + 1 - n_input; + all_tokens_filled = true; + break; + } + } + } + } } int64_t time_end = ggml_time_us(); @@ -567,7 +745,14 @@ int main(int argc, char ** argv) { llama_model_free(model); return 1; } - + + // Compute max_length early to ensure n_ubatch is large enough + int32_t max_length = params.n_predict > 0 ? params.n_predict : params.n_ctx; + + LOG_DBG("DEBUG: params.n_ctx = %d\n", params.n_ctx); + LOG_DBG("DEBUG: params.n_predict = %d\n", params.n_predict); + LOG_DBG("DEBUG: max_length = %d\n", max_length); + llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = params.n_ctx; ctx_params.n_batch = params.n_batch; @@ -611,7 +796,7 @@ int main(int argc, char ** argv) { bool visual_mode = params.diffusion.visual_mode; int32_t n_generated = 0; - std::vector output_tokens(params.n_ubatch); + std::vector output_tokens(max_length); struct diffusion_params diff_params; @@ -622,6 +807,15 @@ int main(int argc, char ** argv) { diff_params.shift_logits = true; } + // EOS early stop parameter from CLI + diff_params.eos_early_stop = params.diffusion.eos_early_stop; + + // Threshold parameter from CLI + diff_params.threshold = params.diffusion.threshold; + + // Hybrid diffusion parameter from CLI + diff_params.hybrid_diffusion = params.diffusion.hybrid_diffusion; + //Use either eps or block length, but not both GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0)); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6f5a742e04a6a..7807eabd6859b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -436,6 +436,7 @@ class MODEL_ARCH(IntEnum): SMALLTHINKER = auto() LLADA = auto() LLADA_MOE = auto() + LLADA2 = auto() SEED_OSS = auto() GROVEMOE = auto() APERTUS = auto() @@ -807,6 +808,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", MODEL_ARCH.LLADA_MOE: "llada-moe", + MODEL_ARCH.LLADA2: "llada2", MODEL_ARCH.SEED_OSS: "seed_oss", MODEL_ARCH.GROVEMOE: "grovemoe", MODEL_ARCH.APERTUS: "apertus", @@ -2952,6 +2954,29 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_DOWN_EXP, ], + MODEL_ARCH.LLADA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.GROVEMOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f7a8c9841ecab..e9ad57d4a4992 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -85,6 +85,7 @@ add_library(llama models/lfm2.cpp models/llada-moe.cpp models/llada.cpp + models/llada2.cpp models/llama-iswa.cpp models/llama.cpp models/mamba.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7ef87acf1b35d..d9b4d86e61051 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -103,6 +103,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_LLADA_MOE, "llada-moe" }, + { LLM_ARCH_LLADA2, "llada2" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, { LLM_ARCH_APERTUS, "apertus" }, @@ -2070,6 +2071,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, }, }, + { + LLM_ARCH_LLADA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_DOTS1, { @@ -2755,6 +2782,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_LLADA2: case LLM_ARCH_RND1: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index 9ad3157bf67c8..9f8a48c9bb190 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -107,6 +107,7 @@ enum llm_arch { LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_LLADA_MOE, + LLM_ARCH_LLADA2, LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, LLM_ARCH_APERTUS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a042ea9632ce4..aa2c7b1923c7a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1984,6 +1984,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LLADA2: case LLM_ARCH_BAILINGMOE2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5649,6 +5650,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_LLADA2: case LLM_ARCH_BAILINGMOE2: { const int64_t n_ff_exp = hparams.n_ff_exp; @@ -5660,8 +5662,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2/llada2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2/llada2"); for (int i = 0; i < n_layer; ++i) { int flags = 0; @@ -6755,7 +6757,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_BAILINGMOE2) { + if (arch == LLM_ARCH_BAILINGMOE2 || arch == LLM_ARCH_LLADA2) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); @@ -7349,6 +7351,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_LLADA2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SEED_OSS: { llm = std::make_unique(*this, params); @@ -7652,6 +7658,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_COGVLM: case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: + case LLM_ARCH_LLADA2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/models/llada2.cpp b/src/models/llada2.cpp new file mode 100644 index 0000000000000..ec2195cebd9e8 --- /dev/null +++ b/src/models/llada2.cpp @@ -0,0 +1,132 @@ +#include "models.h" + +llm_build_llada2::llm_build_llada2(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 0 * sizeof(float) * (n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), + cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA); + cb(sa_out, "sa_out", il); + + // MoE branch + cur = build_norm(sa_out, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if (static_cast(il) < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 5f019c59be897..06d1371659ad7 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -297,6 +297,10 @@ struct llm_build_llada : public llm_graph_context { llm_build_llada(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_llada2 : public llm_graph_context { + llm_build_llada2(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_llada_moe : public llm_graph_context { llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); };