In [1]:
from pathlib import Path
from typing import List, Tuple, Sequence, Dict
import numpy as np
import re
import os

from llama_cpp import Llama

#   notebook is in LlmStenoExplore/notebooks
REPO_ROOT = Path("..").resolve()

MODEL_REGISTRY = {
    "phi3_mini_q4": REPO_ROOT / "models/phi3/Phi-3-mini-4k-instruct-q4.gguf",
    "llama3_8b_q4_k_m": REPO_ROOT / "models/llama3_8b/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
}

def load_language_model(model_key: str) -> Llama:
    model_path = MODEL_REGISTRY[model_key]
    if not model_path.exists():
        raise FileNotFoundError(f"Model file not found: {model_path}")

    maximum_context_tokens = 8192 if "llama3" in model_key else 4096

    return Llama(
        model_path=str(model_path),
        n_ctx=maximum_context_tokens,
        n_gpu_layers=0,
        n_threads=os.cpu_count() or 4,
        n_batch=256,
        logits_all=True,
        verbose=False,
    )

llm = load_language_model("llama3_8b_q4_k_m")

In [2]:

def get_token_ranks_like_paper(
    text: str,
    model: Llama,
    prefix: str = "A text:",
) -> List[int]:
    """
    Token-level rank computation following the paper's recipe:

      1. Tokenize e and k' with the LLM tokenizer.
      2. For each token e_i, compute its rank among ALL vocab tokens
         under p(· | k', e_1,...,e_{i-1}).

    This mirrors the authors' get_token_ranks_llama_cpp, but is robust
    to empty prefix by injecting BOS when prefix == "".
    """
    # Tokenize prefix and drop BOS (to match notebook style when prefix != "")
    if prefix:
        prefix_ids = model.tokenize(prefix.encode("utf-8"))[1:]
    else:
        prefix_ids = []

    # Ensure text is valid UTF-8 and tokenize with leading space, drop BOS
    text = text.encode("utf-8", errors="ignore").decode("utf-8")
    text_ids = model.tokenize((" " + text).encode("utf-8"))[1:]

    model.reset()

    # If there is no prefix at all, use BOS as minimal context
    if prefix_ids:
        model.eval(prefix_ids)
    else:
        bos_id = model.token_bos()
        model.eval([bos_id])
        prefix_ids = [bos_id]

    ranks: List[int] = []

    # One rank per token in text_ids
    for token_id in text_ids:
        # logits for next token given current context
        logits = np.array(model.scores[model.n_tokens - 1], dtype=np.float32)

        # rank of token_id among all vocab entries (1-based)
        sorted_indices = np.argsort(logits)[::-1]
        positions = np.where(sorted_indices == token_id)[0]
        if positions.size == 0:
            raise RuntimeError(f"Token id {token_id} not found in logits")
        rank = int(positions[0]) + 1
        ranks.append(rank)

        # extend context with this token
        model.eval([token_id])

    return ranks


def decode_from_ranks_like_paper(
    prompt: str,
    ranks: List[int],
    model: Llama,
) -> str:
    """
    Token-level decoder matching the paper's scheme:

      - Tokenize prompt k or k' with the model tokenizer.
      - Reset the model and eval the prompt tokens.
      - For each rank r_i:
          * get logits for next token given current context
          * pick the r_i-th most probable token
          * feed it and append to the sequence
      - Detokenize and strip the prompt text from the front if present.
    """
    # Tokenize prompt and drop BOS (same pattern as authors' code)
    if prompt:
        prompt_ids = model.tokenize(prompt.encode("utf-8"))[1:]
    else:
        prompt_ids = []
        # For robustness, inject BOS when prompt is empty
        bos_id = model.token_bos()
        prompt_ids = [bos_id]

    model.reset()
    model.eval(prompt_ids)

    generated_ids = list(prompt_ids)

    for rank in ranks:
        # logits for next token given current context
        logits = np.array(model.scores[model.n_tokens - 1], dtype=np.float32)

        sorted_indices = np.argsort(logits)[::-1]
        if rank < 1 or rank > len(sorted_indices):
            raise ValueError(
                f"Rank {rank} out of range for vocabulary size {len(sorted_indices)}"
            )

        next_token_id = int(sorted_indices[rank - 1])
        generated_ids.append(next_token_id)

        # advance context
        model.eval([next_token_id])

    decoded_bytes = model.detokenize(generated_ids)
    decoded_text = decoded_bytes.decode("utf-8", errors="ignore")

    # Remove the prompt prefix from the visible text if present
    if prompt and decoded_text.startswith(prompt):
        decoded_text = decoded_text[len(prompt):].lstrip()

    return decoded_text


