In [1]:
import json
import os
from pathlib import Path
import math

import altair as alt
import pandas as pd
from omegaconf import OmegaConf
from pydantic import BaseModel

import torch
import torch.nn.functional as F


In [2]:
class PlotEntropiesConfig(BaseModel):
    data_path: str | None
    chart_path: str
    score_override_path: str | None = None
    threshold_override: float | None = None

    class Config:
        extra = "forbid"


class PlotEntropiesData(BaseModel):
    text: str
    threshold: float = 1.335442066192627
    dataframe_json: str | None

    class Config:
        extra = "forbid"

In [3]:
def text_to_tokens(text, bos_id=257, eos_id=258):
    """
    Tokenizes the input text into UTF-8 bytes and adds a BOS token (257)
    at the beginning and an EOS token (258) at the end.
    """
    token_list = list(text.encode("utf-8"))
    return [bos_id] + token_list + [eos_id]


sample_text = "Hello, world!"
print("Tokens for 'Hello, world!':", text_to_tokens(sample_text))

Tokens for 'Hello, world!': [257, 72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 258]


In [4]:
def entropy(scores):
    """
    Computes per-token entropy (using natural log) from logits.
    Input:
      scores: Tensor of shape [bs, seq_len, vocab]
    Returns:
      Tensor of shape [bs, seq_len] with entropy values.
    """
    log_probs = F.log_softmax(scores, dim=-1)
    probs = torch.exp(log_probs)
    p_log_p = log_probs * probs
    return -p_log_p.sum(dim=-1)

In [5]:
def meta_style_patch_start_ids(entropies: torch.Tensor, threshold: float, include_next_token: bool = False) -> torch.Tensor:
    """
    Given a tensor of entropies (for tokens[1:] of the input), compute patch start indices as per Meta's logic:
      - Force patch starts at token indices 0 and 1.
      - For subsequent tokens, flag a patch start if the entropy > threshold.
      
    Parameters:
      entropies: Tensor of shape [seq_len-1] (corresponding to tokens[1:]).
      threshold: Threshold for flagging a patch start.
      include_next_token: Whether to include the final token.
      
    Returns:
      A tensor of patch start indices.
    """
    # Force patch starts at positions 0 and 1.
    first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
    # entropies corresponds to tokens[1:].
    patch_start_mask = (entropies > threshold)
    if not include_next_token and patch_start_mask.numel() > 1:
        patch_start_mask = patch_start_mask[:-1]  # Optionally drop the final token.
    patch_indices = torch.nonzero(patch_start_mask, as_tuple=False).squeeze(1)
    patch_indices = patch_indices + 1  # Shift indices to align with the original tokens.
    patch_start_ids = torch.cat((first_ids, patch_indices))
    patch_start_ids = torch.unique_consecutive(patch_start_ids)
    return patch_start_ids

def patch_lengths_from_start_ids(patch_start_ids: torch.Tensor, seq_len: int) -> list:
    """
    Given patch start indices and the total sequence length, compute patch lengths.
    
    For example, if patch_start_ids = [0, 5, 10] and seq_len = 15,
    then patch lengths are [5, 5, 5].
    """
    patch_ids = patch_start_ids.tolist()
    patch_ids.append(seq_len)
    lengths = []
    for i in range(len(patch_ids)-1):
        lengths.append(patch_ids[i+1] - patch_ids[i])
    return lengths

