# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *


  import pkg_resources,importlib


In [None]:
#| export
from llama_cpp import Llama

Let's start by looking at what the regular llama-cpp flow looks like, as per their docs. I'm going to use `Tiny Llama 1.1B` as my test model. You will need to download a more capable model and setup the path pointer.

In [None]:
#| hide
from pathlib import Path
tl1B_path=str(Path("/home/zardar/Downloads/tinyllama-1.1b-chat-v1.0.Q2_K.gguf")) ; tl1B_path

'/home/zardar/Downloads/tinyllama-1.1b-chat-v1.0.Q2_K.gguf'

In [None]:
llm = Llama(model_path=tl1B_path)

llama_model_loader: loaded meta data with 23 key-value pairs and 201 tensors from /home/zardar/Downloads/tinyllama-1.1b-chat-v1.0.Q2_K.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = tinyllama_tinyllama-1.1b-chat-v1.0
llama_model_loader: - kv   2:                       llama.context_length u32              = 2048
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 2048
llama_model_loader: - kv   4:                          llama.block_count u32              = 22
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 5632
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 64
llama_model_loader: - kv   7:

In [None]:
llm("Microscopic examination reveals nests of atypical squamous cells with keratinization, diagnostic of squamous cell carcinoma. Q: What is the Histologic type? A:", 
    max_tokens=32, stop=["Q:", "\n"], echo=True)

llama_perf_context_print:        load time =     582.87 ms
llama_perf_context_print: prompt eval time =     582.53 ms /    47 tokens (   12.39 ms per token,    80.68 tokens per second)
llama_perf_context_print:        eval time =     131.71 ms /     5 runs   (   26.34 ms per token,    37.96 tokens per second)
llama_perf_context_print:       total time =     716.74 ms /    52 tokens
llama_perf_context_print:    graphs reused =          4


{'id': 'cmpl-15dbd0fc-6fe7-496f-a09b-c9e8110a8f5a',
 'object': 'text_completion',
 'created': 1763237767,
 'model': '/home/zardar/Downloads/tinyllama-1.1b-chat-v1.0.Q2_K.gguf',
 'choices': [{'text': 'Microscopic examination reveals nests of atypical squamous cells with keratinization, diagnostic of squamous cell carcinoma. Q: What is the Histologic type? A: Keratinization. ',
   'index': 0,
   'logprobs': None,
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 47, 'completion_tokens': 6, 'total_tokens': 53}}

There are different formats, we can use `chat_completion` approach. There's also `JSON` and `JSON Schema Mode`. These use grammars to enforce structure but it's a bit simplistic for our purposes. Now the same task as above expressed as `chat_completions`.

- [ ] **TODO:** expand on kwargs such as `n_ctx`, especially `n_gpu_layers=-1`   

In [None]:
llm.create_chat_completion(
      messages = [
          {"role": "system", "content": "You are a medical assistant specialized in cancer reporting."},
          {
              "role": "user",
              "content": "What is the Histologic type in the following report: Microscopic examination reveals nests of atypical squamous cells with keratinization, diagnostic of squamous cell carcinoma."
          }
      ]
)


llama_perf_context_print:        load time =     582.87 ms
llama_perf_context_print: prompt eval time =     899.08 ms /    82 tokens (   10.96 ms per token,    91.20 tokens per second)
llama_perf_context_print:        eval time =     697.40 ms /    27 runs   (   25.83 ms per token,    38.72 tokens per second)
llama_perf_context_print:       total time =    1607.69 ms /   109 tokens
llama_perf_context_print:    graphs reused =         25


{'id': 'chatcmpl-3aff67c4-9f2c-405a-9340-9ccd20ec1f90',
 'object': 'chat.completion',
 'created': 1763237768,
 'model': '/home/zardar/Downloads/tinyllama-1.1b-chat-v1.0.Q2_K.gguf',
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': 'The Histologic type in the following report is "Squamo-Squamo-Keratocarcinoma."'},
   'logprobs': None,
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 82, 'completion_tokens': 27, 'total_tokens': 109}}

**Note:** Ignore the incorrect terminology, this is purely for demonstrative purposes. Normally, we'd pick a much larger and more suitable model.