def hide_text_token_level(
    secret_text: str,
    secret_prefix: str,
    secret_key: str,
    model: Llama = llm,
) -> Tuple[str, List[int]]:
    """
    Encode pipeline (e -> ranks -> stegotext):

      1. Compute ranks for secret_text e after prefix k'.
      2. Generate stegotext s from key k by following those ranks.
    """
    ranks = get_token_ranks_like_paper(
        text=secret_text,
        model=model,
        prefix=secret_prefix,
    )
    stegotext = decode_from_ranks_like_paper(
        prompt=secret_key,
        ranks=ranks,
        model=model,
    )
    return stegotext, ranks


def reveal_text_token_level(
    stegotext: str,
    secret_prefix: str,
    secret_key: str,
    model: Llama = llm,
) -> str:
    """
    Decode pipeline (s -> ranks -> e):

      1. From stegotext s and key k, recover the same ranks.
      2. From those ranks and prefix k', reconstruct e.
    """
    recovered_ranks = get_token_ranks_like_paper(
        text=stegotext,
        model=model,
        prefix=secret_key,
    )
    recovered_text = decode_from_ranks_like_paper(
        prompt=secret_prefix,
        ranks=recovered_ranks,
        model=model,
    )
    return recovered_text


In [6]:
def run_example(secret_text: str, secret_prefix: str, secret_key: str, model: Llama = llm) -> None:
    """
    Run one full encode/decode example and log everything consistently:
      - secret text
      - ranks_e and their length
      - stegotext
      - token counts for secret and stego (same tokenization as get_token_ranks_like_paper)
      - recovered text and equality check
    """
    print("Secret text e:")
    print(secret_text)
    print()

    # Encode: e -> (ranks_e) -> stegotext
    stegotext, ranks_e = hide_text_token_level(
        secret_text=secret_text,
        secret_prefix=secret_prefix,
        secret_key=secret_key,
        model=model,
    )

    print("ranks_e (len = {}):".format(len(ranks_e)))
    print(ranks_e)
    print()

    print("Stegotext s:")
    print(stegotext)
    print()

    # Use the *same* tokenization scheme as get_token_ranks_like_paper:
    # tokenize((" " + text).encode("utf-8"))[1:]
    secret_token_ids = model.tokenize((" " + secret_text).encode("utf-8"))[1:]
    stego_token_ids  = model.tokenize((" " + stegotext).encode("utf-8"))[1:]

    print("Secret tokens:", len(secret_token_ids))
    print("Stego tokens :", len(stego_token_ids))
    print("len(ranks_e) :", len(ranks_e))
    print()

    # Sanity checks (will raise if something is inconsistent)
    assert len(secret_token_ids) == len(ranks_e), "Token count for e does not match len(ranks_e)"
    assert len(stego_token_ids)  == len(ranks_e), "Token count for s does not match len(ranks_e)"

    # Decode: s -> (ranks) -> e
    recovered_text = reveal_text_token_level(
        stegotext=stegotext,
        secret_prefix=secret_prefix,
        secret_key=secret_key,
        model=model,
    )

    print("Recovered e:")
    print(recovered_text)
    print("Recovered == secret_text:", recovered_text == secret_text)
    print("-" * 60)


In [7]:
# Example 1
secret_text_1  = "THE CURRENT SYSTEM HAS REPEATEDLY FAILED"
secret_prefix_1 = "A text:"   # k'
secret_key_1    = "Here it is: the infamous British roasted boar with mint sauce. How to make it perfect."

run_example(secret_text_1, secret_prefix_1, secret_key_1, model=llm)

# Example 2
secret_text_2  = "The cats like to meow all the time. It is annoying."
secret_prefix_2 = "Text:"  # k'
secret_key_2    = "There is a big jungle in Brazil."  # k

run_example(secret_text_2, secret_prefix_2, secret_key_2, model=llm)


Secret text e:
THE CURRENT SYSTEM HAS REPEATEDLY FAILED

ranks_e (len = 9):
[164, 639, 21, 10, 10, 17, 1, 1, 1]

Stegotext s:
Get sufficient roas tting time. The

Secret tokens: 9
Stego tokens : 9
len(ranks_e) : 9

Recovered e:
THE CURRENT SYSTEM HAS REPEATEDLY FAILED
Recovered == secret_text: True
------------------------------------------------------------
Secret text e:
The cats like to meow all the time. It is annoying.

ranks_e (len = 14):
[2, 16879, 67, 1, 81, 1, 25, 1, 1, 1, 3, 2, 6, 1]

Stegotext s:
Thepurpose…
ThereOnceWasASmallJungleInTheAmazonRiverBas

Secret tokens: 14
Stego tokens : 14
len(ranks_e) : 14

Recovered e:
The cats like to meow all the time. It is annoying.
Recovered == secret_text: True
------------------------------------------------------------


In [17]:
#BPE