# Language as a HMM

We try to learn an HMM description of English. 

In [1]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
test_text = ds["train"][0]['text']


One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


## Training tokenizers

In [13]:
from tokenizers import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast
import os

ds = load_dataset("roneneldan/TinyStories")

vocab_size = 4096
tok_name = 'custom_tokenizer_{}'.format(vocab_size)

# Create directory if it doesn't exist
os.makedirs(os.path.join('tokenizers', tok_name), exist_ok=True)

def text_iterator():
    for split in ds.values():
        for sample_text in split["text"]:
            yield sample_text

tok = ByteLevelBPETokenizer()
tok.train_from_iterator(
    text_iterator(),
    vocab_size=vocab_size,
    min_frequency=2,
    special_tokens=["<s>", "</s>", "<pad>", "<unk>"]
)
# Save the full tokenizer (creates tokenizer.json)
tok.save(os.path.join('tokenizers', tok_name, "tokenizer.json"))

# Create directory for HuggingFace tokenizer
hf_tok_name = "HF_custom_tokenizer_{}".format(vocab_size)
os.makedirs(os.path.join('tokenizers', hf_tok_name), exist_ok=True)

# Load from the saved tokenizer.json
hf_tok = PreTrainedTokenizerFast(
    tokenizer_file=os.path.join('tokenizers', tok_name, "tokenizer.json"),
    unk_token="<unk>",
    pad_token="<pad>",
    bos_token="<s>",
    eos_token="</s>",
)
hf_tok.save_pretrained(os.path.join('tokenizers', hf_tok_name))






('tokenizers/HF_custom_tokenizer_4096/tokenizer_config.json',
 'tokenizers/HF_custom_tokenizer_4096/special_tokens_map.json',
 'tokenizers/HF_custom_tokenizer_4096/tokenizer.json')

## Initial experiments

In [44]:
print([hf_tok.decode([hf_tok.encode("Hello, world!")[i]]) for i in range(len(hf_tok.encode("Hello, world!")))])

['Hello', ',', ' world', '!']


In [40]:
# word frequencies

from collections import Counter

sample_text = " ".join(ds["train"][:10000]['text'])
tokens = hf_tok.encode(sample_text)
token_counts = Counter(tokens)

import plotly.graph_objects as go

# Prepare data for plotting (excluding special tokens)
special_token_ids = [
    hf_tok.bos_token_id,
    hf_tok.eos_token_id,
    hf_tok.pad_token_id,
    hf_tok.unk_token_id,
]
token_ids = [token_id for token_id in token_counts.keys() if token_id not in special_token_ids]
frequencies = [token_counts[token_id] for token_id in token_ids]

frequencies = sorted(token_counts.values(), reverse=True)



In [39]:
# Compare token frequencies across three tokenizers with rescaling
import plotly.graph_objects as go
from transformers import PreTrainedTokenizerFast
from collections import Counter

# Load the three tokenizers
vocab_sizes = [1024, 2048, 4096]
scale_factors = {1024: 4, 2048: 2, 4096: 1}  # Rescaling factors
tokenizers = {}
for vocab_size in vocab_sizes:
    hf_tok_path = os.path.join('tokenizers', f'HF_custom_tokenizer_{vocab_size}')
    tokenizers[vocab_size] = PreTrainedTokenizerFast.from_pretrained(hf_tok_path)

# Sample text (same as before)
sample_text = " ".join(ds["train"][:10000]['text'])

# Create plotly figure
fig = go.Figure()

# Process each tokenizer
for vocab_size in vocab_sizes:
    tok = tokenizers[vocab_size]
    scale_factor = scale_factors[vocab_size]
    
    # Encode text and count tokens
    tokens = tok.encode(sample_text)
    token_counts = Counter(tokens)
    
    # Get special token IDs to exclude
    special_token_ids = {
        tok.bos_token_id,
        tok.eos_token_id,
        tok.pad_token_id,
        tok.unk_token_id,
    }
    
    # Filter out special tokens and sort by frequency
    token_freq_pairs = [
        (token_id, count) 
        for token_id, count in token_counts.items() 
        if token_id not in special_token_ids
    ]
    token_freq_pairs.sort(key=lambda x: x[1], reverse=True)
    
    # Prepare data for plotting
    token_ids = [pair[0] for pair in token_freq_pairs]
    frequencies = [pair[1] for pair in token_freq_pairs]
    
    
    # Decode tokens for hover text
    token_strings = [tok.decode([token_id]) for token_id in token_ids]
    
    # Scale x-coordinates
    x_coords = [i * scale_factor for i in range(len(frequencies))]
    
    # Create custom hover data showing token ID and decoded string
    customdata = list(zip(token_ids, token_strings))
    
    # Add trace
    fig.add_trace(go.Scatter(
        x=x_coords,
        y=frequencies,
        mode='lines',
        name=f'Vocab {vocab_size}',
        customdata=customdata,
        hovertemplate='<b>Vocab %{fullData.name}</b><br>' +
                      'Rank: %{x}<br>' +
                      'Token ID: %{customdata[0]}<br>' +
                      'Token: %{customdata[1]}<br>' +
                      'Frequency: %{y}<br>' +
                      '<extra></extra>',
    ))