In [None]:
#| hide 
#| export
from fastcore.basics import patch
from tayz_decoding.types import CreateCRANEChatCompletionResponse
from typing import List, Dict, Type, Generator, Tuple
from pydantic import BaseModel
import xgrammar as xgr
from xgrammar import TokenizerInfo
import numpy as np
import json
import torch

  from .autonotebook import tqdm as notebook_tqdm


**Integrating `xgrammar` with `llama-cpp-python`'s Low-Level Sampler**
* The standard way to guide text generation in libraries like HuggingFace's transformers is with a `LogitsProcessor`. This is a high-level workflow that receives a full array of logits (probabilities for every token in the vocabulary) at each step and modifies them.

**BUT**
* `llama-cpp-python` is highly optimized for performance and avoids copying full logits tensor from its C++ core to Python at every token, this would be computational costly. Instead, it provides a more efficient, low-level callback mechnism: the `LlamaSampler`

**LlamaSampler**
1. *Candidate Selection*: The C++ core first runs buil-in samplers (like top-k or top-p), this reduces the # of possible next tokens to a small candidate set.
2. *Callback invocation*: It then invokes a cb function, `apply_func`, and passes it a C-level pointer to this small candidate set.
3. Our logic modifies the logits of only these few candidates 

**Solution**
* a custom sampler, `_make_xgr_sampler` that bridges `xgrammar` and `LlamaSampler`
* `xgrammar` expects a full logit vector, entire vocabulary
* `LlamaSampler` only provides a partial one
* *Scatter-Apply-Gather* pattern; we manually reconstruct the full logit tensor in Python, let `xgrammar` apply its full-vocabulary mask and gather the modified logits from the original candidates

In [None]:
#| export
@patch
def _crane_build_xgr_compiler(self: Llama) -> xgr.GrammarCompiler:
    """
    Builds and caches an xgrammar.GrammarCompiler for the current Llama model.
    Essential for converting JSON schemas into a format the grammar matcher can use.
    """
    if hasattr(self, "_xgr_compiler"): return self._xgr_compiler
    #raw_toks = [self._model.token_get_text(i) for i in range(vocab_size)]
    raw_toks = [self.detokenize([i], special=True) for i in range(self.n_vocab())]
    tok_info = xgr.TokenizerInfo(encoded_vocab=raw_toks, vocab_type=xgr.VocabType.RAW,
                                 vocab_size=self.n_vocab(), stop_token_ids=[self.token_eos()],
                                 add_prefix_space=True)
    self._xgr_compiler = xgr.GrammarCompiler(tok_info)
    return self._xgr_compiler

In [None]:
#| export
@patch
def _make_xgr_matcher(self: Llama, schema: Type[BaseModel]) -> tuple[xgr.GrammarMatcher, torch.Tensor]:
    """
    Creates an xgrammar.GrammarMatcher and a token bitmask for a given Pydantic schema.
    The matcher tracks the generation state against the schema, and the bitmask is the low-level data 
    structure used to enable/disable tokens.
    """
    compiler = self._crane_build_xgr_compiler()
    cg = compiler.compile_json_schema(json.dumps(schema.model_json_schema()), any_whitespace=True, strict_mode=True)
    matcher = xgr.GrammarMatcher(cg)
    bitmask = xgr.allocate_token_bitmask(batch_size=1, vocab_size=self.n_vocab())
    return matcher, bitmask

In [None]:
#| export
from llama_cpp._internals import LlamaSampler

