In [1]:
import einops

from functools import partial
from itertools import product
import numpy as np
from pathlib import Path
from plotnine import (
    ggplot,
    geom_point, 
    geom_histogram, 
    geom_line,
    geom_ribbon,
    qplot, 
    coord_fixed, 
    aes, 
    facet_wrap, 
    labs,
    scale_x_log10,
    scale_y_log10
)
import polars as pl
import torch
from scipy.sparse import coo_array, csr_array
from scipy import linalg
import scipy.sparse.linalg as sparselinalg
from tokengrams import MemmapIndex, InMemoryIndex
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from ngram_markov.hooked_transformer import HookedTransformer
from transformer_lens import HookedTransformerConfig


In [2]:
torch.cuda.is_available()

True

In [4]:
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from torch.nn.functional import softmax, log_softmax
import einops
import torch
import plotly.express as px
from nnsight import LanguageModel


def load_nnsight_model(path):
    ckpt = torch.load(path, map_location='cpu')
    config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
    tl_weights = convert_nanogpt_weights(ckpt['model'], config)
    tl_model = HookedTransformer(config)
    tl_model.load_state_dict(tl_weights)
    return  LanguageModel(tl_model)


def load_tl_model(path):
    ckpt = torch.load(path, map_location='cpu')
    config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
    tl_weights = convert_nanogpt_weights(ckpt['model'], config)
    tl_model = HookedTransformer(config)
    tl_model.load_state_dict(tl_weights)
    return tl_model

model = load_tl_model(model_path / f'ckpt{epoch}.pt')
model.eval()

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [None]:
2 + 2

In [5]:
data_dir = Path('data/tinystories')

batch_size = 2048
device_type = 'cuda'
device = 'cuda'

def get_batch(block_size=512):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x  = x.to(device)
    return x


In [6]:
ngrams = get_batch()

In [7]:
tinystories = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')

In [None]:

def get_ngram_occurrences(index, raw_data, ngrams, min_occur, ctx_len):
    n = ngrams.shape[1]
    ngram_positions = [
        index.positions(query)[:min_occur] for query in ngrams.cpu().tolist()
    ]
    mask = torch.tensor([len(pos) >= min_occur for pos in ngram_positions])
    ngram_occurrences = []
    for positions in ngram_positions:
        if len(positions) < min_occur:
            continue
        occurrences = torch.stack([
            torch.from_numpy(raw_data[(i-ctx_len):(i + n)].astype(np.int64)) for i in positions
        ])
        ngram_occurrences.append(occurrences.unsqueeze(0))
    return torch.concat(ngram_occurrences), ngrams[mask]



In [None]:
ngram_occurrences, reduced_ngrams = get_ngram_occurrences(index, tinystories, ngrams, 32, 12)

In [None]:
with torch.no_grad():
    base_pred = tl_model(reduced_ngrams)[:, -1, :].squeeze().to('cpu')

In [None]:
model_preds = []
with torch.no_grad():
    for batch in ngram_occurrences.reshape(-1, 8).split(128):
        model_preds.append(tl_model(batch.to('cuda'))[:, -1, :].to('cpu'))

In [None]:
preds = torch.concat(model_preds)
preds.reshape(32, -1, 512).shape

In [None]:
preds.mean(dim=1)

In [8]:
data_dir = Path('data/tinystories')

batch_size = 16_384
device_type = 'cpu'
device = 'cpu'

def get_batch(block_size=5):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x  = x.to(device)
    return x


In [13]:
five_grams = [get_batch().tolist() for _ in range(10)]

In [14]:
from itertools import chain
i = 0
for fg in chain.from_iterable(five_grams):
    if i > 10:
        break
    print(fg)
    i += 1
    

[263, 78, 338, 403, 77]
[276, 72, 290, 68, 271]
[83, 415, 199, 199, 52]
[492, 79, 83, 264, 265]
[14, 313, 266, 324, 434]
[385, 378, 266, 359, 342]
[89, 392, 79, 282, 388]
[14, 221, 490, 440, 259]
[46, 319, 12, 275, 311]
[265, 298, 287, 65, 267]
[472, 300, 69, 413, 260]


