Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8836,6 +8836,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2MOE

def set_gguf_parameters(self):
# set num_key_value_heads only for attention layers
self.hparams["num_key_value_heads"] = [
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
for layer_type in self.hparams["layer_types"]
]

super().set_gguf_parameters()

self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)

self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])

# cache for experts weights for merging
_experts_cache: dict[int, dict[str, Tensor]] = {}

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)

if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")

# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
assert bid is not None

expert_cache = self._experts_cache.setdefault(bid, {})
expert_cache[name] = data_torch
expert_weights = ["w1", "w2", "w3"]

# not enough expert weights to merge
if len(expert_cache) < n_experts * len(expert_weights):
return []

tensors: list[tuple[str, Tensor]] = []
for w_name in expert_weights:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
datas.append(expert_cache[ename])
del expert_cache[ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

del self._experts_cache[bid]
return tensors

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
assert not self._experts_cache


@ModelBase.register("Lfm2VlForConditionalGeneration")
class LFM2VLModel(MmprojModel):
def __init__(self, *args, **kwargs):
Expand Down
25 changes: 25 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class MODEL_ARCH(IntEnum):
SMOLLM3 = auto()
GPT_OSS = auto()
LFM2 = auto()
LFM2MOE = auto()
DREAM = auto()
SMALLTHINKER = auto()
LLADA = auto()
Expand Down Expand Up @@ -749,6 +750,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.GPT_OSS: "gpt-oss",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.LFM2MOE: "lfm2moe",
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
Expand Down Expand Up @@ -2698,6 +2700,29 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.OUTPUT,
],
MODEL_ARCH.LFM2MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.SHORTCONV_CONV,
MODEL_TENSOR.SHORTCONV_INPROJ,
MODEL_TENSOR.SHORTCONV_OUTPROJ,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_NORM, # operator_norm
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.SMALLTHINKER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.router", # openai-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
"model.layers.{bid}.feed_forward.gate", # lfm2moe
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand All @@ -367,6 +368,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
),

# Feed-forward up
Expand Down
28 changes: 28 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
Expand Down Expand Up @@ -2104,6 +2105,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT, "output" },
}
},
{
LLM_ARCH_LFM2MOE,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_SMALLTHINKER,
{
Expand Down Expand Up @@ -2493,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
return true;
default:
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum llm_arch {
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
LLM_ARCH_LFM2MOE,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
Expand Down
81 changes: 66 additions & 15 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
Expand Down Expand Up @@ -1995,14 +1996,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
hparams.n_layer_dense_lead = hparams.n_layer;
switch (hparams.n_ff()) {
case 4608: type = LLM_TYPE_350M; break;
case 6912: type = LLM_TYPE_700M; break;
case 8192: type = LLM_TYPE_1_2B; break;
case 10752: type = LLM_TYPE_2_6B; break;
default: type = LLM_TYPE_UNKNOWN;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_LFM2MOE:
{
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);

for (uint32_t il = 0; il < hparams.n_layer; ++il) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}

type = LLM_TYPE_8B_A1B;
} break;
case LLM_ARCH_SMALLTHINKER:
{
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
Expand Down Expand Up @@ -5814,6 +5830,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
Expand All @@ -5825,11 +5842,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// ffn is same for transformer and conv layers

const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead);

// ffn/moe is same for transformer and conv layers
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
if (is_moe_layer) {
GGML_ASSERT(n_expert && n_expert_used);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
} else { // dense
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}

// for operator_norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
Expand Down Expand Up @@ -6310,7 +6339,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}

if (arch == LLM_ARCH_SMALLTHINKER) {
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
}
Expand Down Expand Up @@ -18602,6 +18631,8 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
const bool is_moe_layer = il >= static_cast<int>(hparams.n_layer_dense_lead);

auto * prev_cur = cur;
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.operator_norm", il);
Expand All @@ -18616,7 +18647,16 @@ struct llm_build_lfm2 : public llm_graph_context {
}

cur = ggml_add(ctx0, prev_cur, cur);
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));

auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(ffn_norm_out, "model.layers.{}.ffn_norm", il);

ggml_tensor * ffn_out = is_moe_layer ?
build_moe_feed_forward(ffn_norm_out, il) :
build_dense_feed_forward(ffn_norm_out, il);
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);

cur = ggml_add(ctx0, cur, ffn_out);
}

cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
Expand All @@ -18631,23 +18671,32 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_build_forward_expand(gf, cur);
}

ggml_tensor * build_feed_forward(ggml_tensor * cur,
int il) const {
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.ffn_norm", il);
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur,
int il) const {
return build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
il);
}

ggml_tensor * build_dense_feed_forward(ggml_tensor * cur,
int il) const {
GGML_ASSERT(!model.layers[il].ffn_up_b);
GGML_ASSERT(!model.layers[il].ffn_gate_b);
GGML_ASSERT(!model.layers[il].ffn_down_b);
cur = build_ffn(cur,
return build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "model.layers.{}.feed_forward.w2", il);

return cur;
}

ggml_tensor * build_attn_block(ggml_tensor * cur,
Expand Down Expand Up @@ -19817,6 +19866,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
} break;
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
llm = std::make_unique<llm_build_lfm2>(*this, params);
} break;
Expand Down Expand Up @@ -20039,6 +20089,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_OPENAI_MOE:
case LLM_ARCH_HUNYUAN_DENSE:
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_SMALLTHINKER:
case LLM_ARCH_GLM4_MOE:
case LLM_ARCH_SEED_OSS:
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ enum llm_type {
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
Expand Down
Loading