# Language as a HMM

We try to learn an HMM description of English. 

In [1]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
test_text = ds["train"][0]['text']


One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


## Training tokenizers

In [13]:
from tokenizers import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast
import os

ds = load_dataset("roneneldan/TinyStories")

vocab_size = 4096
tok_name = 'custom_tokenizer_{}'.format(vocab_size)

# Create directory if it doesn't exist
os.makedirs(os.path.join('tokenizers', tok_name), exist_ok=True)

def text_iterator():
    for split in ds.values():
        for sample_text in split["text"]:
            yield sample_text

tok = ByteLevelBPETokenizer()
tok.train_from_iterator(
    text_iterator(),
    vocab_size=vocab_size,
    min_frequency=2,
    special_tokens=["<s>", "</s>", "<pad>", "<unk>"]
)
# Save the full tokenizer (creates tokenizer.json)
tok.save(os.path.join('tokenizers', tok_name, "tokenizer.json"))

# Create directory for HuggingFace tokenizer
hf_tok_name = "HF_custom_tokenizer_{}".format(vocab_size)
os.makedirs(os.path.join('tokenizers', hf_tok_name), exist_ok=True)

# Load from the saved tokenizer.json
hf_tok = PreTrainedTokenizerFast(
    tokenizer_file=os.path.join('tokenizers', tok_name, "tokenizer.json"),
    unk_token="<unk>",
    pad_token="<pad>",
    bos_token="<s>",
    eos_token="</s>",
)
hf_tok.save_pretrained(os.path.join('tokenizers', hf_tok_name))






('tokenizers/HF_custom_tokenizer_4096/tokenizer_config.json',
 'tokenizers/HF_custom_tokenizer_4096/special_tokens_map.json',
 'tokenizers/HF_custom_tokenizer_4096/tokenizer.json')

## Initial experiments

In [44]:
print([hf_tok.decode([hf_tok.encode("Hello, world!")[i]]) for i in range(len(hf_tok.encode("Hello, world!")))])

['Hello', ',', ' world', '!']


In [40]:
# word frequencies

from collections import Counter

sample_text = " ".join(ds["train"][:10000]['text'])
tokens = hf_tok.encode(sample_text)
token_counts = Counter(tokens)

import plotly.graph_objects as go

# Prepare data for plotting (excluding special tokens)
special_token_ids = [
    hf_tok.bos_token_id,
    hf_tok.eos_token_id,
    hf_tok.pad_token_id,
    hf_tok.unk_token_id,
]
token_ids = [token_id for token_id in token_counts.keys() if token_id not in special_token_ids]
frequencies = [token_counts[token_id] for token_id in token_ids]

frequencies = sorted(token_counts.values(), reverse=True)



In [39]:
# Compare token frequencies across three tokenizers with rescaling
import plotly.graph_objects as go
from transformers import PreTrainedTokenizerFast
from collections import Counter

# Load the three tokenizers
vocab_sizes = [1024, 2048, 4096]
scale_factors = {1024: 4, 2048: 2, 4096: 1}  # Rescaling factors
tokenizers = {}
for vocab_size in vocab_sizes:
    hf_tok_path = os.path.join('tokenizers', f'HF_custom_tokenizer_{vocab_size}')
    tokenizers[vocab_size] = PreTrainedTokenizerFast.from_pretrained(hf_tok_path)

# Sample text (same as before)
sample_text = " ".join(ds["train"][:10000]['text'])

# Create plotly figure
fig = go.Figure()

# Process each tokenizer
for vocab_size in vocab_sizes:
    tok = tokenizers[vocab_size]
    scale_factor = scale_factors[vocab_size]
    
    # Encode text and count tokens
    tokens = tok.encode(sample_text)
    token_counts = Counter(tokens)
    
    # Get special token IDs to exclude
    special_token_ids = {
        tok.bos_token_id,
        tok.eos_token_id,
        tok.pad_token_id,
        tok.unk_token_id,
    }
    
    # Filter out special tokens and sort by frequency
    token_freq_pairs = [
        (token_id, count) 
        for token_id, count in token_counts.items() 
        if token_id not in special_token_ids
    ]
    token_freq_pairs.sort(key=lambda x: x[1], reverse=True)
    
    # Prepare data for plotting
    token_ids = [pair[0] for pair in token_freq_pairs]
    frequencies = [pair[1] for pair in token_freq_pairs]
    
    
    # Decode tokens for hover text
    token_strings = [tok.decode([token_id]) for token_id in token_ids]
    
    # Scale x-coordinates
    x_coords = [i * scale_factor for i in range(len(frequencies))]
    
    # Create custom hover data showing token ID and decoded string
    customdata = list(zip(token_ids, token_strings))
    
    # Add trace
    fig.add_trace(go.Scatter(
        x=x_coords,
        y=frequencies,
        mode='lines',
        name=f'Vocab {vocab_size}',
        customdata=customdata,
        hovertemplate='<b>Vocab %{fullData.name}</b><br>' +
                      'Rank: %{x}<br>' +
                      'Token ID: %{customdata[0]}<br>' +
                      'Token: %{customdata[1]}<br>' +
                      'Frequency: %{y}<br>' +
                      '<extra></extra>',
    ))

