In [69]:
from pathlib import Path
from typing import List, Tuple, Sequence, Dict
import numpy as np
import re
import os

from llama_cpp import Llama

#   notebook is in LlmStenoExplore/notebooks
REPO_ROOT = Path("..").resolve()

MODEL_REGISTRY = {
    "phi3_mini_q4": REPO_ROOT / "models/phi3/Phi-3-mini-4k-instruct-q4.gguf",
    "llama3_8b_q4_k_m": REPO_ROOT / "models/llama3_8b/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
}

def load_language_model(model_key: str) -> Llama:
    model_path = MODEL_REGISTRY[model_key]
    if not model_path.exists():
        raise FileNotFoundError(f"Model file not found: {model_path}")

    maximum_context_tokens = 8192 if "llama3" in model_key else 4096

    return Llama(
        model_path=str(model_path),
        n_ctx=maximum_context_tokens,
        n_gpu_layers=0,
        n_threads=os.cpu_count() or 4,
        n_batch=256,
        logits_all=True,
        verbose=False,
    )

llm = load_language_model("llama3_8b_q4_k_m")

In [64]:
msg_e = "The bus is late again this week"
msg_k = "Here it is: the infamous British roasted boar with mint sauce. How to make it perfect."


In [65]:
def simple_space_tokenize(text: str) -> List[str]:
    #return text.split()
    if not text:
        return []
    return [word.lower() for word in text.strip().split() if word]

e_words = simple_space_tokenize(msg_e)
k_words = simple_space_tokenize(msg_k)

print("msg_e:", repr(msg_e))
print("e_words:", e_words)
print()
print("msg_k:", repr(msg_k))
print("k_words:", k_words)


msg_e: 'The bus is late again this week'
e_words: ['the', 'bus', 'is', 'late', 'again', 'this', 'week']

msg_k: 'Here it is: the infamous British roasted boar with mint sauce. How to make it perfect.'
k_words: ['here', 'it', 'is:', 'the', 'infamous', 'british', 'roasted', 'boar', 'with', 'mint', 'sauce.', 'how', 'to', 'make', 'it', 'perfect.']


In [66]:
WORD_PATTERN = re.compile(r"[a-z]+")

def build_word_token_maps(model: Llama):
    word_to_token_id = {}
    token_id_to_word = {}
    candidate_token_ids = []

    vocab_size = model.n_vocab()
    for token_id in range(vocab_size):
        token_text = model.detokenize([token_id]).decode("utf-8", errors="ignore")
        token_clean = token_text.strip().lower()

        if not token_clean:
            continue
        if not WORD_PATTERN.fullmatch(token_clean):
            continue

        if token_clean not in word_to_token_id:
            word_to_token_id[token_clean] = token_id
            token_id_to_word[token_id] = token_clean
            candidate_token_ids.append(token_id)

    return word_to_token_id, token_id_to_word, np.array(candidate_token_ids, dtype=np.int32)

def get_word_ranks_with_vocab_map(
    words: List[str],
    model: Llama,
    word_to_token_id: dict,
    prefix_text: str = "",
) -> List[int]:
    model.reset()

    if prefix_text:
        prefix_ids = model.tokenize(prefix_text.encode("utf-8"), add_bos=True, special=False)
    else:
        prefix_ids = [model.token_bos()]

    model.eval(prefix_ids)
    ranks: List[int] = []

    for word in words:
        base = word.lower()
        if base not in word_to_token_id:
            raise KeyError(f"Word {word!r} not in word_to_token_id")

        target_token_id = word_to_token_id[base]

        logits = np.array(model.scores[model.n_tokens - 1], dtype=np.float32)
        sorted_indices = np.argsort(logits)[::-1]
        positions = np.where(sorted_indices == target_token_id)[0]
        rank = int(positions[0]) + 1
        ranks.append(rank)

        model.eval([target_token_id])

    return ranks


word_to_token_id, token_id_to_word, candidate_token_ids = build_word_token_maps(llm)
e_words = msg_e.strip().split()
ranks_e = get_word_ranks_with_vocab_map(e_words, llm, word_to_token_id, prefix_text="")
for w, r in zip(e_words, ranks_e):
    print(f"{w:>12} -> {r}")




         The -> 41941
         bus -> 1232
          is -> 1189
        late -> 752
       again -> 5
        this -> 18
        week -> 2


In [67]:
# --- Word-level rank utilities (global maps + simple API) ---

import re
from typing import List, Dict, Tuple
import numpy as np
from llama_cpp import Llama  # only needed for type hints

# Match "plain" words like: government, over, spent
WORD_PATTERN = re.compile(r"[a-z]+")