In [15]:
def prefix_completion(starting_grams):
    # Set to store all unique n-grams (including prefixes)
    all_grams = set()

    # Add all prefixes of each 4-gram
    for gram in starting_grams:
        for i in range(1, len(gram) + 1):
            all_grams.add(tuple(gram[:i]))

    # Calculate the difference between all grams and original 4-grams
    completion = all_grams - set([tuple(gram) for gram in starting_grams])

    return sorted(list(completion))

# Example usage
result = prefix_completion(chain.from_iterable(five_grams))
print(f"Prefix completion: {len(result)}")

Prefix completion: 287452


In [19]:
from torch.nn.functional import softmax, log_softmax

#preds = tl_model(ngram_occurrences)[:, -1, :]
#preds -= preds.mean(dim=1, keepdims=True)

In [36]:
from ngram_markov.ngrams import kl_divergence


base_logits = log_softmax(base_pred, dim=-1)
mean_logits = log_softmax(preds.reshape(32, -1, 512).mean(dim=0), dim=-1)


kl_divergence(mean_logits, base_logits)

tensor([0.6264, 0.6268, 0.6962,  ..., 0.5244, 0.8024, 0.9430])

In [20]:
ngram_occurrences.shape

NameError: name 'ngram_occurrences' is not defined

In [1]:
512 ** 3


134217728

In [28]:
epoch = 53_000
model_path = Path('/media/External01/ngram-checkpoints/4layer_tinystories')

vocab_size = 512


model = load_tl_model(model_path / f'ckpt{epoch}.pt')
model.eval()

model_preds = []

pred_list = []

with torch.no_grad():
    unigram_preds = model(torch.arange(512).unsqueeze(1))[:, -1, :].cpu()

all_bigrams = torch.cartesian_prod(torch.arange(vocab_size), torch.arange(vocab_size))

with torch.no_grad():
    for batch in all_bigrams.split(2048):
        preds = model(batch.to('cuda'))[:, -1, :].cpu()
        pred_list.append(preds)
all_trigram_preds = torch.concat(pred_list)
all_preds = torch.concat([unigram_preds, all_trigram_preds], dim=0)
U, S, V = torch.svd(all_preds)

In [29]:
torch.linalg.matrix_rank(all_preds)

tensor(14)

In [37]:
class SimpleWFA:
    def __init__(self, U, S, V, all_preds, vocab_size, n_states):
        self.U = U[:, :n_states]
        self.S = S[:n_states]
        self.V = V[:, :n_states]
        self.vocab_size = vocab_size
        self.n_states = n_states

        # Initialize WFA parameters
        self.alpha = softmax(self.U[:vocab_size, :].sum(0), dim=0)  # Initial state distribution
        self.omega = softmax(self.V[:vocab_size, :].sum(0), dim=0)  # Final state distribution
        
        # Project the original predictions into the lower-dimensional space
        projected_preds = all_preds @ self.V
        
        # Separate unigram and bigram predictions
        unigram_preds = projected_preds[:vocab_size]
        bigram_preds = projected_preds[vocab_size:].reshape(vocab_size, vocab_size, n_states)

        # Construct transition matrices
        self.A = torch.zeros((vocab_size, n_states, n_states))
        for i in range(vocab_size):
            # bigram_preds[i] has shape (vocab_size, n_states)
            # We need to transform this into (n_states, n_states)
            state_transitions = self.V.T @ bigram_preds[i]  # Shape: (n_states, n_states)
            self.A[i] = softmax(state_transitions, dim=1)

    def evaluate(self, sequence):
        state = self.alpha  # Initial state distribution
        for symbol in sequence:
            state = state @ self.A[symbol]  # Update state distribution
        return state @ self.omega  # Final probability

    def sample(self, max_length=100):
        sequence = []
        state = self.alpha  # Initial state distribution

        for _ in range(max_length):
            # Compute symbol probabilities
            symbol_probs = state @ self.A.sum(2).T  # Sum over next states
            symbol = torch.multinomial(symbol_probs, 1).item()
            sequence.append(symbol)
            if symbol == 0:  # Assuming 0 is the end-of-sequence token
                break
            state = state @ self.A[symbol]  # Update state distribution

        return sequence

