diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 42d63b7c5444c..788d7a1d10bd0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -342,7 +342,7 @@ jobs: cd build export GGML_VK_VISIBLE_DEVICES=0 # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 3600 + ctest -L main --verbose --timeout 4200 ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 diff --git a/common/arg.cpp b/common/arg.cpp index 40af7e574830f..56827a65908be 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2734,6 +2734,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.public_path = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); + add_opt(common_arg( + {"--api-prefix"}, "PREFIX", + string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()), + [](common_params & params, const std::string & value) { + params.api_prefix = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX")); add_opt(common_arg( {"--no-webui"}, string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), diff --git a/common/common.h b/common/common.h index 8922090e7b10d..a5abe32859fdd 100644 --- a/common/common.h +++ b/common/common.h @@ -370,6 +370,7 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT + std::string api_prefix = ""; // NOLINT std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dd80a4a05d596..2419126ec4ea2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -815,6 +815,24 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" + if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": + # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct + res = "hunyuan" + if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf": + # ref: https://huggingface.co/skt/A.X-4.0 + res = "a.x-4.0" + if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": + # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base + res = "falcon-h1" + if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86": + # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base + res = "falcon-h1" + if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896": + # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base + res = "falcon-h1" + if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b": + # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base + res = "falcon-h1" if res is None: logger.warning("\n") @@ -4896,17 +4914,19 @@ def set_vocab(self): def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 - head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64 n_group = self.find_hparam(["n_groups"], optional=True) or 1 rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 # Fail early for models which don't have a block expansion factor of 2 # TODO: does this really matter? - assert d_inner == 2 * d_model - assert d_inner % head_dim == 0 + # skip the assertion for FalconH1 Model + if self.model_arch != gguf.MODEL_ARCH.FALCON_H1: + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) @@ -4943,7 +4963,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = data_torch.reshape((*data_torch.shape, 1)) elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model n_group = self.hparams.get("n_groups", 1) data_torch = data_torch.reshape((n_group, d_inner // n_group)) @@ -4954,6 +4974,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) +@ModelBase.register("JambaForCausalLM") +class JambaModel(TextModel): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R @@ -6535,6 +6672,277 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + +@ModelBase.register("FalconH1ForCausalLM") +class FalconH1Model(Mamba2Model): + model_arch = gguf.MODEL_ARCH.FALCON_H1 + + def __init__(self, *args, **kwargs): + # Set the hparam prefixes for Falcon Mamba2 + self.hparam_prefixes = ["mamba"] + + # Initialize the base Mamba2Model + super().__init__(*args, **kwargs) + + # Use Llama conversion for attention + self._transformer_model_class = LlamaModel + + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["mamba_d_ssm"]) + self.d_head = self.find_hparam(["d_head"]) + + # Initialize any Falcon Mamba2 specific attributes + self.has_attention = True # Falcon Mamba2 has attention components + + # Load Falcon-H1 multipliers from hyperparameters + self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True) + self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True) + self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True) + self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True) + self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) + self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) + self.intermediate_size = self.find_hparam(["intermediate_size"]) + self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True) + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + tensors = list(super().modify_tensors(data_torch, name, bid)) + tensor = tensors[0][1] + + if "down_proj" in name: + tensor = tensor * self.mlp_multipliers[1] + elif "gate_proj" in name: + tensor = tensor * self.mlp_multipliers[0] + elif "k_proj" in name: + tensor = tensor * self.key_multiplier * self.attention_in_multiplier + elif "q_proj" in name: + tensor = tensor * self.attention_in_multiplier + elif "v_proj" in name: + tensor = tensor * self.attention_in_multiplier + elif "o_proj" in name: + tensor = tensor * self.attention_out_multiplier + elif "out_proj" in name: + tensor = tensor * self.ssm_out_multiplier + elif "in_proj" in name: + tensor = tensor * self.ssm_in_multiplier + zxbcdt_multipliers = self.hparams["ssm_multipliers"] + intermediate_size = self.hparams["mamba_d_ssm"] + groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] + tensor[:intermediate_size, :] *= zxbcdt_multipliers[0] + tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1] + tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2] + tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3] + tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4] + elif "lm_head" in name: + tensor = tensor * self.hparams["lm_head_multiplier"] + elif "embed_tokens" in name: + tensor = tensor * self.hparams["embedding_multiplier"] + elif "mamba.norm" in name: + tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group) + + tensors = [(tensors[0][0], tensor)] + return tensors + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + ## General Params ## + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + # Override some Mamba2 defaults + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + ## Attention params ## + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2 + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_key_length(self.hparams["head_dim"]) + self.gguf_writer.add_value_length(self.hparams["head_dim"]) + + ## Validation ## + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}" + + # Add any other Falcon Mamba2 specific configuration + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) + + +@ModelBase.register("HunYuanMoEV1ForCausalLM") +class HunYuanMoEModel(TextModel): + model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # For handling tied embeddings + self._tok_embd = None + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: # todo this is an assert in Qwen, why? + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # 3. Generate the tokens and toktypes lists + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Overwrite incorrect id read from config.json + self.gguf_writer.add_bos_token_id(127959) # <|bos|> + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) + + moe_intermediate_size = hparams["moe_intermediate_size"] + assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + + moe_topk = hparams["moe_topk"] + assert all(topk == moe_topk[0] for topk in moe_topk) + self.gguf_writer.add_expert_used_count(moe_topk[0]) + + moe_shared_expert = hparams["num_shared_expert"] + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + + # Rope + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128 + scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251 + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "model.embed_tokens.weight": + self._tok_embd = data_torch.clone() + + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + if name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("SmolLM3ForCausalLM") +class SmolLM3Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.SMOLLM3 + + def set_vocab(self): + super().set_vocab() + # remove unsupported array slicing in chat template + # ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1 + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + if tokenizer.chat_template is not None: + chat_template = tokenizer.chat_template.replace("[:]", "") + self.gguf_writer.add_chat_template(chat_template) + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 2f733f0973686..b8cb6027d6de5 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, + {"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -137,6 +138,12 @@ class TOKENIZER_TYPE(IntEnum): {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, + {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, + # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes + {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, + {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, + {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"}, + {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"}, ] diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 7f71e0247ddc7..51e0b0b20f58d 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv ### 2. Define the model architecture in `llama.cpp` -The model params and tensors layout must be defined in `llama.cpp`: -1. Define a new `llm_arch` -2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non-standard metadata in `llm_load_hparams` -4. Create the tensors for inference in `llm_load_tensors` -5. If the model has a RoPE operation, add the rope type in `llama_rope_type` +The model params and tensors layout must be defined in `llama.cpp` source files: +1. Define a new `llm_arch` enum value in `src/llama-arch.h`. +2. In `src/llama-arch.cpp`: + - Add the architecture name to the `LLM_ARCH_NAMES` map. + - Add the tensor mappings to the `LLM_TENSOR_NAMES` map. +3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`. +4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`. NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions. ### 3. Build the GGML graph implementation -This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. - -Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. +This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`. +Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor. +Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`. +Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct. Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 949eac9a5a0b5..8a8775be36583 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -495,7 +495,7 @@ extern "C" { GGML_OP_POOL_1D, GGML_OP_POOL_2D, GGML_OP_POOL_2D_BACK, - GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_UPSCALE, GGML_OP_PAD, GGML_OP_PAD_REFLECT_1D, GGML_OP_ROLL, @@ -1297,6 +1297,19 @@ extern "C" { struct ggml_tensor * a, float s); + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eae575cc040cd..ccb17eb072eb2 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_CLAMP: @@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SCALE: + float bias; + memcpy(&bias, (float*)op->op_params + 1, sizeof(float)); + return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support broadcast // ref: https://github.com/ggml-org/llama.cpp/pull/14435 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index aaeee614ab993..fd77e9a6abad5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); + float s; // scale factor + float b; // bias + + memcpy(&s, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&b, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + } + } else { + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_mad1_f32(nc, + (float *) ((char *) dst->data + i1*nb1), + (float *) ((char *) src0->data + i1*nb1), + s, b); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 1f5857a23e35c..d18783a00a1a5 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int #endif } +inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmsa(x, 1, &s, &b, y, 1, n); +#elif defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + // scalar ; TODO: Write SVE code + for (int i = 0; i < n; ++i) { + y[i] = x[i]*s + b; + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = x[i]*s + b; + } + #endif +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = x[i]*s + b; + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 954f74d408f9f..1a2708ec9dff5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) { #endif #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) -#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ - do { \ - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \ - const int id = ggml_cuda_get_device(); \ - if (!shared_memory_limit_raised[id]) { \ - CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ - shared_memory_limit_raised[id] = true; \ - } \ - } while (0) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) #else -#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + GGML_UNUSED(nbytes); \ + } while (0) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 124d5d3e89122..908c76dbdd270 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index c22baf41764d1..b2f1724c95588 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); - GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 963e4d03dd77b..f77b2629a19b0 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_BF16: get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); @@ -210,6 +214,10 @@ void get_rows_cuda( ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_F16: ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index af5ad1ed52cdc..da1e8f8f4e443 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3200,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3373,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 18f691b2d3103..d058504cd6cc0 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -50,21 +50,19 @@ static __global__ void rope_norm( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0; const int ix = channel_x*s2 + row_x*s1 + i0; + if (i0 >= n_dims) { + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -94,21 +92,19 @@ static __global__ void rope_neox( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -138,21 +134,19 @@ static __global__ void rope_multi( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 1405e066e86a2..2ee9e588992f4 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,18 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); + scale_f32<<>>(x, dst, scale, bias, k); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); - scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); + scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 524e979574266..ef48aa5f97bcd 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst, dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); } +static __global__ void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int)floorf(y_src_f); + int y1_src = y0_src + 1; + + y0_src = max(0, min(y0_src, ne01_src - 1)); + y1_src = max(0, min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int)floorf(x_src_f); + int x1_src = x0_src + 1; + + x0_src = max(0, min(x0_src, ne00_src - 1)); + x1_src = max(0, min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + static void upscale_f32_cuda(const float * x, float * dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, cudaStream_t stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; upscale_f32<<>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); } +static void upscale_f32_bilinear_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset, cudaStream_t stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); +} + void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const float sf0 = (float)dst->ne[0]/src0->ne[0]; - const float sf1 = (float)dst->ne[1]/src0->ne[1]; - const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; const float sf3 = (float)dst->ne[3]/src0->ne[3]; - upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + pixel_offset = 0.0f; + } + upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 40fc315e82fd1..83a0739809a6e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node( GGML_ASSERT(ggml_is_contiguous(src0)); float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float bias; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float)); int64_t n = ggml_nelements(dst); @@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 22240bab47249..239ec31fbcb58 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1014,16 +1014,18 @@ kernel void kernel_scale( device const float * src0, device float * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, constant float & scale, + constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; + dst[tpig] = src0[tpig] * scale + bias; } kernel void kernel_clamp( diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a9fc039038705..43d8e5c72c937 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); + float bias; + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float)); ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); int n = ggml_nelements(dst)/4; diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index 8cfd518fa5a3e..aeca8a456e4fe 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -8,9 +8,10 @@ kernel void kernel_scale( ulong offset0, global float4 * dst, ulong offsetd, - float scale + float scale, + float bias ) { src0 = (global float4*)((global char*)src0 + offset0); dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 21c81e99a19aa..cd15bbdb29fa2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static void scale_f32(const float * x, float * dst, const float scale, const int k, +static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int return; } - dst[i] = scale * x[i]; + dst[i] = scale * x[i] + bias; } @@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( -static void scale_f32_sycl(const float *x, float *dst, const float scale, +static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; stream->parallel_for( @@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - scale_f32(x, dst, scale, k, item_ct1); + scale_f32(x, dst, scale, bias, k, item_ct1); }); } @@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds float * dst_dd = static_cast(dst->data); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); - scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream); + scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); /* DPCT1010:87: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index e44c6b6ef8f42..1b60226dcd531 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0; const int i2 = channel0 * s2 + row0 * s1 + i0; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0 / 2; const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const } const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = (row_dst * ne0) + (i0 / 2); const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix); + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e8df00d4183ac..c36e1a6d3bfc2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -501,6 +501,8 @@ struct vk_device_struct { ggml_backend_buffer_type buffer_type; + bool disable_fusion; + #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif @@ -1091,8 +1093,8 @@ static size_t vk_skip_checks; static size_t vk_output_tensor; static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); -static void ggml_vk_check_results_0(ggml_tensor * tensor); -static void ggml_vk_check_results_1(ggml_tensor * tensor); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); #endif typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); @@ -2704,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -3507,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->idx = idx; + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + return device; } @@ -6248,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0) { // Try to run two workgroups per SM. split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } @@ -6388,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -7504,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, + op_params[0], op_params[1], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -7654,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -8885,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. @@ -9146,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr // fused rms_norm + mul ggml_tensor *mul = cgraph->nodes[node_idx + 1]; ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; - ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun); + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); } else { - ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun); + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); } break; case GGML_OP_RMS_NORM_BACK: @@ -9308,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -9323,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -9433,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(tensor); + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); use_fence = true; #endif @@ -9453,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_1(tensor); + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); #endif } @@ -9900,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) { return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; } +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + mul->src[0]->ne[1] != rms_norm->ne[1]) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -9913,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -9983,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } - if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; } @@ -10760,11 +10795,21 @@ void * comp_result; size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; -static void ggml_vk_check_results_0(ggml_tensor * tensor) { +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + check_counter++; if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; @@ -10792,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { for (int i = 0; i < 6; i++) { ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } if (srci == nullptr) { continue; } @@ -10849,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_SUB) { tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { @@ -11040,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } - ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); - ggml_build_forward_expand(cgraph, tensor_clone); + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); - ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ggml_vk_print_tensor(tensor_clone, "tensor_clone"); @@ -11066,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); } -static void ggml_vk_check_results_1(ggml_tensor * tensor) { +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 599cef072e931..0a17a9df23f9f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 32 +layout(constant_id = 0) const uint BLOCK_SIZE = 32; -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 1) writeonly buffer D {float data_d[];}; @@ -16,6 +16,8 @@ layout (push_constant) uniform parameter { uint k_num; } p; +shared float tmpsh[BLOCK_SIZE]; + void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; @@ -32,23 +34,51 @@ void main() { // Compute the max m value for the row float m_max = -1.0/0.0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; m_max = max(m_max, m); } + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + // Compute L based on m_max float L = 0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float l = data_a[l_offset + k * lm_stride]; - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; L += exp(m - m_max) * l; } + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + L = 1.0 / L; + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; // Scale and sum the O contributions based on m_max and store the result to memory - for (uint d = tid; d < D; d += BLOCK_SIZE) { + if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 26163b167c7ed..888ce79f6ec11 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -500,10 +500,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +511,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +531,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +551,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +638,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 4f5b1a0ecaf5d..5808710ccf998 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -14,21 +14,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index db775c456cae8..366a7b1c47cdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 4ad35e549d77f..9643bca96ac92 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index 4663428dee0a2..f10b0a02b5076 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -18,7 +18,7 @@ void main() { continue; } - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); idx += num_threads; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2698522ed7101..30f78fabb3c85 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; if (tname == "bf16") { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 75fc1e7072970..5ae1c527df639 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, + float b, bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params(result, &s, sizeof(s)); + float params[2] = { s, b }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_SCALE; result->src[0] = a; @@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, false); + return ggml_scale_impl(ctx, a, s, 0.0, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, true); + return ggml_scale_impl(ctx, a, s, 0.0, true); +} + +struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, false); +} + +struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, true); } // ggml_set @@ -5777,7 +5795,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MEAN: { if (src0_needs_grads) { - ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false)); } } break; case GGML_OP_REPEAT: { @@ -5854,7 +5872,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float s; memcpy(&s, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false)); } } break; case GGML_OP_SET: { diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 5ffd12b8b2795..53504399c57f4 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -631,7 +631,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par gguf_free(ctx); return nullptr; } - ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + if (SIZE_MAX - ctx->size < padded_size) { + GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n", + __func__, ti.t.name, ctx->size, padded_size); + gguf_free(ctx); + return nullptr; + } + ctx->size += padded_size; } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c12609c6d9f99..fbe3f53273a35 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum): LLAMA4 = auto() DECI = auto() FALCON = auto() + FALCON_H1 = auto() BAICHUAN = auto() GROK = auto() GPT2 = auto() @@ -329,6 +330,7 @@ class MODEL_ARCH(IntEnum): ARWKV7 = auto() MAMBA = auto() MAMBA2 = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -357,6 +359,8 @@ class MODEL_ARCH(IntEnum): DOTS1 = auto() ARCEE = auto() ERNIE4_5 = auto() + HUNYUAN_MOE = auto() + SMOLLM3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -429,7 +433,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() @@ -632,6 +639,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -660,6 +668,9 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ERNIE4_5: "ernie4_5", + MODEL_ARCH.FALCON_H1: "falcon-h1", + MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.SMOLLM3: "smollm3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -732,7 +743,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", @@ -1732,6 +1746,34 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2211,6 +2253,77 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.FALCON_H1: [ + # Token embedding + MODEL_TENSOR.TOKEN_EMBD, + + # Input layernorm + MODEL_TENSOR.ATTN_NORM, + + # Attention components + MODEL_TENSOR.ATTN_Q, # Query projection + MODEL_TENSOR.ATTN_K, # Key projection + MODEL_TENSOR.ATTN_V, # Value projection + MODEL_TENSOR.ATTN_OUT, # Output projection + + # SSM components (Mamba2 specific) + MODEL_TENSOR.SSM_IN, # Input projection for SSM + MODEL_TENSOR.SSM_CONV1D, # Convolution layer + MODEL_TENSOR.SSM_DT, # Delta time projection + MODEL_TENSOR.SSM_A, # A parameter (log form) + MODEL_TENSOR.SSM_D, # D parameter + MODEL_TENSOR.SSM_NORM, # Normalization in SSM + MODEL_TENSOR.SSM_OUT, # Output projection + + # Pre-feedforward layernorm + MODEL_TENSOR.FFN_PRE_NORM, + + # Feed-forward network components + MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU) + MODEL_TENSOR.FFN_DOWN, # Down projection + MODEL_TENSOR.FFN_UP, # Up projection + + # Post-feedforward layernorm + MODEL_TENSOR.OUTPUT_NORM, # Final layer norm + MODEL_TENSOR.OUTPUT, # Output projection (lm_head) + ], + MODEL_ARCH.HUNYUAN_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.SMOLLM3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 51634ef6bdd2e..215eb297ebcc1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -279,6 +279,8 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert ), @@ -286,12 +288,14 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + "model.layers.{bid}.pre_ff_layernorm.weight", ), # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.{bid}.feed_forward.up_proj", ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -301,8 +305,9 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe - "model.layers.{bid}.feed_forward.router", # llama4 + "model.layers.{bid}.feed_forward.router", # llama4 jamba "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + "model.layers.{bid}.mlp.gate.wg", # hunyuan ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -344,7 +349,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba "transformer_encoder.{bid}.ffn.w12", # neobert ), @@ -362,6 +367,8 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", + "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan ), # AWQ-activation gate @@ -382,7 +389,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -398,6 +405,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan ), # Feed-forward down @@ -427,7 +435,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba "transformer_encoder.{bid}.ffn.w3", # neobert ), @@ -447,11 +455,13 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan ), MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon + "model.layers.{bid}.self_attn.query_layernorm", # hunyuan "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 @@ -461,6 +471,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon + "model.layers.{bid}.self_attn.key_layernorm", # hunyuan "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 @@ -545,42 +556,64 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba falcon-h1 ), MODEL_TENSOR.SSM_NORM: ( + "model.layers.{bid}.mamba.norm", # falcon-h1 "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/include/llama.h b/include/llama.h index 3eda9bc68608c..dc86aea41dcbd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -117,6 +117,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, }; enum llama_rope_type { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab24054305857..1955c03eb3d1c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -46,6 +46,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_JAMBA, "jamba" }, + { LLM_ARCH_FALCON_H1, "falcon-h1" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -78,6 +80,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1022,6 +1026,61 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_FALCON_H1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1694,12 +1753,52 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_HUNYUAN_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, }, + { + LLM_ARCH_SMOLLM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { @@ -1778,6 +1877,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1925,9 +2027,11 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } bool llm_arch_is_hybrid(const llm_arch & arch) { - // TODO: There are currently no hybrid models! Once there are, this will be - // the place to identify them + // List all mamba-attention hybrid models here switch (arch) { + case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: + return true; default: return false; } diff --git a/src/llama-arch.h b/src/llama-arch.h index b769831dff5ec..3381b8dc4a4b7 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -50,6 +50,8 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, + LLM_ARCH_JAMBA, + LLM_ARCH_FALCON_H1, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -82,6 +84,8 @@ enum llm_arch { LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, + LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_SMOLLM3, LLM_ARCH_UNKNOWN, }; @@ -293,7 +297,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 5d317f4ee62eb..cbc19d3c40c30 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -64,6 +64,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, + { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; + } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -665,6 +668,18 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|response|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { + // tencent/Hunyuan-A13B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|startoftext|>" << message->content << "<|extra_4|>"; + } else if (role == "assistant") { + ss << "<|startoftext|>" << message->content << "<|eos|>"; + } else { + ss << "<|startoftext|>" << message->content << "<|extra_0|>"; + } + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index 38800010ae48b..b621fda281669 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -44,6 +44,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, + LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7f0e8c67f1325..a248a7ec22350 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -336,29 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - - const int64_t n_rs = mctx->get_recr()->get_n_rs(); - - if (s_copy) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); - int32_t * data = (int32_t *) s_copy->data; - - // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n - for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mctx->get_recr()->s_copy(i); - } - } -} - -void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - GGML_ASSERT(one && ggml_nelements(one) == 1); - float f_one = 1.0f; - ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); + inp_attn->set_input(ubatch); + inp_rs->set_input(ubatch); } // @@ -992,35 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t return pos_bias; } -llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { - const auto * mctx_cur = static_cast(mctx); - - auto inp = std::make_unique(hparams, cparams, mctx_cur); - - { - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - - const auto n_kv = inp->mctx->get_attn()->get_n_kv(); - - inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); - inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); - - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); - ggml_set_input(inp->self_kq_mask); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - } - - { - const auto n_rs = mctx_cur->get_recr()->get_n_rs(); - - inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); - ggml_set_input(inp->s_copy); - } - - return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); -} - ggml_tensor * llm_graph_context::build_attn_mha( ggml_cgraph * gf, ggml_tensor * q, @@ -1194,8 +1144,12 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * mctx_cur = static_cast(mctx); +static std::unique_ptr build_attn_inp_kv_unified_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_context * mctx_cur) { auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1203,6 +1157,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); @@ -1213,6 +1168,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + return inp; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } @@ -1234,7 +1197,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * mctx_cur = static_cast(mctx); + const auto * mctx_cur = inp->mctx; // store to KV cache { @@ -1293,7 +1256,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, v_cur); } - const auto * mctx_iswa = static_cast(mctx); + const auto * mctx_iswa = inp->mctx; const bool is_swa = hparams.is_swa(il); @@ -1391,59 +1354,9 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); - - const auto * mctx_cur = static_cast(mctx)->get_attn(); - - // store to KV cache - { - const auto & k_idxs = inp->get_k_idxs(); - const auto & v_idxs = inp->get_v_idxs(); - - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); - } - - const auto & kq_mask = inp->get_kq_mask(); - - ggml_tensor * q = q_cur; - ggml_tensor * k = mctx_cur->get_k(ctx0, il); - ggml_tensor * v = mctx_cur->get_v(ctx0, il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); - cb(cur, "kqv_out", il); - - if (wo) { - cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - - return cur; -} - +// TODO: maybe separate the inner implementation into a separate function +// like with the non-sliding window equivalent +// once sliding-window hybrid caches are a thing. llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * mctx_cur = static_cast(mctx); @@ -1513,8 +1426,9 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } -llm_graph_input_rs * llm_graph_context::build_rs_inp() const { - const auto * mctx_cur = static_cast(mctx); +static std::unique_ptr build_rs_inp_impl( + ggml_context * ctx0, + const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); @@ -1523,29 +1437,25 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const { inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); - return (llm_graph_input_rs *) res->add_input(std::move(inp)); + return inp; } -ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * s, - int32_t state_size, - int32_t n_seqs, - const llm_graph_get_rows_fn & get_state_rows) const { - const auto * kv_state = static_cast(mctx); +llm_graph_input_rs * llm_graph_context::build_rs_inp() const { + const auto * mctx_cur = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + auto inp = build_rs_inp_impl(ctx0, mctx_cur); + + return (llm_graph_input_rs *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_mem_hybrid * inp, + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows) const { - const auto * kv_state = static_cast(mctx)->get_recr(); + const auto * kv_state = inp->mctx; return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } @@ -1592,6 +1502,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ); } +llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_pooling( ggml_cgraph * gf, ggml_tensor * cls, diff --git a/src/llama-graph.h b/src/llama-graph.h index 7bdf656768a0c..fbf8e2889564d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -322,47 +322,25 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_memory_hybrid_context * mctx) : - hparams(hparams), - cparams(cparams), - mctx(mctx) { - } + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * s_copy; // I32 [kv_size] + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; - ggml_tensor * get_k_idxs() const { return self_k_idxs; } - ggml_tensor * get_v_idxs() const { return self_v_idxs; } - - ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - - ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] - - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - - const llama_hparams & hparams; - const llama_cparams & cparams; + llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } const llama_memory_hybrid_context * mctx; }; -// TODO: remove this when ggml_scale_add is implemented -class llm_graph_input_one : public llm_graph_input_i { -public: - llm_graph_input_one() {} - virtual ~llm_graph_input_one() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * one = nullptr; // F32 -}; - // // llm_graph_result // @@ -579,8 +557,6 @@ struct llm_graph_context { ggml_tensor * build_inp_pos_bucket_dec() const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; - llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; - // // attention // @@ -656,18 +632,6 @@ struct llm_graph_context { float kq_scale, int il) const; - ggml_tensor * build_attn( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; // // recurrent // @@ -700,14 +664,6 @@ struct llm_graph_context { int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - ggml_tensor * build_rs( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * s, - int32_t state_size, - int32_t n_seqs, - const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, ggml_cgraph * gf, @@ -718,6 +674,11 @@ struct llm_graph_context { ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const; + // + // hybrid + // + + llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; // // pooling diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4b90dac7a327c..2c1ae67098ca4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent( uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; - LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n", - __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer); - head = 0; size = mem_size; used = 0; @@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent( ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); + throw std::runtime_error("failed to create ggml context for rs cache"); } ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); @@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); + throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } @@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -377,14 +374,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & ubatch = balloc.split_equal(n_ubatch, false); } - if (balloc.get_n_used() < balloc.get_n_tokens()) { - // failed to find a suitable split + if (ubatch.n_tokens == 0) { break; } ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + if (!prepare(ubatches)) { break; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0573c5bcea0a4..ca094e47b6cb5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -102,6 +102,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; + case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_E2B: return "E2B"; @@ -1117,6 +1118,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1549,6 +1570,58 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_FALCON_H1: + { + // Common parameters + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + case 72: + type = LLM_TYPE_34B; break; + default: + type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_HUNYUAN_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_A13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_SMOLLM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; + + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3178,10 +3251,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } } @@ -3208,6 +3281,87 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // Attention layers + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } else { + // FFN (no MoE) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_XVERSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4469,6 +4623,149 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_FALCON_H1: + { + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + // ssm_norm + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_HUNYUAN_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } break; + case LLM_ARCH_SMOLLM3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); @@ -4714,16 +5011,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - } - - if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - if (!classifier_labels.empty()) { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); @@ -4734,6 +5021,18 @@ void llama_model::print_info() const { } } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); @@ -5670,12 +5969,10 @@ struct llm_build_falcon : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode @@ -5952,12 +6249,10 @@ struct llm_build_dbrx : public llm_graph_context { cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(cur, "wqkv_clamped", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -6468,12 +6763,10 @@ struct llm_build_neo_bert : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // RoPE @@ -6703,8 +6996,8 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); cb(Qcur, "Qcur", il); @@ -6724,6 +7017,12 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_k_norm_b, LLM_NORM, il); cb(Kcur, "Kcur", il); + } else { + Qcur = ggml_cont(ctx0, Qcur); + cb(Qcur, "Qcur", il); + + Kcur = ggml_cont(ctx0, Kcur); + cb(Kcur, "Kcur", il); } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -6978,12 +7277,10 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode @@ -7748,21 +8045,21 @@ struct llm_build_phi2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -7886,21 +8183,21 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -8256,12 +8553,10 @@ struct llm_build_codeshell : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -8677,8 +8972,6 @@ struct llm_build_minicpm3 : public llm_graph_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = build_norm(kv_compressed, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, il); @@ -8705,12 +8998,6 @@ struct llm_build_minicpm3 : public llm_graph_context { v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8719,7 +9006,6 @@ struct llm_build_minicpm3 : public llm_graph_context { cb(q_pe, "q_pe", il); // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9199,8 +9485,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int n_layer_sparsity = 10; // number of layers using activation sparsity const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) - ggml_tensor * one; // containing single element 1.0f - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model), @@ -9212,14 +9496,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * cur; ggml_tensor * inpL; - // TODO: remove this when ggml_scale_add is implemented - one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - { - auto inp = std::make_unique(); - inp->one = one; - res->add_input(std::move(inp)); - } - inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) @@ -9609,7 +9885,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(innovation, "innovation", il); ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] - all_coefs = ggml_add(ctx0, all_coefs, one); + all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 cb(all_coefs, "all_coefs", il); all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] @@ -9752,73 +10028,22 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -struct llm_build_mamba : public llm_graph_context { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; +struct llm_graph_context_mamba : public llm_graph_context { + llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} - // {n_embd, n_tokens} - inpL = build_inp_embd(model.tok_embd); + ggml_tensor * build_mamba_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { - auto * rs_inp = build_rs_inp(); + const auto * mctx_cur = inp->mctx; - ggml_tensor * inp_out_ids = build_inp_out_ids(); + const auto kv_head = mctx_cur->get_head(); - for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); - } else { - cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // residual - cur = ggml_add(ctx0, cur, inpL); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - // final rmsnorm - cur = build_norm(inpL, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } - - ggml_tensor * build_mamba_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(mctx); - - const auto kv_head = mctx_cur->get_head(); + const auto & layer = model.layers[il]; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -9829,8 +10054,6 @@ struct llm_build_mamba : public llm_graph_context { const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; - // Use the same RMS norm as the final layer norm - const float norm_rms_eps = hparams.f_norm_rms_eps; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9848,7 +10071,7 @@ struct llm_build_mamba : public llm_graph_context { cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -9877,10 +10100,10 @@ struct llm_build_mamba : public llm_graph_context { // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. // For simultaneous sequences, all sequences need to have the same length. - x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d); // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + x = ggml_add(ctx0, x, layer.ssm_conv1d_b); x = ggml_silu(ctx0, x); } @@ -9888,27 +10111,27 @@ struct llm_build_mamba : public llm_graph_context { // ssm { // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} - ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); + ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); - // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers - if (ssm_dt_b_c_rms) { - dt = ggml_rms_norm(ctx0, dt, norm_rms_eps); - B = ggml_rms_norm(ctx0, B, norm_rms_eps); - C = ggml_rms_norm(ctx0, C, norm_rms_eps); + // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) { + dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il); + B = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il); } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} - dt = build_lora_mm(model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + dt = build_lora_mm(layer.ssm_dt, dt); + dt = ggml_add(ctx0, dt, layer.ssm_dt_b); cur = x; x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = model.layers[il].ssm_a; + ggml_tensor * A = layer.ssm_a; // use the states and the indices provided by build_recurrent_state // (this is necessary in order to properly use the states before they are overwritten, @@ -9934,16 +10157,15 @@ struct llm_build_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens - y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - // cb(cur, "mamba_out", il); return cur; } @@ -9955,7 +10177,8 @@ struct llm_build_mamba : public llm_graph_context { const llama_model & model, const llama_ubatch & ubatch, int il) const { - const auto * mctx_cur = static_cast(mctx); + + const auto * mctx_cur = inp->mctx; const auto kv_head = mctx_cur->get_head(); @@ -10059,11 +10282,14 @@ struct llm_build_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + if (model.layers[il].ssm_norm) { + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10072,12 +10298,178 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - // cb(cur, "mamba_out", il); + cb(cur, "mamba_out", il); return cur; } }; +struct llm_build_mamba : public llm_graph_context_mamba { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + auto * rs_inp = build_rs_inp(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + } else { + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +}; + +struct llm_build_jamba : public llm_graph_context_mamba { + llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (n_head_kv == 0) { + cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + } else { + // Attention + + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // No RoPE :) + cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_command_r : public llm_graph_context { llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -10784,10 +11176,10 @@ struct llm_build_openelm : public llm_graph_context { cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); @@ -10909,12 +11301,10 @@ struct llm_build_gptneox : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -12159,6 +12549,8 @@ struct llm_build_chatglm : public llm_graph_context { if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -12166,13 +12558,11 @@ struct llm_build_chatglm : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); @@ -12293,6 +12683,8 @@ struct llm_build_glm4 : public llm_graph_context { if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -12300,13 +12692,11 @@ struct llm_build_glm4 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -14525,13 +14915,10 @@ struct llm_build_ernie4_5 : public llm_graph_context { } }; -struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_falcon_h1 : public llm_graph_context_mamba { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - ggml_tensor * cur; ggml_tensor * inpL; @@ -14540,7 +14927,8 @@ struct llm_build_arcee : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + // Build the inputs in the recurrent & kv cache + auto * inp = build_inp_mem_hybrid(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -14549,61 +14937,189 @@ struct llm_build_arcee : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + cb(Qcur, "Qcur-post-rope", il); + cb(Kcur, "Kcur-post-rope", il); + cb(Vcur, "Vcur-post-rope", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, model.layers[il].bo, + ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(attn_out, "attn_out", il); + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + // Mamba2 layer + cb(cur, "ssm_in", il); + + ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); + cb(ssm_out, "ssm_out", il); + + // // Aggregation + cur = ggml_add(ctx0, attn_out, ssm_out); + inpSA = ggml_add(ctx0, cur, inpSA); + cb(cur, "layer_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = inpSA; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_arcee : public llm_graph_context { + llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -14660,6 +15176,304 @@ struct llm_build_arcee : public llm_graph_context { } }; +struct llm_build_hunyuan_moe : public llm_graph_context { + llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_mlp", il); + + // MoE branch + ggml_tensor * cur_moe = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, + true, // norm_topk_prob + false, + 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur_moe, "ffn_moe_out", il); + + ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp); + cb(ffn_out, "ffn_out", il); + + cur = ggml_add(ctx0, ffn_out, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_smollm3 : public llm_graph_context { + llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -14706,7 +15520,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv); + /* offload */ cparams.offload_kqv, + /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, + /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); @@ -14899,6 +15715,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_JAMBA: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_XVERSE: { llm = std::make_unique(*this, params, gf); @@ -15040,6 +15860,18 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_HUNYUAN_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_SMOLLM3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_FALCON_H1: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -15157,6 +15989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -15191,12 +16024,14 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: + case LLM_ARCH_FALCON_H1: case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: @@ -15228,6 +16063,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: + case LLM_ARCH_HUNYUAN_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 979fff62045f9..453f5af62fbc7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -94,6 +94,7 @@ enum llm_type { LLM_TYPE_57B_A14B, LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick + LLM_TYPE_A13B, LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, LLM_TYPE_E2B, @@ -173,6 +174,9 @@ struct llama_layer { struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * ssm_dt_norm = nullptr; + struct ggml_tensor * ssm_b_norm = nullptr; + struct ggml_tensor * ssm_c_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5c9eb87566dde..6aa1d901c5e36 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" @@ -1522,6 +1523,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe"|| tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; @@ -1554,7 +1556,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-de" || tokenizer_pre == "gigachat" || tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de") { + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "a.x-4.0") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -1656,6 +1659,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "seed-coder") { pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0dc9c09e28ee2..1d837b4322cfa 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2368,22 +2368,24 @@ struct test_scale : public test_case { const ggml_type type; const std::array ne; float scale; + float bias; std::string vars() override { - return VARS_TO_STR3(type, ne, scale); + return VARS_TO_STR4(type, ne, scale, bias); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - float scale = 2.0f) - : type(type), ne(ne), scale(scale) {} + float scale = 2.0f, + float bias = 0.0f) + : type(type), ne(ne), scale(scale), bias(bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale(ctx, a, scale); + ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); ggml_set_name(out, "out"); return out; @@ -2583,10 +2585,6 @@ struct test_rms_norm_mul : public test_case { } } - double max_nmse_err() override { - return 1e-6; - } - float grad_eps() override { return 1.0f; } @@ -5048,6 +5046,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_silu_back()); for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { @@ -5058,7 +5057,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } - for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } @@ -5327,12 +5326,12 @@ static std::vector> make_test_cases_eval() { for (bool fw : {true, false}) { // fw == forward bool all = true; - for (float v : { 0, 1 }) { - for (float fs : { 1.0f, 1.4245f }) { - for (float ef : { 0.0f, 0.7465f }) { - for (float af : { 1.0f, 1.4245f }) { - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (bool ff : {false, true}) { // freq_factors + for (float fs : { 1.0f, 1.4245f }) { + for (float ef : { 0.0f, 0.7465f }) { + for (float af : { 1.0f, 1.4245f }) { + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (bool ff : {false, true}) { // freq_factors + for (float v : { 0, 1 }) { test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { @@ -5345,13 +5344,21 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d3f6271931f62..57b917f2f97b3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -4806,14 +4806,14 @@ int main(int argc, char ** argv) { // register static assets routes if (!params.public_path.empty()) { // Set the base directory for serving static files - bool is_found = svr->set_mount_point("/", params.public_path); + bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); if (!is_found) { LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); return 1; } } else { // using embedded static index.html - svr->Get("/", [](const httplib::Request & req, httplib::Response & res) { + svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { res.set_content("Error: gzip is not supported by this browser", "text/plain"); } else { @@ -4829,37 +4829,37 @@ int main(int argc, char ** argv) { } // register API routes - svr->Get ("/health", handle_health); // public endpoint (no API key check) - svr->Get ("/metrics", handle_metrics); - svr->Get ("/props", handle_props); - svr->Post("/props", handle_props_change); - svr->Post("/api/show", handle_api_show); - svr->Get ("/models", handle_models); // public endpoint (no API key check) - svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) - svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) - svr->Post("/completion", handle_completions); // legacy - svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions_oai); - svr->Post("/chat/completions", handle_chat_completions); - svr->Post("/v1/chat/completions", handle_chat_completions); - svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint - svr->Post("/infill", handle_infill); - svr->Post("/embedding", handle_embeddings); // legacy - svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings_oai); - svr->Post("/rerank", handle_rerank); - svr->Post("/reranking", handle_rerank); - svr->Post("/v1/rerank", handle_rerank); - svr->Post("/v1/reranking", handle_rerank); - svr->Post("/tokenize", handle_tokenize); - svr->Post("/detokenize", handle_detokenize); - svr->Post("/apply-template", handle_apply_template); + svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/metrics", handle_metrics); + svr->Get (params.api_prefix + "/props", handle_props); + svr->Post(params.api_prefix + "/props", handle_props_change); + svr->Post(params.api_prefix + "/api/show", handle_api_show); + svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) + svr->Post(params.api_prefix + "/completion", handle_completions); // legacy + svr->Post(params.api_prefix + "/completions", handle_completions); + svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai); + svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions); + svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions); + svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint + svr->Post(params.api_prefix + "/infill", handle_infill); + svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy + svr->Post(params.api_prefix + "/embeddings", handle_embeddings); + svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai); + svr->Post(params.api_prefix + "/rerank", handle_rerank); + svr->Post(params.api_prefix + "/reranking", handle_rerank); + svr->Post(params.api_prefix + "/v1/rerank", handle_rerank); + svr->Post(params.api_prefix + "/v1/reranking", handle_rerank); + svr->Post(params.api_prefix + "/tokenize", handle_tokenize); + svr->Post(params.api_prefix + "/detokenize", handle_detokenize); + svr->Post(params.api_prefix + "/apply-template", handle_apply_template); // LoRA adapters hotswap - svr->Get ("/lora-adapters", handle_lora_adapters_list); - svr->Post("/lora-adapters", handle_lora_adapters_apply); + svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list); + svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply); // Save & load slots - svr->Get ("/slots", handle_slots); - svr->Post("/slots/:id_slot", handle_slots_action); + svr->Get (params.api_prefix + "/slots", handle_slots); + svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); // // Start the server