From e1b7d46e8c3364cad6d3d17059e9748911726881 Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Fri, 14 Nov 2025 03:09:10 +0000 Subject: [PATCH 1/3] add support for rnj1 --- convert_hf_to_gguf.py | 48 ++++++++++++++ gguf-py/gguf/constants.py | 20 ++++++ src/CMakeLists.txt | 1 + src/llama-arch.cpp | 22 +++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 17 +++++ src/models/models.h | 4 ++ src/models/rnj1.cpp | 132 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 245 insertions(+) create mode 100644 src/models/rnj1.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4590b239212..ed803a77634 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6098,6 +6098,54 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +# Note: HF checkpoints use architecture "Gemma3ForCausalLM" since a separate +# model implementation is not needed for Transformers. To convert to GGUF, +# change the architecture in config.json to "Rnj1ForCausalLM". +@ModelBase.register("Rnj1ForCausalLM") +class Rnj1Model(TextModel): + model_arch = gguf.MODEL_ARCH.RNJ1 + norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + # some default values are not specified in the hparams + self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 32768)) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) + self.gguf_writer.add_key_length(hparams.get("head_dim", 128)) + self.gguf_writer.add_value_length(hparams.get("head_dim", 128)) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10_000.0)) # for global layers + assert hparams.get("attn_logit_softcapping") is None + self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "language_model." in name: + name = name.replace("language_model.", "") + + # ref code in Gemma3RMSNorm + # output = output * (1.0 + self.weight.float()) + if name.endswith("norm.weight"): + data_torch = data_torch + self.norm_shift + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b..6dc874999e1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum): GEMMA3 = auto() GEMMA3N = auto() GEMMA_EMBEDDING = auto() + RNJ1 = auto() STARCODER2 = auto() RWKV6 = auto() RWKV6QWEN2 = auto() @@ -758,6 +759,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", + MODEL_ARCH.RNJ1: "rnj1", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", @@ -1930,6 +1932,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_PRE_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.RNJ1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fbd538109ba..2e103ecba91 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -117,6 +117,7 @@ add_library(llama models/qwen3next.cpp models/refact.cpp models/rnd1.cpp + models/rnj1.cpp models/rwkv6-base.cpp models/rwkv6.cpp models/rwkv6qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 64ad1b77690..7c3107a2a07 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -50,6 +50,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, + { LLM_ARCH_RNJ1, "rnj1" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, @@ -1225,6 +1226,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_RNJ1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/src/llama-arch.h b/src/llama-arch.h index e113180024d..7d3c6b5f5c6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -54,6 +54,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA_EMBEDDING, + LLM_ARCH_RNJ1, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c3675dbdc41..e13b7637316 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1335,6 +1335,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; + case LLM_ARCH_RNJ1: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3899,6 +3910,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_RNJ1: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA_EMBEDDING: { @@ -7310,6 +7322,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_RNJ1: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params); @@ -7767,6 +7783,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_RNJ1: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/models/models.h b/src/models/models.h index d93601ad06a..4abb8fccf99 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -213,6 +213,10 @@ struct llm_build_gemma : public llm_graph_context { llm_build_gemma(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_rnj1 : public llm_graph_context { + llm_build_rnj1(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_glm4 : public llm_graph_context { llm_build_glm4(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/rnj1.cpp b/src/models/rnj1.cpp new file mode 100644 index 00000000000..3bf5c647466 --- /dev/null +++ b/src/models/rnj1.cpp @@ -0,0 +1,132 @@ +#include "models.h" +#include "../llama-impl.h" + +// Based off Gemma3 implementation. The main differences are: +// - all layers are global +// - use YaRN for long-context + +llm_build_rnj1::llm_build_rnj1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} From 569a0092b98e73bd222490441234144e52afd2ef Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:57:09 +0000 Subject: [PATCH 2/3] refactor gemma3 to support rnj-1 --- convert_hf_to_gguf.py | 79 ++++-------- gguf-py/gguf/constants.py | 20 ---- src/CMakeLists.txt | 3 +- src/llama-arch.cpp | 22 ---- src/llama-arch.h | 1 - src/llama-model.cpp | 49 ++++---- src/models/{gemma3-iswa.cpp => gemma3.cpp} | 35 +++++- src/models/models.h | 9 +- src/models/rnj1.cpp | 132 --------------------- 9 files changed, 78 insertions(+), 272 deletions(-) rename src/models/{gemma3-iswa.cpp => gemma3.cpp} (79%) delete mode 100644 src/models/rnj1.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ed803a77634..48663a77f1e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel): norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): - self._set_vocab_sentencepiece() - - self.gguf_writer.add_add_space_prefix(False) + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + self._set_vocab_sentencepiece() + self.gguf_writer.add_add_space_prefix(False) def set_gguf_parameters(self): hparams = self.hparams @@ -5845,13 +5847,22 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None + self.gguf_writer.add_final_logit_softcapping(self.hparams.get("final_logit_softcapping", 0)) self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: - assert hparams["rope_scaling"]["rope_type"] == "linear" - # important: this rope_scaling is only applied for global layers, and not used by 1B model - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams["rope_scaling"] + if rope_scaling["rope_type"] == "linear": + # important: this rope_scaling is only applied for global layers, and not used by 1B model + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + elif rope_scaling["rope_type"] == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused @@ -5865,8 +5876,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: - vocab = self._create_vocab_sentencepiece() - tokens = vocab[0] + if (self.dir_model / "tokenizer.json").is_file(): + tokens = self.get_vocab_base()[0] + else: + tokens = self._create_vocab_sentencepiece()[0] data_torch = data_torch[:len(tokens)] # ref code in Gemma3RMSNorm @@ -6098,54 +6111,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) -# Note: HF checkpoints use architecture "Gemma3ForCausalLM" since a separate -# model implementation is not needed for Transformers. To convert to GGUF, -# change the architecture in config.json to "Rnj1ForCausalLM". -@ModelBase.register("Rnj1ForCausalLM") -class Rnj1Model(TextModel): - model_arch = gguf.MODEL_ARCH.RNJ1 - norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value - - def set_gguf_parameters(self): - hparams = self.hparams - block_count = hparams["num_hidden_layers"] - - # some default values are not specified in the hparams - self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 32768)) - self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) - self.gguf_writer.add_key_length(hparams.get("head_dim", 128)) - self.gguf_writer.add_value_length(hparams.get("head_dim", 128)) - self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10_000.0)) # for global layers - assert hparams.get("attn_logit_softcapping") is None - self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) - rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"]) - self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"]) - self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"]) - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - if "language_model." in name: - name = name.replace("language_model.", "") - - # ref code in Gemma3RMSNorm - # output = output * (1.0 + self.weight.float()) - if name.endswith("norm.weight"): - data_torch = data_torch + self.norm_shift - - return [(self.map_tensor_name(name), data_torch)] - - @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6dc874999e1..2b8489c591b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -385,7 +385,6 @@ class MODEL_ARCH(IntEnum): GEMMA3 = auto() GEMMA3N = auto() GEMMA_EMBEDDING = auto() - RNJ1 = auto() STARCODER2 = auto() RWKV6 = auto() RWKV6QWEN2 = auto() @@ -759,7 +758,6 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", - MODEL_ARCH.RNJ1: "rnj1", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", @@ -1932,24 +1930,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_PRE_NORM, MODEL_TENSOR.FFN_POST_NORM, ], - MODEL_ARCH.RNJ1: [ - MODEL_TENSOR.TOKEN_EMBD, - MODEL_TENSOR.OUTPUT, - MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_Q_NORM, - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_K_NORM, - MODEL_TENSOR.ATTN_V, - MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_GATE, - MODEL_TENSOR.FFN_DOWN, - MODEL_TENSOR.FFN_UP, - MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_POST_NORM, - MODEL_TENSOR.FFN_PRE_NORM, - MODEL_TENSOR.FFN_POST_NORM, - ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2e103ecba91..84a0c2934e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,7 +67,7 @@ add_library(llama models/gemma-embedding.cpp models/gemma.cpp models/gemma2-iswa.cpp - models/gemma3-iswa.cpp + models/gemma3.cpp models/gemma3n-iswa.cpp models/glm4-moe.cpp models/glm4.cpp @@ -117,7 +117,6 @@ add_library(llama models/qwen3next.cpp models/refact.cpp models/rnd1.cpp - models/rnj1.cpp models/rwkv6-base.cpp models/rwkv6.cpp models/rwkv6qwen2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7c3107a2a07..64ad1b77690 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -50,7 +50,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, - { LLM_ARCH_RNJ1, "rnj1" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, @@ -1226,27 +1225,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, - { - LLM_ARCH_RNJ1, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, { LLM_ARCH_STARCODER2, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 7d3c6b5f5c6..e113180024d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -54,7 +54,6 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA_EMBEDDING, - LLM_ARCH_RNJ1, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e13b7637316..f5a39d2fdcb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1264,24 +1264,32 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); - - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; - - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; case 62: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } + if (type == LLM_TYPE_8B) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) @@ -1335,17 +1343,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; - case LLM_ARCH_RNJ1: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); - } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3910,7 +3907,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; - case LLM_ARCH_RNJ1: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA_EMBEDDING: { @@ -7312,7 +7308,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_GEMMA3N: { @@ -7322,10 +7322,6 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; - case LLM_ARCH_RNJ1: - { - llm = std::make_unique(*this, params); - } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params); @@ -7783,7 +7779,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA_EMBEDDING: - case LLM_ARCH_RNJ1: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/models/gemma3-iswa.cpp b/src/models/gemma3.cpp similarity index 79% rename from src/models/gemma3-iswa.cpp rename to src/models/gemma3.cpp index 839ff6d3d93..ba68c5dc1a8 100644 --- a/src/models/gemma3-iswa.cpp +++ b/src/models/gemma3.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); + cb(cur, "ffn_post_norm", il); cur = ggml_add(ctx0, cur, sa_out); @@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll // lm_head cur = build_lora_mm(model.output, cur); + if constexpr (!iswa) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } + +template struct llm_build_gemma3; +template struct llm_build_gemma3; diff --git a/src/models/models.h b/src/models/models.h index 4abb8fccf99..6494f545018 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context { llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_gemma3 : public llm_graph_context { + llm_build_gemma3(const llama_model & model, const llm_graph_params & params); }; struct llm_build_gemma3n_iswa : public llm_graph_context { @@ -213,10 +214,6 @@ struct llm_build_gemma : public llm_graph_context { llm_build_gemma(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_rnj1 : public llm_graph_context { - llm_build_rnj1(const llama_model & model, const llm_graph_params & params); -}; - struct llm_build_glm4 : public llm_graph_context { llm_build_glm4(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/rnj1.cpp b/src/models/rnj1.cpp deleted file mode 100644 index 3bf5c647466..00000000000 --- a/src/models/rnj1.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "models.h" -#include "../llama-impl.h" - -// Based off Gemma3 implementation. The main differences are: -// - all layers are global -// - use YaRN for long-context - -llm_build_rnj1::llm_build_rnj1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 - Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - cur = build_norm(cur, - model.layers[il].attn_post_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_post_norm", il); - - ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); - cb(sa_out, "sa_out", il); - - cur = build_norm(sa_out, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward network - { - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GELU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } - cur = build_norm(cur, - model.layers[il].ffn_post_norm, NULL, - LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", il); - - cur = ggml_add(ctx0, cur, sa_out); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} From 14c977aadf10070f9f4ee1fd2aa8d92efbb1fecf Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Tue, 9 Dec 2025 02:29:24 +0000 Subject: [PATCH 3/3] address review comments --- convert_hf_to_gguf.py | 20 +++++++++++--------- src/llama-model.cpp | 31 +++++++++++++++---------------- src/models/gemma3.cpp | 2 +- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 48663a77f1e..6db46cc550d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5825,11 +5825,11 @@ class Gemma3Model(TextModel): norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): - if (self.dir_model / "tokenizer.json").is_file(): - self._set_vocab_gpt2() - else: + if (self.dir_model / "tokenizer.model").is_file(): self._set_vocab_sentencepiece() self.gguf_writer.add_add_space_prefix(False) + else: + self._set_vocab_gpt2() def set_gguf_parameters(self): hparams = self.hparams @@ -5847,11 +5847,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None - self.gguf_writer.add_final_logit_softcapping(self.hparams.get("final_logit_softcapping", 0)) - self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + if (final_logit_softcap := hparams.get("final_logit_softcapping")): + self.gguf_writer.add_final_logit_softcapping(final_logit_softcap) + if hparams.get("sliding_window_pattern") != 1: + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: - rope_scaling = self.hparams["rope_scaling"] + rope_scaling = hparams["rope_scaling"] if rope_scaling["rope_type"] == "linear": # important: this rope_scaling is only applied for global layers, and not used by 1B model self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) @@ -5876,10 +5878,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: - if (self.dir_model / "tokenizer.json").is_file(): - tokens = self.get_vocab_base()[0] - else: + if (self.dir_model / "tokenizer.model").is_file(): tokens = self._create_vocab_sentencepiece()[0] + else: + tokens = self.get_vocab_base()[0] data_torch = data_torch[:len(tokens)] # ref code in Gemma3RMSNorm diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f5a39d2fdcb..6b76b2f7423 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1264,6 +1264,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; @@ -1274,22 +1289,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } - if (type == LLM_TYPE_8B) { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); - - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; - - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - } - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) diff --git a/src/models/gemma3.cpp b/src/models/gemma3.cpp index ba68c5dc1a8..ae60ef4790c 100644 --- a/src/models/gemma3.cpp +++ b/src/models/gemma3.cpp @@ -140,7 +140,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr // lm_head cur = build_lora_mm(model.output, cur); - if constexpr (!iswa) { + if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); cur = ggml_tanh(ctx0, cur); cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);