In [1]:
import einops

from functools import partial
from itertools import product
import numpy as np
from pathlib import Path
from plotnine import (
    ggplot,
    geom_point, 
    geom_histogram, 
    geom_line,
    geom_ribbon,
    qplot, 
    coord_fixed, 
    aes, 
    facet_wrap, 
    labs,
    scale_x_log10,
    scale_y_log10
)
import polars as pl
import torch
from scipy.sparse import coo_array, csr_array
from scipy import linalg
import scipy.sparse.linalg as sparselinalg
from tokengrams import MemmapIndex, InMemoryIndex
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer, HookedTransformerConfig
import zstandard as zstd


from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from ngram_markov.model import GPT, GPTConfig
from torch.nn.functional import softmax


import collections
from collections import defaultdict
from itertools import islice
import numpy as np
import rustworkx as rx

from typing import List, Tuple


import cola



In [2]:
import torch

def noise_operator(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    # Create a mask for the positions we'll potentially modify
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    # Expand tokens for num_samples
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    # Create Bernoulli mask
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    # Sample from unigram distribution
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    # Apply noise
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    
    return noised_tokens

In [3]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer/tinystories512")
ts_bin_path = 'data/tinystories/train.bin'
index_path = "data/tinystories/suffix_tree.idx"

index = MemmapIndex(ts_bin_path, index_path)

In [5]:
epoch = 20_000
model_path = Path('/workspace/checkpoints')

ckpt = torch.load(model_path / f'ckpt{epoch}.pt', map_location='cpu')

In [6]:
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from torch.nn.functional import softmax, log_softmax


config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])

tl_weights = convert_nanogpt_weights(ckpt['model'], config)
tl_model = HookedTransformer(config)
tl_model.load_state_dict(tl_weights)

#gpt_config = GPTConfig(**ckpt['model_args'])

#gpt_model = GPT(gpt_config)
#gpt_model.load_state_dict(ckpt['model'])
#gpt_model.to('cuda')

<All keys matched successfully>

In [7]:
from ngram_markov.ngrams import create_ngrams, calculate_ngram_kl_divergence


data_dir = Path('data/tinystories')

batch_size = 16
block_size = 1024
device_type = 'cpu'
device = 'cpu'