In [None]:
#| export
# In llama-cpp-xgrammar-crane.py
@patch
def _make_xgr_sampler(self: Llama, matcher: xgr.GrammarMatcher, bitmask: torch.Tensor) -> LlamaSampler:
    """
    Creates a custom LlamaSampler that integrates xgrammar constraints into the llama.cpp sampling pipeline.
    """
    from llama_cpp._internals import LlamaSampler

    # Determine the device to use for tensor operations
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def apply_func(cur_p):
        # 1) Get the candidate logits and their token IDs from llama.cpp's sampler.
        #    This is a PARTIAL list, not the full vocabulary.
        sz = cur_p.contents.size
        arr = np.ctypeslib.as_array(cur_p.contents.data, shape=(sz,))
        logits_np = arr["logit"]  # The logits for the candidates
        ids_np = arr["id"].astype(np.int32) # The token IDs for the candidates
        ids_t = torch.from_numpy(ids_np).to(device)

        # 2) If grammar has terminated, force EOS and exit early.
        if matcher.is_terminated():
            logits_np[...] = -np.inf
            eos_token_id = self.token_eos()
            # Find if EOS is in our candidate set and set its logit to 0.
            eos_rows = np.where(ids_np == eos_token_id)[0]
            if eos_rows.size > 0:
                logits_np[eos_rows[0]] = 0.0
            return

        # 3) Create a full-sized logits tensor initialized to negative infinity.
        #    This is the "Scatter" step.
        vocab_size = self.n_vocab()
        full_logits_t = torch.full((vocab_size,), -float('inf'), dtype=torch.float32, device=device)
        
        # 4) Place the candidate logits into the full tensor at their correct positions.
        full_logits_t[ids_t] = torch.from_numpy(logits_np).to(device)

        # 5) Compute the next-token bitmask from the grammar matcher.
        xgr.reset_token_bitmask(bitmask)
        matcher.fill_next_token_bitmask(bitmask, index=0)

        # 6) Apply the full-vocabulary bitmask to the full-vocabulary logits tensor.
        #    The bitmask is on the CPU, so we move it to the correct device.
        xgr.apply_token_bitmask_inplace(full_logits_t, bitmask.to(device), vocab_size=vocab_size)

        # 7) Gather the modified logits for the original candidates back into the numpy array
        #    that llama.cpp will read from. This is the "Gather" step.
        modified_logits = full_logits_t[ids_t].cpu().numpy()
        np.copyto(logits_np, modified_logits)
        
        # 8) Backstop: If all candidate logits became -inf (due to a grammar mismatch),
        #    force the EOS token to prevent the sampler from failing.
        if not np.isfinite(logits_np).any():
            logits_np[...] = -np.inf
            eos_token_id = self.token_eos()
            eos_rows = np.where(ids_np == eos_token_id)[0]
            if eos_rows.size > 0:
                logits_np[eos_rows[0]] = 0.0
            else:
                # If EOS wasn't even a candidate, just pick the first candidate to avoid a total crash.
                if len(logits_np) > 0:
                    logits_np[0] = 0.0

    sampler = LlamaSampler()
    sampler.add_custom(apply_func)
    sampler.add_greedy()
    return sampler

In [None]:
#| export
@patch
def _get_ctx(self: Llama):
    """
    Internal helper to safely access the underlying llama.cpp context
    """
    ctx = getattr(self, "_ctx", None)
    if ctx is None or not hasattr(ctx, "ctx"): raise TypeError("Expected a LlamaContext as self._ctx (with .ctx handle)")
    return ctx

In [None]:
#| export
def _find_subseq(a: List[int], sub: List[int]) -> int:
    """
    Helper function to find a sub-sequence of tokens.
    """
    if not sub: return -1
    L,M = len(a), len(sub)
    for i in range(max(0,L-M),L):
        if a[i:i+M] == sub: return i
    return -1

In [None]:
#| export
@patch
def _crane_generate_unconstrained(self: Llama, s1: str, max_toks: int, temperature: float, stop: List[str]) -> Generator[str, None, Tuple[List[int], str]]:
    """
    Generates text tokens by token in unconstrained mode.
    Yields: Detokenized text for each token
    Returns: A tuple of (all_generated_tokens, stop_reason)
    """
    s1_toks = self.tokenize(s1.encode("utf-8"), add_bos=False, special=True)
    stop_seq_toks = [self.tokenize(s.encode("utf-8"), add_bos=False, special=True) for s in stop if s]
    
    gen_toks: List[int] = []
    gen_text: str = ""

    for _ in range(max_toks): 
        tok = self.sample(temp=temperature)
        if tok == self.token_eos(): return gen_toks, "eos"

        gen_toks.append(tok)
        chunk = self.detokenize([tok], special=False).decode("utf-8", errors="ignore")
        gen_text += chunk; yield chunk

        for stop_seq in stop_seq_toks:
            seq_len = len(stop_seq)
            if len(gen_toks) >= seq_len and gen_toks[-seq_len:] == stop_seq: return gen_toks[:-seq_len], "stop"
        if len(gen_toks) >= len(s1_toks) and gen_toks[-len(s1_toks):] == s1_toks: return gen_toks, "s1"
        self.eval([tok])
    return gen_toks, "length"