In [6]:
def compute_entropy_and_patches(text: str, threshold: float) -> pd.DataFrame:
    """
    Loads the entropy model, tokenizes the input text (adding BOS/EOS tokens),
    performs a forward pass to compute next-token logits, computes per-token
    entropy (in nats), and determines patch-start flags using Meta's patching logic.
    
    Alignment:
      - The point at row i in the DataFrame represents the entropy used to predict token i+1.
      - For i in [0, ..., N-2]: entropy is computed; for the final token (EOS, row N-1) we predict 0.
      - patch_start[0] = 1 (BOS always starts a patch).
      - If entropy[i] > threshold, then patch_start[i+1] = 1.
    
    Also, when converting tokens to strings:
      - BOS token (257) is shown as "<"
      - EOS token (258) is shown as ">"
    
    Returns a DataFrame with columns:
      'position'  : token position (0..N-1, including BOS/EOS),
      'tokens'    : token (converted to a character, with BOS/EOS replaced),
      'entropies' : the entropy for predicting the next token (0 for EOS),
      'start'     : binary flag (1 indicates a patch start).
    """
    # --- Model loading ---
    model_config = OmegaConf.load("blt/configs/entropy.yaml")
    from blt.model.entropy import EntropyModel
    model = EntropyModel(model_config)
    model.eval()
    model.to("cuda")
    
    checkpoint_path = "checkpoints/blt-entropy-pile/checkpoint_latest.pt"
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location="cuda")
    model.load_state_dict(ckpt["model_state_dict"], strict=True)
    
    # --- Tokenization ---
    tokens = text_to_tokens(text, bos_id=257, eos_id=258)
    N = len(tokens)  # total tokens (including BOS/EOS)
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to("cuda")  # Shape: [1, N]
    
    # --- Forward Pass ---
    with torch.no_grad():
        logits = model(input_ids)  # Shape: [1, N, vocab_size]
    
    # --- Next-token Predictions & Entropy ---
    # We skip the final token (no next-token prediction) so logits_next has shape [1, N-1, vocab_size].
    logits_next = logits[:, :-1, :]
    ent = entropy(logits_next).squeeze(0)  # Shape: [N-1]
    # ent[i] is the entropy used to predict token i+1.
    
    # --- Build DataFrame (N rows) ---
    token_strs = []
    for t in tokens:
        if t == 257:
            token_strs.append("<")
        elif t == 258:
            token_strs.append(">")
        else:
            try:
                token_strs.append(chr(t))
            except Exception:
                token_strs.append(str(t))
    
    entropies_list = []
    for i in range(N):
        if i < N - 1:
            e_val = float(ent[i])
        else:
            e_val = 0  # For the final token (EOS) predict 0.
        entropies_list.append(e_val)
    
    # --- Patch Start Determination ---
    # patch_start[0] = 1 (BOS always starts a patch)
    # If ent[i] > threshold, then patch_start[i+1] = 1.
    patch_start = [0] * N
    patch_start[0] = 1
    for i in range(N - 1):
        if ent[i] > threshold:
            patch_start[i + 1] = 1
    
    df = pd.DataFrame({
        "position": list(range(N)),
        "tokens": token_strs,
        "entropies": entropies_list,
        "start": patch_start,
    })
    
    patch_ids = meta_style_patch_start_ids(ent, threshold, include_next_token=False)
    patch_lengths = patch_lengths_from_start_ids(patch_ids, seq_len=N)
    print("Computed patch lengths:", patch_lengths)
    print("Patch start indices:", patch_ids.tolist())
    print("Mean patch length:", sum(patch_lengths) / len(patch_lengths))
    
    return df

In [7]:
sample_sentence = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
threshold = 2.5  # Meta's documented threshold
df = compute_entropy_and_patches(sample_sentence, threshold)
print("Computed DataFrame")

# Create x-axis tick labels combining position and token.
x_ticks = []
for row in df.itertuples():
    pos = row.position
    token = row.tokens
    x_ticks.append(f"{str(pos).zfill(3)}|{token}")
df["position_with_token"] = x_ticks

# Configure the Altair x-axis to split the label at "|" and show only the token.
x_axis = alt.Axis(
    labelExpr="split(datum.label, '|')[1]",
    grid=False,
    labelOverlap=False,
    labelAngle=0,
)

width = 1200
height = 150
base = alt.Chart(df).properties(width=width, height=height)
points = base.mark_line(point=True).encode(
    x=alt.X("position_with_token:O", title=None, axis=x_axis),
    y=alt.Y("entropies", title="Entropy of Next Token"),
)
# Draw a horizontal rule at the threshold.
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
    y=alt.datum(threshold),
)
# Draw vertical dashed lines at patch-start positions.
patch_rules = (
    alt.Chart(df[df["start"] > 0])
    .mark_rule(color="#474747", strokeDash=[4, 2])
    .encode(x=alt.X("position_with_token:O", axis=x_axis))
)

chart = patch_rules + rule + points
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)

output_path = Path("our_chart.png")
output_path.parent.mkdir(exist_ok=True)
chart.save(str(output_path))
print(f"Chart saved to {output_path}")

print(f"threshold: {threshold}")
print(f"Mean entropy (our model): {df['entropies'].mean()}")
print(f"Total patches (our model): {sum(df['start'])}")
chart  # Display the chart inline.

