Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel):
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value

def set_vocab(self):
self._set_vocab_sentencepiece()

self.gguf_writer.add_add_space_prefix(False)
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_space_prefix(False)
else:
self._set_vocab_gpt2()

def set_gguf_parameters(self):
hparams = self.hparams
Expand All @@ -5845,13 +5847,24 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
if (final_logit_softcap := hparams.get("final_logit_softcapping")):
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
if hparams.get("sliding_window_pattern") != 1:
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
assert hparams["rope_scaling"]["rope_type"] == "linear"
# important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
rope_scaling = hparams["rope_scaling"]
if rope_scaling["rope_type"] == "linear":
# important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
elif rope_scaling["rope_type"] == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
Expand All @@ -5865,8 +5878,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
vocab = self._create_vocab_sentencepiece()
tokens = vocab[0]
if (self.dir_model / "tokenizer.model").is_file():
tokens = self._create_vocab_sentencepiece()[0]
else:
tokens = self.get_vocab_base()[0]
data_torch = data_torch[:len(tokens)]

# ref code in Gemma3RMSNorm
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ add_library(llama
models/gemma-embedding.cpp
models/gemma.cpp
models/gemma2-iswa.cpp
models/gemma3-iswa.cpp
models/gemma3.cpp
models/gemma3n-iswa.cpp
models/glm4-moe.cpp
models/glm4.cpp
Expand Down
23 changes: 17 additions & 6 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GEMMA3:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(6);
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (found_swa && hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(6);

hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
} else {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
}

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
hparams.f_final_logit_softcapping = 0.0f;
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 18: type = LLM_TYPE_270M; break;
case 26: type = LLM_TYPE_1B; break;
case 32: type = LLM_TYPE_8B; break; // Rnj-1
case 34: type = LLM_TYPE_4B; break;
case 48: type = LLM_TYPE_12B; break;
case 62: type = LLM_TYPE_27B; break;
Expand Down Expand Up @@ -7300,7 +7307,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GEMMA3:
{
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
} else {
llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
}
} break;
case LLM_ARCH_GEMMA3N:
{
Expand Down
35 changes: 30 additions & 5 deletions src/models/gemma3-iswa.cpp → src/models/gemma3.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "models.h"

llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
template <bool iswa>
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;

ggml_tensor * cur;
Expand All @@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
ggml_tensor * inp_pos = build_inp_pos();

// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_iswa();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;

if constexpr (iswa) {
inp_attn = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv();
}

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
float freq_base_l = 0.0f;
float freq_scale_l = 0.0f;

if constexpr (iswa) {
freq_base_l = model.get_rope_freq_base (cparams, il);
freq_scale_l = model.get_rope_freq_scale(cparams, il);
} else {
freq_base_l = freq_base;
freq_scale_l = freq_scale;
}

// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
Expand Down Expand Up @@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", -1);
cb(cur, "ffn_post_norm", il);

cur = ggml_add(ctx0, cur, sa_out);

Expand All @@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
// lm_head
cur = build_lora_mm(model.output, cur);

if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}

template struct llm_build_gemma3<false>;
template struct llm_build_gemma3<true>;
5 changes: 3 additions & 2 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_gemma3_iswa : public llm_graph_context {
llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
template <bool iswa>
struct llm_build_gemma3 : public llm_graph_context {
llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_gemma3n_iswa : public llm_graph_context {
Expand Down
Loading