In [None]:
#| export
@patch
def _crane_generate_constrained(self: Llama, schema: Type[BaseModel], s2_toks: List[int], max_toks:int, prefix_toks_after_s1: List[int] | None = None) -> Generator[str, None, List[int]]:
    """
        Generates text token by token in constrained mode.
        Yields: Detokenized text for each token.
        Returns: The list of tokens generated in this phase
    """
    matcher, bitmask = self._make_xgr_matcher(schema)
    
    sampler = self._make_xgr_sampler(matcher, bitmask)
    ctx = self._get_ctx()

    if prefix_toks_after_s1:
        for t in prefix_toks_after_s1: matcher.accept_token(t)
    
    generated_toks_in_phase: List[int] = []

    try:
        for _ in range(max_toks):
            tok = sampler.sample(ctx)
            if tok == self.token_eos() or matcher.is_terminated(): break
            generated_toks_in_phase.append(tok)

            sampler.accept(tok)
            matcher.accept_token(tok) # advance grammar for next step
            self.eval([tok]) # advance model KV    

            yield self.detokenize([tok], special=False).decode("utf-8", errors="ignore")

            if len(generated_toks_in_phase) >= len(s2_toks) and generated_toks_in_phase[-len(s2_toks):] == s2_toks: break
    finally:
        matcher.reset()
        sampler.reset()

    return generated_toks_in_phase
    

In [None]:
# #| export
# from llama_cpp import LogitsProcessorList
from llama_cpp.llama_types import ChatCompletionResponseChoice, ChatCompletionResponseMessage

In [None]:
#| export
from llama_cpp.llama_chat_format import Jinja2ChatFormatter

In [None]:
?Jinja2ChatFormatter