Computed patch lengths: [1, 3, 1, 2, 3, 2, 1, 2, 1, 1, 3, 6, 2, 3, 3, 9, 2, 8, 5, 3, 7, 5, 8]
Patch start indices: [0, 1, 4, 5, 7, 10, 12, 13, 15, 16, 17, 20, 26, 28, 31, 34, 43, 45, 53, 58, 61, 68, 73]
Mean patch length: 3.5217391304347827
Computed DataFrame
Chart saved to our_chart.png
threshold: 2.5
Mean entropy (our model): 1.653603274697139
Total patches (our model): 23


## Expected Results
#### Meta's Documented Patch Example

In [8]:
import json
import pandas as pd
import altair as alt
from pathlib import Path

meta_json_str = r'''{"text":"Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.","threshold":1.335442066192627,"dataframe_json":"{\"position\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":6,\"7\":7,\"8\":8,\"9\":9,\"10\":10,\"11\":11,\"12\":12,\"13\":13,\"14\":14,\"15\":15,\"16\":16,\"17\":17,\"18\":18,\"19\":19,\"20\":20,\"21\":21,\"22\":22,\"23\":23,\"24\":24,\"25\":25,\"26\":26,\"27\":27,\"28\":28,\"29\":29,\"30\":30,\"31\":31,\"32\":32,\"33\":33,\"34\":34,\"35\":35,\"36\":36,\"37\":37,\"38\":38,\"39\":39,\"40\":40,\"41\":41,\"42\":42,\"43\":43,\"44\":44,\"45\":45,\"46\":46,\"47\":47,\"48\":48,\"49\":49,\"50\":50,\"51\":51,\"52\":52,\"53\":53,\"54\":54,\"55\":55,\"56\":56,\"57\":57,\"58\":58,\"59\":59,\"60\":60,\"61\":61,\"62\":62,\"63\":63,\"64\":64,\"65\":65,\"66\":66,\"67\":67,\"68\":68,\"69\":69,\"70\":70,\"71\":71,\"72\":72,\"73\":73,\"74\":74,\"75\":75,\"76\":76,\"77\":77,\"78\":78,\"79\":79,\"80\":80},\"tokens\":{\"0\":\"<\",\"1\":\"D\",\"2\":\"a\",\"3\":\"e\",\"4\":\"n\",\"5\":\"e\",\"6\":\"r\",\"7\":\"y\",\"8\":\"s\",\"9\":\"_\",\"10\":\"T\",\"11\":\"a\",\"12\":\"r\",\"13\":\"g\",\"14\":\"a\",\"15\":\"r\",\"16\":\"y\",\"17\":\"e\",\"18\":\"n\",\"19\":\"_\",\"20\":\"i\",\"21\":\"s\",\"22\":\"_\",\"23\":\"i\",\"24\":\"n\",\"25\":\"_\",\"26\":\"G\",\"27\":\"a\",\"28\":\"m\",\"29\":\"e\",\"30\":\"_\",\"31\":\"o\",\"32\":\"f\",\"33\":\"_\",\"34\":\"T\",\"35\":\"h\",\"36\":\"r\",\"37\":\"o\",\"38\":\"n\",\"39\":\"e\",\"40\":\"s\",\"41\":\",\",\"42\":\"_\",\"43\":\"a\",\"44\":\"_\",\"45\":\"f\",\"46\":\"a\",\"47\":\"n\",\"48\":\"t\",\"49\":\"a\",\"50\":\"s\",\"51\":\"y\",\"52\":\"_\",\"53\":\"e\",\"54\":\"p\",\"55\":\"i\",\"56\":\"c\",\"57\":\"_\",\"58\":\"b\",\"59\":\"y\",\"60\":\"_\",\"61\":\"G\",\"62\":\"e\",\"63\":\"o\",\"64\":\"r\",\"65\":\"g\",\"66\":\"e\",\"67\":\"_\",\"68\":\"R\",\"69\":\".\",\"70\":\"R\",\"71\":\".\",\"72\":\"_\",\"73\":\"M\",\"74\":\"a\",\"75\":\"r\",\"76\":\"t\",\"77\":\"i\",\"78\":\"n\",\"79\":\".\",\"80\":\">\"},\"token_ids\":{\"0\":1,\"1\":72,\"2\":101,\"3\":105,\"4\":114,\"5\":105,\"6\":118,\"7\":125,\"8\":119,\"9\":36,\"10\":88,\"11\":101,\"12\":118,\"13\":107,\"14\":101,\"15\":118,\"16\":125,\"17\":105,\"18\":114,\"19\":36,\"20\":109,\"21\":119,\"22\":36,\"23\":109,\"24\":114,\"25\":36,\"26\":75,\"27\":101,\"28\":113,\"29\":105,\"30\":36,\"31\":115,\"32\":106,\"33\":36,\"34\":88,\"35\":108,\"36\":118,\"37\":115,\"38\":114,\"39\":105,\"40\":119,\"41\":48,\"42\":36,\"43\":101,\"44\":36,\"45\":106,\"46\":101,\"47\":114,\"48\":120,\"49\":101,\"50\":119,\"51\":125,\"52\":36,\"53\":105,\"54\":116,\"55\":109,\"56\":103,\"57\":36,\"58\":102,\"59\":125,\"60\":36,\"61\":75,\"62\":105,\"63\":115,\"64\":118,\"65\":107,\"66\":105,\"67\":36,\"68\":86,\"69\":50,\"70\":86,\"71\":50,\"72\":36,\"73\":81,\"74\":101,\"75\":118,\"76\":120,\"77\":109,\"78\":114,\"79\":50,\"80\":2},\"entropies\":{\"0\":3.3949158192,\"1\":2.1656451225,\"2\":2.3216569424,\"3\":2.8214058876,\"4\":1.5249242783,\"5\":0.0401624143,\"6\":0.0981037766,\"7\":0.0544578359,\"8\":0.3430138826,\"9\":1.0546212196,\"10\":0.25252828,\"11\":0.1494535804,\"12\":0.0624754503,\"13\":0.001355894,\"14\":0.0050173439,\"15\":0.0052358187,\"16\":0.0011725067,\"17\":0.0010307421,\"18\":1.0241208076,\"19\":3.6867966652,\"20\":0.4502205253,\"21\":0.0484119244,\"22\":2.2572875023,\"23\":0.3789347112,\"24\":1.0042934418,\"25\":2.9090054035,\"26\":1.8933598995,\"27\":1.3859074116,\"28\":0.3827198744,\"29\":0.2646365762,\"30\":1.7742085457,\"31\":0.0136727821,\"32\":0.0053820172,\"33\":0.5485631227,\"34\":0.2064044327,\"35\":0.0049266233,\"36\":0.0005439016,\"37\":0.0007023578,\"38\":0.0004170335,\"39\":0.0054524317,\"40\":1.1938130856,\"41\":0.0238215197,\"42\":3.1279797554,\"43\":1.3883389235,\"44\":3.0503094196,\"45\":1.695879817,\"46\":1.8551058769,\"47\":1.4570231438,\"48\":0.0047810897,\"49\":0.026396824,\"50\":0.6633765101,\"51\":0.3141393065,\"52\":2.8411159515,\"53\":1.143143177,\"54\":0.0520330966,\"55\":0.3398066461,\"56\":0.4140175879,\"57\":2.5563707352,\"58\":1.3370712996,\"59\":0.0227173548,\"60\":3.4447185993,\"61\":1.8576486111,\"62\":0.8189754486,\"63\":0.6776530743,\"64\":0.0677763447,\"65\":0.212713033,\"66\":0.1003480032,\"67\":0.1746164262,\"68\":0.4123829603,\"69\":0.5507118702,\"70\":0.1047425047,\"71\":0.0194335245,\"72\":0.001482119,\"73\":0.0009310447,\"74\":0.0002176317,\"75\":0.0076908777,\"76\":0.0003866984,\"77\":0.0008008487,\"78\":1.2395234108,\"79\":0.4564163089,\"80\":0.0000461392},\"patch\":{\"0\":0,\"1\":1,\"2\":2,\"3\":3,\"4\":4,\"5\":5,\"6\":5,\"7\":5,\"8\":5,\"9\":5,\"10\":5,\"11\":5,\"12\":5,\"13\":5,\"14\":5,\"15\":5,\"16\":5,\"17\":5,\"18\":5,\"19\":5,\"20\":6,\"21\":6,\"22\":6,\"23\":7,\"24\":7,\"25\":7,\"26\":8,\"27\":9,\"28\":10,\"29\":10,\"30\":10,\"31\":11,\"32\":11,\"33\":11,\"34\":11,\"35\":11,\"36\":11,\"37\":11,\"38\":11,\"39\":11,\"40\":11,\"41\":11,\"42\":11,\"43\":12,\"44\":13,\"45\":14,\"46\":15,\"47\":16,\"48\":17,\"49\":17,\"50\":17,\"51\":17,\"52\":17,\"53\":18,\"54\":18,\"55\":18,\"56\":18,\"57\":18,\"58\":19,\"59\":20,\"60\":20,\"61\":21,\"62\":22,\"63\":22,\"64\":22,\"65\":22,\"66\":22,\"67\":22,\"68\":22,\"69\":22,\"70\":22,\"71\":22,\"72\":22,\"73\":22,\"74\":22,\"75\":22,\"76\":22,\"77\":22,\"78\":22,\"79\":22,\"80\":22},\"start\":{\"0\":1,\"1\":1,\"2\":1,\"3\":1,\"4\":1,\"5\":1,\"6\":0,\"7\":0,\"8\":0,\"9\":0,\"10\":0,\"11\":0,\"12\":0,\"13\":0,\"14\":0,\"15\":0,\"16\":0,\"17\":0,\"18\":0,\"19\":0,\"20\":1,\"21\":0,\"22\":0,\"23\":1,\"24\":0,\"25\":0,\"26\":1,\"27\":1,\"28\":1,\"29\":0,\"30\":0,\"31\":1,\"32\":0,\"33\":0,\"34\":0,\"35\":0,\"36\":0,\"37\":0,\"38\":0,\"39\":0,\"40\":0,\"41\":0,\"42\":0,\"43\":1,\"44\":1,\"45\":1,\"46\":1,\"47\":1,\"48\":1,\"49\":0,\"50\":0,\"51\":0,\"52\":0,\"53\":1,\"54\":0,\"55\":0,\"56\":0,\"57\":0,\"58\":1,\"59\":1,\"60\":0,\"61\":1,\"62\":1,\"63\":0,\"64\":0,\"65\":0,\"66\":0,\"67\":0,\"68\":0,\"69\":0,\"70\":0,\"71\":0,\"72\":0,\"73\":0,\"74\":0,\"75\":0,\"76\":0,\"77\":0,\"78\":0,\"79\":0,\"80\":0}}"}'''