# Global caches so we only scan the vocabulary once
WORD_TO_TOKEN_ID: Dict[str, int] = {}
TOKEN_ID_TO_WORD: Dict[int, str] = {}
CANDIDATE_TOKEN_IDS: np.ndarray | None = None


def build_word_token_maps(model: Llama) -> Tuple[Dict[str, int], Dict[int, str], np.ndarray]:
    """
    Build (once) a mapping between lowercase words and model token ids.

    We keep only tokens whose detokenized form, after stripping whitespace
    and lowercasing, matches [a-z]+ exactly.

    Returns the three globals so you can grab them if you want, but the
    results are also cached in WORD_TO_TOKEN_ID, TOKEN_ID_TO_WORD, and
    CANDIDATE_TOKEN_IDS for reuse.
    """
    global WORD_TO_TOKEN_ID, TOKEN_ID_TO_WORD, CANDIDATE_TOKEN_IDS

    # Already built: just return cached values
    if WORD_TO_TOKEN_ID and CANDIDATE_TOKEN_IDS is not None:
        return WORD_TO_TOKEN_ID, TOKEN_ID_TO_WORD, CANDIDATE_TOKEN_IDS

    word_to_token_id: Dict[str, int] = {}
    token_id_to_word: Dict[int, str] = {}
    candidate_token_ids: List[int] = []

    vocabulary_size = model.n_vocab()
    print(f"Scanning vocabulary of size {vocabulary_size}...")

    for token_id in range(vocabulary_size):
        token_text = model.detokenize([token_id]).decode("utf-8", errors="ignore")
        token_clean = token_text.strip().lower()

        if not token_clean:
            continue
        if not WORD_PATTERN.fullmatch(token_clean):
            continue

        if token_clean not in word_to_token_id:
            word_to_token_id[token_clean] = token_id
            token_id_to_word[token_id] = token_clean
            candidate_token_ids.append(token_id)

    candidate_token_ids_array = np.array(candidate_token_ids, dtype=np.int32)

    WORD_TO_TOKEN_ID = word_to_token_id
    TOKEN_ID_TO_WORD = token_id_to_word
    CANDIDATE_TOKEN_IDS = candidate_token_ids_array

    print(f"Kept {len(candidate_token_ids_array)} word-like tokens.")

    return WORD_TO_TOKEN_ID, TOKEN_ID_TO_WORD, CANDIDATE_TOKEN_IDS


def get_word_ranks(
    words: List[str],
    prefix_text: str = "",
    model: Llama | None = None,
) -> List[int]:
    """
    Given a list of space-delimited words, compute the rank of each word's
    single-token representation under the model, feeding words sequentially.

    - Uses the cached WORD_TO_TOKEN_ID map (builds it on first call).
    - prefix_text is the k' context in the paper (can be empty).

    Returns:
        A list of integer ranks, one per word (1 = most probable).
    """
    if model is None:
        # fall back to the global llm defined earlier in the notebook
        global llm
        model = llm

    # Ensure vocabulary maps are ready
    build_word_token_maps(model)

    # Reset model state and prime with BOS or BOS+prefix
    model.reset()
    if prefix_text:
        prefix_ids = model.tokenize(
            prefix_text.encode("utf-8"),
            add_bos=True,
            special=False,
        )
    else:
        prefix_ids = [model.token_bos()]

    model.eval(prefix_ids)

    ranks: List[int] = []

    for word in words:
        base = word.lower()
        if base not in WORD_TO_TOKEN_ID:
            raise KeyError(f"Word {word!r} not in WORD_TO_TOKEN_ID map")

        target_token_id = WORD_TO_TOKEN_ID[base]

        # Logits for the next token given current context
        logits = np.array(model.scores[model.n_tokens - 1], dtype=np.float32)

        # Rank = position of target_token_id when sorting logits descending
        sorted_indices = np.argsort(logits)[::-1]
        positions = np.where(sorted_indices == target_token_id)[0]
        if positions.size == 0:
            raise RuntimeError(
                f"Token id {target_token_id} for word {word!r} not found in logits"
            )

        rank = int(positions[0]) + 1  # 1-based
        ranks.append(rank)

        # Advance the context by feeding this word's token
        model.eval([target_token_id])

    return ranks


In [68]:
word_to_token_id, token_id_to_word, candidate_token_ids = build_word_token_maps(llm)
e_words = msg_e.strip().split()
ranks_e = get_word_ranks_with_vocab_map(e_words, llm, word_to_token_id, prefix_text="")
for w, r in zip(e_words, ranks_e):
    print(f"{w:>12} -> {r}")


Scanning vocabulary of size 128256...
Kept 42661 word-like tokens.
         The -> 41941
         bus -> 1232
          is -> 1189
        late -> 752
       again -> 5
        this -> 18
        week -> 2
