guide: adding new model architectures #16770
pwilkin
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Intro
With my adventures with Qwen3 Next (#16095), which some of you might have seem, I think I learned a lot about the Llama.cpp / GGML internal architecture / quirks, so I decided to write a guide for someone who might want to try, like me, to dust off their C++ and attempt to implement a new model architecture.
Prerequisites
While nowadays a lot of people vibe code stuff with the help of LLMs and it's possible to do a lot of things outside the area of your expertise, there are nevertheless a couple of skills that are must haves for coding a model architecture.
Where to start
Obviously, unless you're willing to add a new model architecture from scratch in GGML (feel free, but it's a pretty challenging task!), you'll be working off some reference implementation. Now, for various reasons, the best reference implementation choice is the Transformers implementation, if there is one. First, we have tooling that lets you easily work with a Transformers implementation to compare your adaptation, second, there is know-how on how to do various things that other frameworks might lack. If you have to pick another reference implementation than Transformers, the difficulty level rises significantly. Also, be sure you pick a stable reference implementation. Sometimes, even the model developers themselves leave bugs in their implementations that they will fix later on. Make sure the reference implementation is actually stable and producing good results before you attempt the conversion.
The key elements
convert_hf_to_gguf_update.py (new tokenizer)
If your model features a new tokenizer - one that hasn't been used before - you have to update
convert_hf_to_gguf_update.pyto fetch the tokenizer from a reference model on Huggingface, then run the script. If your tokenizer is already supported, you will skip this step. Don't add a new tokenizer if it's the same tokenizer as one already used.convert_hf_to_gguf.py
This is where any conversion starts. The conversion script. Currently a huge, barely maintainable mess.
However, this isn't the only file you need to look at before the conversion. Probably, the ones you should start with - before you even write a line of code in
convert_hf_to_gguf.py- areconstants.pyandtensor_mapping.pyinside thegguf-pypackage. They contain the currently supported tensor types in GGML as well as the parameter names. You should look at themodel.safetensors.index.jsonandconfig.jsonfiles from the original model (we're assuming you're converting from safetensors here) to get the list of layers and the list of parameters, respectively. The first thing you should do is to make a checklist:tensor_mappings.pyto check if a tensor with the same name scheme wasn't previously handled in some other model)modeling_*.pyclass from the Transformers implementation to check which parameters you will need or you can just add parameters on the fly when needed. A lot of the parameters will be handled by default implementations, so keep that in mind - you do not have to manually handle all of the parameters.exp(your_tensor)or1 + your_tensoror similar), then you should do that in the preparation code to both simplify the graph building and optimize (since you will only do the calculation once). The reason why this is sometimes not done in the reference code is because of the backward pass, but as I mentioned, we're only aiming for forward pass capability most of the times.Once you've looked at the reference code, tensors and parameters, it's time to write the code. You will be modifying the following places:
constants.py- the model architecture constant, its codename and the list of tensors that it usestensor_mapping.py- if the names of the tensors in your model follow a non-standard pattern, i.e. don't match any of the patterns already listed, add the mappings in that file. Unless some tensors are completely specific to your model, add the mappings in the general case.llama-arch.handllama-arch.cpp- the architecture and name, to mirror the one inconstants.pyconvert_hf_to_gguf.py- the conversion class itselfWhen writing the conversion class, make sure to inherit from the closest class that mirrors your architecture (
TextModelat the minimum). Usually, you can just copy-paste the additional parameters, expert packing and other tensor conversion code from existing examples, but just in case, the methods you want to pay attention to are:set_gguf_parameters- here's where you will convert all the nonstandard hyperparametersprepare_tensors- create any new tensors that might be neededmodify_tensors- perform any processing on existing tensors (merge, omit, transform etc.)One important thing to note is that GGML tensors that have a weight and bias part have to follow the "X.weight" and "X.bias" convention, so if the convention is different (for example "X_weight" and "X_bias" or "X" and "X_bias"), you will have to modify the names to match the convention.
Once you're done, you should attempt to convert the model. Pay attention to any error messages suggesting that there are tensors you're not handling. If there are tensor you do not want to handle (for example because they correspond to functionality you are not converting, i.e. MTP), you explicitly have to ignore them in
modify_tensors.llama-model.cpp
Now comes the hard part - the graph implementation. You will most certainly want to look inside
examples/model-conversion- there are scripts there for running a reference implementation together with your implementation to get the logits and compare. However, logits might not be enough - if you struggle with conversion, you might want to analyze tensor dumps. Therun-org-modelscript has been modified to provide GGML-style tensor dumps, but you might want to monkey-patch some more functions specific to your model (a commented example is in the script file).What is a graph builder?
Before we even start, a key note on how inference works in Llama.cpp. When you look at Transformers code, you can basically figure out how the processing goes - the model's main
forwardmethod is called, which chains theforwardcalls of the submodules for the separate layers and components within. But that is not how inference works in GGML / Lllama.cpp.In Llama.cpp, due to its support of multiple different backends, which are supposed to be transparent to the end user, a different approach is used. What the model architecture constructs is not the inference itself but the inference graph. This means a couple of things. First of all, whenever you code any instructions that operate on tensors, what will happen when those instructions are called is that the respective operation will get added to the graph, not executed. That means that, apart from the tensor type, dimensions and strides, you cannot obtain any information from the tensor at the stage of building the graph itself - notably, you cannot extract any data. Second, that means that the graph has to be static. Its shape can depend on hyperparameters of the model or metaparameters such as layer number, but it cannot depend on the values of the processed tensors at the time of processing. Any such dependencies will have to be encoded into the processing itself - into entities known as operations. Also, the graph building will be done lazily - if a node is deemed unreachable, it will not be computed. That means you will have to mark which nodes are actually important for your processing - which is done using the function
ggml_build_forward_expand. This function marks a graph node as essential in the computation, which means the node - and all nodes along the path to that node - will be computed (note - that still does not perform the computation itself and there is no method in the graph builder to force the computation, because the graph is esentially a template and not a singular instance of an inference). And last but not least, you cannot perform loops. Even if you can make the loop logic independent of specific tensor values, looping over eg. the number of tokens, when a batch can have, say, 512 tokens, will generate a subgraph with (512 * the number of nodes inside the loop) subnodes. This is not feasible and will exhaust the memory instantly. If you need a loop, you will have to, again, do it inside a custom operation (we'll get back to that topic soon, I promise).GGML quirks and specifics
Okay, before we go any deeper into graph building code, two major differences between GGML and Transformers that you absolutely have to be aware:
n_seq_tokens, number of sequences akan_seq).matmuloperation in GGML:matmul(A, B)corresponds totranspose(B) @ Ain Transformers or, in other terms,transpose(matmul(A, B)) = transpose(A) @ B. Most of the other operations mirror their Transformers counterparts, but matrix multiplication is the one which stands out.Model preparation
Before we can do any graph building, the graph needs to have a model to operate. We converted the tensors, now we need to load them into a layered model structure. First, we load the hyperparameters into the GGML
hparamstructures, which is done inllama-model.cppin theload_hparams()method. Then, we load the tensors themselves - this is done in the huge switch statement within theload_tensors()method. Basically, this is where you actually load the tensors you converted. All of the tensors have to actually get loaded, or the loader will throw an error. This is also where you specify the tensors' dimensions using the hyperparameters. This will be the first real test of the conversion - whether you can reasonably load the tensors you converted with proper dimensions that make semantic sense within the model parameters.Finally, if your model has a novel size, make sure to update the model sizes enum!
Graph building
Now we get to the nontrivial part. Building the graph. First of all, note that GGML already has functions for some typical things that happen during an LLM's forward inference - such as RoPE, expert routing, standard attention, KV cache handling, input position embeddings. Unless you're absolutely sure that the architecture you're building for does things completely differently, assume that you will be using those functions. Don't try to mirror the Transformers code step-by-step - just understand what the building blocks are.
Most typical tensor operations have their counterparts in GGML. Here are a couple typical ones:
ggml_mul_mat(but see note above), alsobuild_mm_loraspecifically for weight projections which might use LoRAsggml_mul(this has broadcasting, but only in one dimension)ggml_repeat_4d(broadcasts the tensor to fill the given shape)ggml_pad(but the semantics are different, so beware -ggml_paddoes post-padding and takes the pad value in each dimension, whereas standard Transformerspadonly does padding on the semantic dimensions, but takes pre- and post- pad arguments. If you want pre- and post- padding, seeggml_pad_extggml_reshape_(1|2|3|4)dggml_cont(you can also useggml_cont_(1|2|3|4)dwhich is short for reshape + cont)ggml_permuteorggml_transpose- the former takes a permutation tuple, or the indices where the respective dimensions will land, soggml_permute(ctx, t, 3, 0, 1, 2)means a permutation of 0 -> 3, 1 -> 0, 2 -> 1, 3 -> 2; the latter just exchanges the semantic dimensions (the first two)ggml_scale,ggml_scale_biasggml_expggml_new_tensor_(1|2|3|4)d, the tensor is zeroed by defaultggml_exp(ggml_new_tensor)(since e^0 = 1)ggml_view_(1|2|3|4)d- you will have to provide strides, which are the number of bytes by which you need to jump to get the next element on that dimension, and offset. Remember to useggml_element_sizeandggml_nelementsinstead of using fixed data types. Unless you are doing some really weird views, passing the original tensor strides (which are held in the t->nb array) will be enough. If you're looking for dimensions (the equivalent of Transformers'.shape), they're int->ne(so the full shape of the tensor is (t->ne[0], t->ne[1], t->ne[2], t->ne[3]))New operations
If you find an operation that isn't implemented, you can try to work around it with equivalent transformations or add the operation. Adding an operation is a separate topic in itself, but the standard is to provide a CPU backend implementation as the reference one and keep any further backends for separate PRs. Basically, the operations are the code that actually gets executed once the graph is computed and inference is done. Once you're coding operations, you're out of the world of abstractions and in the world of dirty pointer arithmetic, memory allocations, manually traversing tensor dimensions and the like. Note: there are abstractions coded in the library, for example, there are C++ templates if your operation is a true unary (operates independently on tensor elements element-by-element, without any extra parameters, such as NEG or EXP) or binary (operates element-wise on element-pairs from two tensors) operation. So, to add an operation, you have to:
ggml_<operation>) which actually queues the operation by preparing the result tensor and passing the arguments - either as source tensors in the t->src array or in the dedicated parameters array as t->op_paramsggml.handggml.c, don't forget to increase the static count assertions; if it's a true unary op, don't add it to the ops enum, but to the unary ops enumggml-cpu.cand implement the operation itself (likely inops.cpporunary-ops.cpp)test-backend-opsWhich RoPE?
This is a short but potentially annoying topic. There are basically two types of RoPE used: normal and so-called "NeoX". However, determining which is which is counterintuitive. When you have a Transformers implementation of
apply_rotary_pos_emb, it will usually look like this:The critical part are these two lines:
If those lines are present, it's normal RoPE. If they aren't present, it's NeoX RoPE. Mark the appropriate case in the switch at the end of
llama-model.cpp.Debugging
That's it! Now comes the tedious part of making sure your implementation works well. At a minimum, the conversion should pass the test in
examples/model-conversionfrom the compare logits task. You should also verify that it produces coherent long output and that it can coherently read long inputs. If your conversion matches the reference perfectly on short prompt processing, but diverges on generation, it's usually a sign of either (a) incorrect state management (can happen especially in mamba / hybrid models) or (b) bad RoPE (see above, also make sure the hyperparams for RoPE are correct, especially if using YaRN). From my experience, the biggest culprits in divergence are incorrect tensor shapes and/or transpositions, which quickly lead to diverging outputs. A good heuristic when looking at tensor dumps is to look whether the top-left and bottom-right corners match, i.e. the elements (2, 0), (1, 0), (0, 0), (0, 1), (0, 2) and their bottom-right counterparts - that can let you quickly spot obvious transposition or tensor-ordering problems. For easily gathering tensor dumps on single token processing, usellama-eval-callback- if you need to debug longer sequences, you might need to copy the callback code fromllama-eval-callbacktollama-cliand enable the callback by passing the appropriate params (seeeval-callback.cppfor the details).Prompt template
So, you got a perfect match and you thought you were done, eh? Not so easy. If your model uses a non-typical thinking marker, non-typical tool-calling markers or both, you have to add chat template support in
chat.cppandllama-chat.cpp. This means, basically:Unless your model is especially nasty, this basically means copying what has already been done with a similar model. Make sure to add the necessary tests to
test-chat.cppto verify if your chat parser works correctly!And that's it!
If you made it this far, congratulations. Your basic model architecture implementation is done. Now you can focus on adding those optimized CUDA kernels, Vulkan shaders and all the other things that aren't critical for correctness, but can make the model work faster :)
Beta Was this translation helpful? Give feedback.
All reactions