# --- Parse the JSON ---
data = json.loads(meta_json_str)
df = pd.read_json(data["dataframe_json"])
print("Loaded DataFrame:")

# --- Set the threshold ---
threshold = data["threshold"]

# --- Create x-axis tick labels combining position and token ---
x_ticks = []
for row in df.itertuples():
    pos = row.position
    token = row.tokens
    x_ticks.append(f"{str(pos).zfill(3)}|{token}")
df["position_with_token"] = x_ticks

# --- Define the Altair x-axis ---
x_axis = alt.Axis(
    labelExpr="split(datum.label, '|')[1]",
    grid=False,
    labelOverlap=False,
    labelAngle=0,
)

# --- Create the chart ---
width = 1200
height = 150
base = alt.Chart(df).properties(width=width, height=height)

points = base.mark_line(point=True).encode(
    x=alt.X("position_with_token:O", title=None, axis=x_axis),
    y=alt.Y("entropies", title="Entropy of Next Byte"),
)

rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
    y=alt.datum(threshold),
)

patch_rules = (
    alt.Chart(df[df["start"] > 0])
    .mark_rule(color="#474747", strokeDash=[4, 2])
    .encode(x=alt.X("position_with_token:O", axis=x_axis))
)

chart = patch_rules + rule + points
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)

output_path = Path("meta_chart.png")
output_path.parent.mkdir(exist_ok=True)
chart.save(str(output_path))
print(f"Chart saved to {output_path}")

print(f"threshold: {threshold}")
print(f"Mean entropy (metas model): {df['entropies'].mean()}")
print(f"Total patches (metas model): {df['patch'].max()}")
chart

Loaded DataFrame:
Chart saved to meta_chart.png
threshold: 1.335442066192627
Mean entropy (metas model): 0.8172790294580247
Total patches (metas model): 22


  df = pd.read_json(data["dataframe_json"])