def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    else:
        data = np.memmap(data_dir / 'validation.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [10]:
queries = [[i] for i in range(512)]
bigram_counts = torch.tensor(index.batch_count_next(queries, 512), dtype=torch.float32)
unigrams = bigram_counts.sum(dim=1)
unigrams /= unigrams.sum()

In [11]:
batch, _ = get_batch('train')



noised_batch = noise_operator(batch, unigrams, 0.1, 4, 1024)

In [66]:
@torch.no_grad()
def run_minibatch(model, minibatch):
    with torch.amp.autocast('cuda'):
        logits = model(minibatch.to('cuda'))[:, -1:, :]
        current_log_probs = log_softmax(logits, dim=-1)
    return current_log_probs.to('cpu')
    

def noise_operator(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    # Count corrupted tokens per sequence
    #corrupted_count = (noised_tokens != tokens.unsqueeze(1)).sum(dim=-1)
    
    return noised_tokens


import torch

def noise_operator_metadata(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    
    # Identify corrupted tokens
    corrupted_mask = (noised_tokens != tokens.unsqueeze(1))
    
    # Count corrupted tokens per sequence
    corrupted_count = corrupted_mask.sum(dim=-1)
    
    # Track positions of corrupted tokens
    position_indices = torch.arange(seq_len, device=device).expand(batch_size, num_samples, -1)
    corrupted_positions = position_indices[corrupted_mask]
    
    # Calculate distance from end for each corrupted token
    distance_from_end = seq_len - corrupted_positions - 1
    
    return noised_tokens, corrupted_count, corrupted_positions, distance_from_end

# Usage example:
# noised_tokens, corrupted_count, corrupted_positions, distance_from_end = noise_operator(tokens, unigram_probabilities, rho, n_gram, num_samples)

    
def compare_noised_predictions_batched(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size=10):
    device = batch.device
    orig_batch_size, seq_len = batch.shape

    # Get original predictions and apply log_softmax
    with torch.no_grad():
        original_logits = log_softmax(model(batch)[:, -1:, :], dim=-1)

    # Apply noise operator
    noised_batches = noise_operator(
        batch,
        unigram_probabilities,
        rho,
        n_gram,
        num_samples
    )
    
    # Initialize tensor to store noised log_softmax probabilities
    noised_log_probs = torch.zeros((orig_batch_size, num_samples, original_logits.shape[-1]), device='cpu')

    # Process noised batches
    for i in range(0, orig_batch_size):
        end_i = min(i + 1, orig_batch_size)
        for j in range(0, num_samples, batch_size):
            end_j = min(j + batch_size, num_samples)
            
            current_batch = noised_batches[i:end_i, j:end_j].reshape(-1, seq_len)            
            current_log_probs = run_minibatch(model, current_batch)
            reshaped_log_probs = current_log_probs.view(end_i - i, end_j - j, -1)
            noised_log_probs[i:end_i, j:end_j] = reshaped_log_probs.to('cpu')

    return original_logits, noised_log_probs
    

def kl_divergence_log(p_log, q_log):
    return (p_log.exp() * (p_log - q_log)).sum(dim=-1)


def entropy(p_log):
    return -1. * (softmax(p_log, dim=-1) * p_log).sum(dim=-1)


def analyze_noise_effect(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size=10):
    original_log_probs, noised_log_probs = (
        compare_noised_predictions_batched(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size)
    )

    mean_noised_log_probs = log_softmax(noised_log_probs.mean(dim=1), dim=-1)

    # Move tensors to CPU for calculations
    original_log_probs = original_log_probs.cpu()
    noised_log_probs = noised_log_probs.cpu()

    # Calculate KL divergences in log space
    kl_noised_original = kl_divergence_log(noised_log_probs, original_log_probs)
    kl_mean_noised_original = kl_divergence_log(mean_noised_log_probs, original_log_probs.squeeze())

    # Print some basic statistics for reference

    avg_kl_noised_original = kl_noised_original.mean().item()
    max_kl_noised_original =  kl_noised_original.max().item()
    avg_kl_mean_noised_original = kl_mean_noised_original.mean().item()
    max_kl_mean_noised_original =  kl_mean_noised_original.max().item()
    

    print(f"Average KL(noised || original): {avg_kl_noised_original:.4f}")
    print(f"Max KL(noised || original): {max_kl_noised_original:.4f}")
    print(f"Average KL(mean noised || original): {avg_kl_mean_noised_original:.4f}")
    print(f"Max KL(mean noised || original): {max_kl_mean_noised_original:.4f}")

    return original_log_probs.squeeze(), mean_noised_log_probs, kl_noised_original #, corrupted_positions, dist_from_end

# Usage example
# model = your_transformer_model
# batch = your_input_batch
# unigram_probabilities = your_unigram_probabilities
# rho = 0.1
# n_gram = 2
# num_samples = 1000
# batch_size = 10  # Adjust this based on your GPU memory
# per_seq_std, euclidean_distance = analyze_noise_effect(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size)

In [67]:
batch, _ = get_batch('train')
log_probs, t_log_probs, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 256, 16)

Average KL(noised || original): 0.2717
Max KL(noised || original): 9.3362
Average KL(mean noised || original): 0.0358
Max KL(mean noised || original): 0.1576


In [68]:
entropy(t_log_probs)

tensor([5.3794e-04, 3.0769e-02, 2.2824e-02, 4.5360e-01, 1.0706e+00, 1.4880e+00,
        1.7449e+00, 4.5146e-03, 8.0832e-01, 3.0447e-01, 2.6097e-02, 2.4319e+00,
        2.5486e+00, 1.6996e+00, 1.3079e-04, 1.2619e+00])

In [69]:
entropy(log_probs.squeeze()) - entropy(t_log_probs)

tensor([ 1.1989e-04, -2.5578e-02, -1.9680e-02,  3.8841e-02, -1.2229e-01,
        -5.4508e-01, -3.0910e-01, -2.8948e-03, -2.7419e-01,  3.1388e-02,
         7.1690e-03, -7.3609e-01, -4.2333e-01, -6.1854e-01, -3.1693e-05,
        -4.4911e-01])

In [72]:
index.batch_count_next(batch[:, -5:].tolist(), 512)

[[0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  25597,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
 

In [59]:
log_probs, _, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 512, 16)

Average KL(noised || original): 0.2986
Max KL(noised || original): 13.0858
Average KL(mean noised || original): 0.0897
Max KL(mean noised || original): 0.4492


In [58]:
_, _, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 1024, 16)

Average KL(noised || original): 0.3129
Max KL(noised || original): 12.2490
Average KL(mean noised || original): 0.0929
Max KL(mean noised || original): 0.4068


In [60]:
entropy(log_probs)

tensor([[1.0375e+00],
        [7.5132e-04],
        [7.6108e-03],
        [1.4330e+00],
        [1.6799e-01],
        [1.4628e+00],
        [1.6195e-03],
        [5.4730e-01],
        [6.8034e-01],
        [6.7074e-02],
        [1.9394e-02],
        [1.2514e+00],
        [1.2036e+00],
        [1.5110e+00],
        [1.8040e+00],
        [1.9232e-01]])

In [None]:


def get_batch(batch_size=16, seq_len=1024):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122

    data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in ix])
    
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x = x.to(device)
    return x


def noise_sensitivity_for_checkpoint(model, num_sequences, unigram_probabilities, rho, n_gram, num_samples, sample_batch_size):
    batch = get_batch('train')
    


def noise_sensitivity_over_time(epochs, num_sequences, unigram_probabilities, rho, n_gram, num_samples, sample_batch_size):
    model_path = Path('/media/External01/out')
    for step in epochs:
        ckpt = torch.load(model_path / f'ckpt{epoch}.pt', map_location='cpu')
        config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
        tl_weights = convert_nanogpt_weights(ckpt['model'], config)
        tl_model = HookedTransformer(config)
        tl_model.load_state_dict(tl_weights)