# all_preds is a Hankel matrix for a function over sequences
#U, S, V = torch.svd(all_preds)
vocab_size = 512
n_states = torch.linalg.matrix_rank(all_preds).item()
wfa = SimpleWFA(U, S, V, all_preds, vocab_size, n_states)

In [63]:
model.set_tokenizer(tokenizer)

In [61]:
bos = torch.tensor([0], device='cuda')
bos.unsqueeze(0)

tensor([[0]], device='cuda:0')

In [69]:
tokens = model.generate(bos.unsqueeze(0), max_new_tokens=400)

  0%|          | 0/400 [00:00<?, ?it/s]

torch.Size([1, 279])

In [71]:
tokenizer.decode(tokens[0].cpu().tolist())

'<|endoftext|>Once upon a time, there was a little blue potato. She had a shiny green dress, and she she wanted to show it to all her friends.\n\nSo, the potato put the potato on the floor and told everyone about her hogse. Everyone was so excited to see the potato on the freezer and wait, but the little girl kept taking it out.\n\nFinally, the freezer was cold and sparkly. All the sparkles were turning the potato out of the fridge. Everyone was so happy and they kept taking more and more.\n\nThe little blue potato was no longer short or tired. From then on, she was even more curious about a special fat potato. The potato was the happiest she had ever been.<|endoftext|>'

In [41]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer/tinystories512")


In [55]:
tokenizer.decode(wfa.sample(max_length=20))

'veray isg!3chF\x12o� J� B rǭe shel'

torch.Size([60, 512])

In [55]:
(wfa.A.transpose(1, 2) @ wfa.omega).shape

torch.Size([512, 512])

In [None]:
import os
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from transformer_lens import HookedTransformer

def process_checkpoints(model_path, vocab_size, batch_size=2048):
    """
    Process all checkpoints in the given path, calculate predictions, perform SVD,
    and save results as numpy files.

    Args:
    model_path (str): Path to the directory containing model checkpoints
    config: Model configuration
    vocab_size (int): Size of the vocabulary
    batch_size (int): Batch size for processing
    n_components (int): Number of components to keep in SVD

    Returns:
    None
    """
    model_path = Path(model_path)
    
    # Get all checkpoint files
    checkpoint_files = sorted([f for f in os.listdir(model_path) if f.startswith('ckpt') and f.endswith('.pt')])

    for ckpt_file in tqdm(checkpoint_files, desc="Processing checkpoints"):
        epoch = int(ckpt_file.split('ckpt')[1].split('.pt')[0])
        
        # Load the model
        ckpt = torch.load(model_path / ckpt_file)
        config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
        tl_weights = convert_nanogpt_weights(ckpt['model'], config)
        model = HookedTransformer(config)
        model.load_state_dict(tl_weights)
        model.eval()
        model.to('cuda')

        # Generate predictions
        pred_list = []
        unigram_inputs = torch.arange(vocab_size).unsqueeze(1).to('cuda')
        with torch.no_grad():
            unigram_preds = model(unigram_inputs)[:, -1, :].cpu()
            
            all_bigrams = torch.cartesian_prod(torch.arange(vocab_size), torch.arange(vocab_size))
            for batch in all_bigrams.split(batch_size):
                preds = model(batch.to('cuda'))[:, -1, :].cpu()
                pred_list.append(preds)

        all_trigram_preds = torch.cat(pred_list)
        all_preds = torch.cat([unigram_preds, all_trigram_preds], dim=0)

        # Perform SVD
        U, S, V = torch.svd(all_preds)
    
        # Save results
        save_dir = model_path / f'processed_ckpt{epoch}'
        save_dir.mkdir(exist_ok=True)

        np.save(save_dir / 'all_preds.npy', all_preds.numpy())
        np.save(save_dir / 'U.npy', U.numpy())
        np.save(save_dir / 'S.npy', S.numpy())
        np.save(save_dir / 'V.npy', V.numpy())

        print(f"Processed and saved results for checkpoint {epoch}")