# Update layout
fig.update_layout(
    title='Token Frequency Distribution Across Tokenizers (Rescaled)',
    xaxis_title='Token Rank (scaled by vocab size ratio)',
    yaxis_title='Frequency',
    yaxis_type='log',
    width=1000,
    height=600,
    hovermode='x unified',  # Show all traces at the same x coordinate
)

fig.show()

## A visualization of the model

In [None]:
from model import HookedTransformerModel
import torch
import plotly.graph_objects as go
import plotly.express as px
import os

In [65]:
folder_timestamp = "20251124_125523"

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    device = torch.device('cpu')
    print("Using CPU")

# Getting through the directory for the final model
checkpoint_model_paths = []
for root, dirs, files in os.walk(os.path.join('records', folder_timestamp)):
    for file in files:
        if file.endswith('.pt') or file.endswith('.pth'):
            if 'final_model' in file:
                best_model_path = os.path.join(root, file)
            else:
                checkpoint_model_paths.append(os.path.join(root, file))

config_path = os.path.join('records', folder_timestamp, 'config.yaml')
model = HookedTransformerModel(config_path)
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

Using MPS


HookedTransformerModel(
  (model): HookedTransformer(
    (embed): Embed()
    (hook_embed): HookPoint()
    (pos_embed): PosEmbed()
    (hook_pos_embed): HookPoint()
    (blocks): ModuleList(
      (0): TransformerBlock(
        (ln1): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (ln2): LayerNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (attn): Attention(
          (hook_k): HookPoint()
          (hook_q): HookPoint()
          (hook_v): HookPoint()
          (hook_z): HookPoint()
          (hook_attn_scores): HookPoint()
          (hook_pattern): HookPoint()
          (hook_result): HookPoint()
        )
        (mlp): MLP(
          (hook_pre): HookPoint()
          (hook_post): HookPoint()
        )
        (hook_attn_in): HookPoint()
        (hook_q_input): HookPoint()
        (hook_k_input): HookPoint()
        (hook_v_input): HookPoint()
        (hook_mlp_in): Hook

In [90]:
import numpy as np
W_E = model.model.embed.W_E.detach().cpu().numpy()
W_E_normalized = W_E / np.linalg.norm(W_E, axis=1, keepdims=True)
W_U = model.model.W_U.detach().cpu().numpy()
px.imshow(W_E_normalized @ W_E_normalized.T, color_continuous_scale='balance', zmin=-1, zmax=1)
# px.imshow(W_U.T @ W_U)

In [None]:
W_K = model.model.blocks[0].attn.W_K.detach().cpu().numpy()[0]
W_Q = model.model.blocks[0].attn.W_Q.detach().cpu().numpy()[0]
W_O = model.model.blocks[0].attn.W_O.detach().cpu().numpy()[0]
W_V = model.model.blocks[0].attn.W_V.detach().cpu().numpy()[0]

px.imshow(W_O @ W_V.T)


In [98]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

W_E = model.model.embed.W_E.detach().cpu().numpy()
W_E_normalized = W_E / np.linalg.norm(W_E, axis=1, keepdims=True)

# Compute the full matrix
full_matrix = W_E_normalized @ W_E_normalized.T

# Coarse-grain to 8x8 by taking block averages (excluding diagonal)
n = full_matrix.shape[0]
block_size = n // 8
coarse_matrix = np.zeros((8, 8))

for i in range(8):
    for j in range(8):
        i_start, i_end = i * block_size, (i + 1) * block_size
        j_start, j_end = j * block_size, (j + 1) * block_size
        block = full_matrix[i_start:i_end, j_start:j_end]
        
        # Create mask to exclude diagonal elements
        if i == j:
            # For diagonal blocks, exclude the diagonal
            mask = ~np.eye(block.shape[0], dtype=bool)
            coarse_matrix[i, j] = block[mask].mean()
        else:
            # For off-diagonal blocks, use all elements
            coarse_matrix[i, j] = block.mean()

# Create side-by-side subplots
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=('Full Resolution', '8x8 Coarse-Grained (excl. diagonal)'),
    horizontal_spacing=0.15
)

# Add full resolution heatmap
fig.add_trace(
    go.Heatmap(
        z=full_matrix,
        colorscale='balance',
        zmin=-1,
        zmax=1,
        showscale=True,
        colorbar=dict(x=0.45, len=0.9)
    ),
    row=1, col=1
)

# Add coarse-grained heatmap
fig.add_trace(
    go.Heatmap(
        z=coarse_matrix,
        colorscale='balance',
        zmin=-1,
        zmax=1,
        showscale=True,
        colorbar=dict(x=1.02, len=0.9)
    ),
    row=1, col=2
)

# Update layout
fig.update_layout(
    width=1200,
    height=600,
    title_text="Embedding Similarity Matrix"
)

# Update axes
fig.update_xaxes(title_text="Token", row=1, col=1)
fig.update_yaxes(title_text="Token", row=1, col=1)
fig.update_xaxes(title_text="Block", row=1, col=2)
fig.update_yaxes(title_text="Block", row=1, col=2)

fig.show()

## PCA for residual stream