[0;31mInit signature:[0m
[0mJinja2ChatFormatter[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtemplate[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meos_token[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbos_token[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0madd_generation_prompt[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstop_token_ids[0m[0;34m:[0m [0;34m'Optional[List[int]]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Base Protocol for a chat formatter. A chat formatter is a function that
takes a list of messages and returns a chat format response which can be used
to generate a completion. The response can also include a stop token or list
of stop tokens to use for the completion.
[0;31mInit docstring:[0m A chat formatter that uses jinja2

In [None]:
#| export
@patch
def create_crane_chat_completion(self: Llama, messages: List[Dict[str,str]], schema=Type[BaseModel],
                                 s1:str = "<<JSON>>", s2:str="</JSON>>", temperature:float=0.0,
                                 max_tokens_unconstrained:int=3072,max_tokens_constrained:int=4096, 
                                 stop: List[str] = []) -> CreateCRANEChatCompletionResponse: 
        """
        CRANE: Constrained + Unconstrained switching for structured output.
        
        Generates text until `s1` delimiter (unconstrained), then switches to
        constrained JSON generation matching `pydantic_model` schema until `s2`.
        """
        # reset model state
        self.reset()
        etok = self._model.token_get_text(self.token_eos())
        btok = self._model.token_get_text(self.token_bos())
        formatter = Jinja2ChatFormatter(template=self.metadata['tokenizer.chat_template'],
                                        eos_token=etok, bos_token=btok)

        resp = formatter(messages=messages)
        f_prmpt: str = resp.prompt
        
        prmpt_toks = self.tokenize(f_prmpt.encode("utf-8"), add_bos=True)
        self.eval(prmpt_toks) # evaluate initial prompt
        
        # tokenize delimiters and stops
        s1_toks = self.tokenize(s1.encode("utf-8"), add_bos=False, special=True)
        s2_toks = self.tokenize(s2.encode("utf-8"), add_bos=False, special=True)
        
        # Phase 1. Unconstrained generation until s1 or max_tokens
        unc_gen = self._crane_generate_unconstrained(
                s1=s1, max_toks=max_tokens_unconstrained, temperature=temperature,
                stop=stop)
        unc_text, unc_toks, stop_reason = "", [], "error"
        while True:
                try: chunk = next(unc_gen)
                except StopIteration as e:
                        if e.value is not None: unc_toks, stop_reason = e.value
                        break
                unc_text += chunk
        
        if stop_reason != 's1':
                if s1 in unc_text: 
                        self.eval(s1_toks)
                        stop_reason = "s1"
                elif stop_reason == "eos":
                        self.eval(s1_toks)
                        unc_text += s1
                        stop_reason = "s1"
                else:
                        raise AssertionError(f"s1 delimiter '{s1}' not found in unconstrained phase.\n{unc_text}\n{stop_reason}")

        s1_pos = _find_subseq(unc_toks, s1_toks)
        if s1_pos == -1: prefix_toks_after_s1 = s1_toks
        else: prefix_toks_after_s1 = unc_toks[s1_pos:]
        self.eval(s1_toks)
        
        # Phase 2. Constrained generation
        con_gen = self._crane_generate_constrained(schema=schema, s2_toks=s2_toks, 
                                                              max_toks=max_tokens_constrained, prefix_toks_after_s1=prefix_toks_after_s1)
        jtxt, con_toks = "", []
        while True:
                try: chunk = next(con_gen)
                except StopIteration as e:
                        if e.value is not None: con_toks = e.value
                        break
                jtxt +=chunk
                
        unc_text += jtxt
        unc_text += s2
        #assert unc_text.endswith(s2), f"s2 delimiter '{s2}' not found at end of constrained phase.\nJTXT: {jtxt}"

        try:
                pjson = schema.model_validate_json(jtxt)
        except Exception as e: raise ValueError(f"Failed to parse generated JSON: {jtxt}\nError:{e}")
        
        cc = ChatCompletionResponseChoice(index=0,
                                message=ChatCompletionResponseMessage(role="assistant", content=jtxt), finish_reason="stop")
        
        return CreateCRANEChatCompletionResponse(id="crane-"+str(id(self)), object="crane.chat.completion",
                                                 completion=cc, json=pjson) 

The KV cache in a decoder-only transformer is analogous to the encoder output in an encoder–decoder model — it’s the stored representation of the prefix context that subsequent tokens attend to.

Let's test this out on our similar use case from above.

In [None]:
from typing import Literal
from pydantic import BaseModel, Field

class HistologicType(BaseModel):
    histologic_type: Literal["Keritinzation", "None"]

In [None]:
HistologicType.model_json_schema()

{'properties': {'histologic_type': {'enum': ['Keritinzation', 'None'],
   'title': 'Histologic Type',
   'type': 'string'}},
 'required': ['histologic_type'],
 'title': 'HistologicType',
 'type': 'object'}

We don't include `s1` and `s2` in the user messages anymore because they're control delimiters, not part of the natural language exchange. The orchestrator appends `s1` at runtime right before switching to grammar-constraned decoding, and later expects `s2` to mark the end 

In [None]:
messages = [
    {"role": "system", "content": "You are a pathology assistant."},
    {"role": "user", "content": "Classify the histologic type from the following report: Microscopic examination reveals nests of atypical squamous cells with keratinization, diagnostic of squamous cell carcinoma."}
]

We can see if a given model has a defined chat_template

In [None]:
llm.metadata['tokenizer.chat_template']

"{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

In [None]:
llm._model.token_get_text(llm.token_bos())

'<s>'

In [None]:
response = llm.create_crane_chat_completion(
    messages=messages,
    schema=HistologicType,
    s1="<<JSON>>",
    s2="</JSON>>",
    temperature=0.0,
    max_tokens_unconstrained=512,
    max_tokens_constrained=256
)

In [None]:
response

{'id': 'crane-138206126811712',
 'object': 'crane.chat.completion',
 'completion': {'index': 0,
  'message': {'role': 'assistant', 'content': '{"histologic_type": "None"}'},
  'finish_reason': 'stop'},
 'json': HistologicType(histologic_type='None')}

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()