# Usage
model_path = '/media/External01/ngram-checkpoints/4layer_tinystories'
vocab_size = 512

process_checkpoints(model_path, vocab_size)

Processing checkpoints:   0%|                                                                       | 0/77 [00:00<?, ?it/s]

Moving model to device:  cuda


Processing checkpoints:   1%|▊                                                              | 1/77 [00:14<17:52, 14.11s/it]

Processed and saved results for checkpoint 100
Moving model to device:  cuda


In [75]:
U_wfa = U[:, :n_states]
S_wfa = S[:n_states]
V_wfa = V[:, :n_states]
       


    # Initialize WFA parameters
alpha = U[0, :]  # Initial state
omega = V_wfa[0, :] * S_wfa  # Final state
        
        # Transition matrices (one per symbol)


In [82]:
V.shape

torch.Size([512, 512])

In [65]:
A = torch.stack([
    U[i*vocab_size:(i+1)*vocab_size, :] @ # 512 x 10
    torch.diag(S_wfa) @ # 10 x 10
    V_wfa[i*vocab_size:(i+1)*vocab_size, :].T
            for i in range(vocab_size)
])

In [66]:
S

tensor([1.1200e+04, 3.4420e+03, 1.9294e+03, 1.3880e+03, 1.2805e+03, 1.2543e+03,
        1.1475e+03, 1.0503e+03, 1.0072e+03, 9.9615e+02, 9.0286e+02, 8.8500e+02,
        8.6497e+02, 8.3656e+02, 8.0855e+02, 7.8499e+02, 7.7931e+02, 7.2183e+02,
        7.1082e+02, 7.0032e+02, 6.8252e+02, 6.5800e+02, 6.4312e+02, 6.3882e+02,
        6.1603e+02, 6.1202e+02, 6.0251e+02, 5.9163e+02, 5.8398e+02, 5.7652e+02,
        5.6503e+02, 5.5836e+02, 5.3562e+02, 5.3421e+02, 5.2170e+02, 5.0934e+02,
        5.0336e+02, 4.9327e+02, 4.8829e+02, 4.7580e+02, 4.6953e+02, 4.5649e+02,
        4.5277e+02, 4.4559e+02, 4.4255e+02, 4.3890e+02, 4.3068e+02, 4.2564e+02,
        4.1645e+02, 4.0704e+02, 4.0103e+02, 3.9427e+02, 3.9204e+02, 3.8407e+02,
        3.8080e+02, 3.7383e+02, 3.7059e+02, 3.6594e+02, 3.6220e+02, 3.5108e+02,
        3.4595e+02, 3.3800e+02, 3.3532e+02, 3.2680e+02, 3.2173e+02, 3.1639e+02,
        3.1121e+02, 3.0854e+02, 3.0516e+02, 3.0272e+02, 2.9451e+02, 2.9147e+02,
        2.8783e+02, 2.8378e+02, 2.8226e+

In [109]:
(base_pred - preds[6]).pow(2).sum()

tensor(210.9469, device='cuda:0', grad_fn=<SumBackward0>)

In [71]:
torch.allclose(softmax(preds[0] - preds[0].mean(), dim=0), softmax(preds[0], dim=0), atol=1.e-4)

True

In [75]:
(preds - preds.mean(dim=0, keepdims=True)).mean(dim=0).min()

tensor(-1.3447e-06, device='cuda:0', grad_fn=<MinBackward1>)