# Update layout
fig.update_layout(
    title='Token Frequency Distribution Across Tokenizers (Rescaled)',
    xaxis_title='Token Rank (scaled by vocab size ratio)',
    yaxis_title='Frequency',
    yaxis_type='log',
    width=1000,
    height=600,
    hovermode='x unified',  # Show all traces at the same x coordinate
)

fig.show()

## A visualization of the model

In [48]:
import sys
sys.path.append('/Users/linc/Documents/workspace/hidden-markov-chain')

from model import HookedTransformerModel
import torch
import circuitsvis as cv
from IPython.display import display

# Load the model
config_path = 'config/psl7.yaml'
model = HookedTransformerModel(config_path)

# Try to load checkpoint if available
import os
checkpoint_dirs = [
    'wandb/latest-run/files',
    'checkpoints',
    '.'
]

for checkpoint_dir in checkpoint_dirs:
    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
        break
else:
    print("No checkpoint found, using randomly initialized model")

model.eval()

print(f"Model has {model.model.cfg.n_layers} layers, {model.model.cfg.n_heads} heads")
print(f"d_model: {model.model.cfg.d_model}, d_head: {model.model.cfg.d_head}")

# Visualize various weight matrices
layer_idx = 0

# 1. Token Embedding weights
print("\n=== Token Embedding Weights ===")
W_E = model.model.embed.W_E.detach()  # [vocab_size, d_model]
print(f"Shape: {W_E.shape}")
display(cv.tokens.colored_tokens(
    tokens=[f"tok_{i}" for i in range(min(20, W_E.shape[0]))],
    values=W_E[:20].mean(dim=1).cpu().numpy(),
))

# 2. QK and OV Circuits for each head
print(f"\n=== Layer {layer_idx} Attention Circuits ===")
W_Q = model.model.blocks[layer_idx].attn.W_Q  # [n_heads, d_model, d_head]
W_K = model.model.blocks[layer_idx].attn.W_K
W_V = model.model.blocks[layer_idx].attn.W_V
W_O = model.model.blocks[layer_idx].attn.W_O  # [n_heads, d_head, d_model]

# Compute QK circuit (W_Q @ W_K.T) for each head
for head in range(model.model.cfg.n_heads):
    W_QK = W_Q[head] @ W_K[head].T  # [d_model, d_model]
    print(f"\nHead {head} - QK Circuit shape: {W_QK.shape}")
    
    # Create a heatmap visualization
    display(cv.attention.attention_patterns(
        attention=W_QK.detach().cpu().numpy()[None, None, :, :],  # Add batch and head dims
        tokens=[f"d{i}" for i in range(W_QK.shape[0])],
        # attention_head_names=[f"QK_H{head}"]
    ))

# 3. OV circuits
print(f"\n=== OV Circuits (Layer {layer_idx}) ===")
for head in range(model.model.cfg.n_heads):
    W_OV = W_V[head] @ W_O[head]  # [d_model, d_model]
    print(f"\nHead {head} - OV Circuit shape: {W_OV.shape}")
    
    display(cv.attention.attention_patterns(
        attention=W_OV.detach().cpu().numpy()[None, None, :, :],
        tokens=[f"d{i}" for i in range(W_OV.shape[0])],
        # attention_head_names=[f"OV_H{head}"]
    ))

# 4. MLP weights
print(f"\n=== Layer {layer_idx} MLP Weights ===")
W_in = model.model.blocks[layer_idx].mlp.W_in.detach()  # [d_model, d_mlp]
W_out = model.model.blocks[layer_idx].mlp.W_out.detach()  # [d_mlp, d_model]
print(f"W_in shape: {W_in.shape}, W_out shape: {W_out.shape}")

# Show first few neurons
display(cv.attention.attention_patterns(
    attention=W_in[:, :min(32, W_in.shape[1])].T.cpu().numpy()[None, None, :, :],
    tokens=[f"d{i}" for i in range(W_in.shape[0])],
))


No checkpoint found, using randomly initialized model
Model has 1 layers, 1 heads
d_model: 4, d_head: 4

=== Token Embedding Weights ===
Shape: torch.Size([64, 4])



=== Layer 0 Attention Circuits ===

Head 0 - QK Circuit shape: torch.Size([4, 4])



=== OV Circuits (Layer 0) ===

Head 0 - OV Circuit shape: torch.Size([4, 4])



=== Layer 0 MLP Weights ===
W_in shape: torch.Size([4, 16]), W_out shape: torch.Size([16, 4])
