In [2]:
import einops

from functools import partial
from itertools import product
import numpy as np
from pathlib import Path
from plotnine import (
    ggplot,
    geom_point, 
    geom_histogram, 
    geom_line,
    geom_ribbon,
    qplot, 
    coord_fixed, 
    aes, 
    facet_wrap, 
    labs,
    scale_x_log10,
    scale_y_log10
)
import polars as pl
import torch

from tokengrams import MemmapIndex, InMemoryIndex
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from transformer_lens import HookedTransformerConfig
import zstandard as zstd


from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.hooked_transformer import HookedTransformer
from torch.nn.functional import softmax


import collections
from collections import defaultdict
from itertools import islice
import numpy as np
import rustworkx as rx

from typing import List, Tuple


In [3]:
import torch

def noise_operator(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    # Create a mask for the positions we'll potentially modify
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    # Expand tokens for num_samples
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    # Create Bernoulli mask
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    # Sample from unigram distribution
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    # Apply noise
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    
    return noised_tokens

In [4]:
epoch = 53_000
model_path = Path('/media/External01/ngram-checkpoints/4layer_tinystories')

#ckpt = torch.load(, map_location='cpu')

In [5]:
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from torch.nn.functional import softmax, log_softmax
import einops
import torch
import plotly.express as px
from nnsight import LanguageModel



def load_nnsight_model(path):
    ckpt = torch.load(path, map_location='cpu')
    config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
    tl_weights = convert_nanogpt_weights(ckpt['model'], config)
    tl_model = HookedTransformer(config)
    tl_model.load_state_dict(tl_weights)
    return  LanguageModel(tl_model)


def load_tl_model(path):
    ckpt = torch.load(path, map_location='cpu')
    config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
    tl_weights = convert_nanogpt_weights(ckpt['model'], config)
    tl_model = HookedTransformer(config)
    tl_model.load_state_dict(tl_weights)
    return tl_model

model = load_tl_model(model_path / f'ckpt{epoch}.pt')
model.eval()

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [6]:
from ngram_markov.model import GPT, GPTConfig
from ngram_markov.utils import create_ngrams, nanogpt_to_hooked_transformer_config, convert_nanogpt_weights
from torch.nn.functional import softmax, log_softmax





#gpt_config = GPTConfig(**ckpt['model_args'])

#gpt_model = GPT(gpt_config)
#gpt_model.load_state_dict(ckpt['model'])
#gpt_model.to('cuda')

In [7]:
import torch

def create_local_attention_mask(seq_len: int, window_size: int, batch_size: int = 1, device='cpu') -> torch.Tensor:
    """
    Creates a mask for local attention where each token can only attend to
    the previous window_size tokens and itself, while maintaining causality.
    
    Args:
    seq_len (int): The sequence length.
    window_size (int): The number of previous tokens to attend to.
    batch_size (int): The batch size.
    
    Returns:
    torch.Tensor: A boolean mask of shape (batch_size, 1, seq_len, seq_len).
    """
    # Create a causal mask
    causal_mask = torch.tril(torch.ones(seq_len, seq_len))
    
    # Create a local attention mask
    local_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1-window_size)
    
    # Combine causal and local masks
    mask = causal_mask * local_mask
    
    # Add batch and head dimensions
    mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)
    mask = torch.logical_not(mask).float().masked_fill(mask == 0, float('-inf'))
    
    return mask.to(device)

# Example usage:
# seq_len = 1024
# window_size = 256
# batch_size = 1
# local_mask = create_local_attention_mask(seq_len, window_size, batch_size)



In [8]:
create_local_attention_mask(10, 3, 1)

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [-inf, 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [-inf, -inf, 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [-inf, -inf, -inf, 0., 0., 0., -inf, -inf, -inf, -inf],
          [-inf, -inf, -inf, -inf, 0., 0., 0., -inf, -inf, -inf],
          [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., -inf, -inf],
          [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., -inf],
          [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.]]]])

In [44]:
from ngram_markov.ngrams import create_ngrams, calculate_ngram_kl_divergence


data_dir = Path('data/tinystories')

batch_size = 32
block_size = 256
device_type = 'cpu'
device = 'cpu'

def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    else:
        data = np.memmap(data_dir / 'validation.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x
    else:
        x = x.to(device)
    return x

In [10]:
batch = get_batch('train')

In [11]:
local_mask = create_local_attention_mask(block_size, 2, 16)
full_preds = model(batch)
local_preds = model(batch, additive_attention_mask=local_mask.to('cuda:0'))
(full_preds - local_preds).abs().mean()



tensor(3.3166, device='cuda:0', grad_fn=<MeanBackward0>)

In [12]:
_, cache = model.run_with_cache(batch)
cache['blocks.0.attn.hook_z'].shape

torch.Size([16, 256, 4, 128])

In [45]:
import torch
from transformer_lens import HookedTransformer, ActivationCache
from tqdm.notebook import tqdm, trange
from ngram_markov.ngrams import kl_divergence

def run_model_and_cache(model, input_ids, window_size=5, use_local_attention=False):
    if use_local_attention:
        additive_mask = create_local_attention_mask(input_ids.shape[1], window_size=window_size, batch_size=input_ids.shape[0], device='cuda:0')
        output, cache = model.run_with_cache(input_ids, additive_attention_mask=additive_mask)
    else:
        output, cache = model.run_with_cache(input_ids)
    return output.to('cpu'), cache.to('cpu')

def patch_head_output(model, input_ids, source_cache, target_cache, layer, head, use_local_attention, window_size=2):
    hook_point = f'blocks.{layer}.attn.hook_z'
    def patch_hook(value, hook):
        patched_value = value.clone()
        #print((patched_value[:, :, head, :].to('cpu') - source_cache[hook_point][:, :, head, :]).abs().max())
        patched_value[:, :, head, :] = source_cache[hook_point][:, :, head, :].to('cuda:0')
        
        return patched_value
    
    if use_local_attention:
        additive_mask = create_local_attention_mask(input_ids.shape[1], window_size=window_size, batch_size=input_ids.shape[0], device='cuda:0')
        patched_output = model.run_with_hooks(input_ids, additive_attention_mask=additive_mask, fwd_hooks=[(hook_point, patch_hook)])
    else:
        patched_output = model.run_with_hooks(input_ids, fwd_hooks=[(hook_point, patch_hook)])
    return patched_output



def compute_impact(patched_output, baseline_output, window_size):
    #print((patched_output.to('cpu') - baseline_output).abs().max())
    return kl_divergence(
        patched_output.log_softmax(dim=-1).reshape(-1, 512)[window_size:, :],
        baseline_output.log_softmax(dim=-1).to('cuda:0').reshape(-1, 512)[window_size:, :],
    )

def analyze_attention_head_differences(model, input_ids, window_size=2):
    full_output, full_cache = run_model_and_cache(model, input_ids)
    local_output, local_cache = run_model_and_cache(model, input_ids, use_local_attention=True, window_size=window_size)
    
    results = []
    
    for layer in range(model.cfg.n_layers):        
        for head in range(model.cfg.n_heads):
            record = {'layer': layer, 'head': head}
            # Patch full into local
            patched_output_full_to_local = patch_head_output(model, input_ids, full_cache, local_cache, layer, head, use_local_attention=True, window_size=window_size)
            record['f2l_vs_local'] =  compute_impact(patched_output_full_to_local, local_output, window_size).cpu().numpy()
            record['f2l_vs_full'] = compute_impact(patched_output_full_to_local, full_output, window_size).cpu().numpy()
            
            # Patch local into full
            patched_output_local_to_full = patch_head_output(model, input_ids, local_cache, full_cache, layer, head, use_local_attention=False, window_size=window_size)
            
            record['l2f_vs_full'] = compute_impact(patched_output_local_to_full, full_output, window_size).cpu().numpy()
            record['l2f_vs_local'] = compute_impact(patched_output_local_to_full, local_output, window_size).cpu().numpy()
            df = pl.DataFrame(record).with_columns(head=pl.lit(head), layer=pl.lit(layer))
            results.append(df)
    
    return pl.concat(results, how='vertical')

# Usage
input_ids = get_batch('train')
model.to('cuda:0')
all_diffs = []
num_batches = 1000
for _ in trange(num_batches):
    input_ids = get_batch('train')
    with torch.no_grad():
        all_diffs.append(analyze_attention_head_differences(model, input_ids.to('cuda'), window_size=5))
patching_df = pl.concat(all_diffs, how='vertical').melt(id_vars=['layer', 'head'])
patching_df.write_parquet('patching_results_window_5.parquet')
# Analyze results
#for direction in ['full_to_local', 'local_to_full']:
#    print(f"\nTop 10 most impactful heads ({direction}):")
#    top_differences = sorted(differences[direction].items(), key=lambda x: x[1], reverse=True)[:10]
#    for head, impact in top_differences:
#        print(f"{head}: {impact}")

Moving model to device:  cuda:0


  0%|          | 0/1000 [00:00<?, ?it/s]

In [60]:
(
    patching_df
    .filter(pl.col('layer').eq(1) & pl.col('variable').eq('f2l_vs_local'))
    .group_by('head')
    .agg(
        min_diff = pl.col('value').min(),
        mean_diff = pl.col('value').mean(),
        med_diff = pl.col('value').median(),
        max_diff = pl.col('value').max(),
        q25_diff = pl.col('value').quantile(0.25),
        q75_diff = pl.col('value').quantile(0.75), 
        q90_diff = pl.col('value').quantile(0.90),
        q95_diff = pl.col('value').quantile(0.95),
    )
    .sort('head')
)

head,min_diff,mean_diff,med_diff,max_diff,q25_diff,q75_diff,q90_diff,q95_diff
i32,f32,f32,f32,f32,f32,f32,f32,f32
0,-6.8308e-07,0.799904,0.521741,17.296215,0.251066,1.045621,1.847112,2.505945
1,-6.8308e-07,1.083738,0.801561,20.4261,0.391188,1.470269,2.367022,3.063566
2,-6.8308e-07,0.802297,0.629361,15.186031,0.356174,1.04748,1.620571,2.088233
3,-6.8308e-07,0.746013,0.583443,13.601524,0.331211,0.975181,1.514431,1.952505


In [62]:
(
    patching_df
    .filter(
        pl.col('layer').eq(1) & 
        pl.col('variable').eq('f2l_vs_local') &
        pl.col('head').eq(1) &
        pl.col('value').gt(10)
        
    )
    .sort('value', descending=True)
)

layer,head,variable,value
i32,i32,str,f32
1,1,"""f2l_vs_local""",20.4261
1,1,"""f2l_vs_local""",16.592453
1,1,"""f2l_vs_local""",15.332425
1,1,"""f2l_vs_local""",15.306915
1,1,"""f2l_vs_local""",14.842405
…,…,…,…
1,1,"""f2l_vs_local""",10.007487
1,1,"""f2l_vs_local""",10.006559
1,1,"""f2l_vs_local""",10.005215
1,1,"""f2l_vs_local""",10.004627


In [59]:
(
    patching_df
    .filter(pl.col('layer').eq(1) & pl.col('variable').eq('l2f_vs_full'))
    .group_by('head')
    .agg(
        min_diff = pl.col('value').min(),
        mean_diff = pl.col('value').mean(),
        med_diff = pl.col('value').median(),
        max_diff = pl.col('value').max(),
        q25_diff = pl.col('value').quantile(0.25),
        q75_diff = pl.col('value').quantile(0.75), 
        q90_diff = pl.col('value').quantile(0.90),
        q95_diff = pl.col('value').quantile(0.95),
    )
    .sort('head')
)

head,min_diff,mean_diff,med_diff,max_diff,q25_diff,q75_diff,q90_diff,q95_diff
i32,f32,f32,f32,f32,f32,f32,f32,f32
0,-6.8308e-07,0.624556,0.072269,33.463722,0.00974,0.279776,1.090482,3.269177
1,-6.8308e-07,1.151524,0.257083,29.727833,0.045249,1.000877,3.357298,6.25361
2,-6.8308e-07,0.399355,0.103724,28.012333,0.011397,0.342546,0.907622,1.692867
3,-6.8308e-07,0.276543,0.086751,18.841259,0.004821,0.300218,0.7069,1.156817


In [43]:
df.filter(pl.col('layer').eq(0) & pl.col('head').eq(3)).sort('variable')

layer,head,variable,mean_kl
i32,i32,str,f32
0,3,"""f2l_vs_full""",7.550286
0,3,"""f2l_vs_local""",1.474907
0,3,"""l2f_vs_full""",4.649981
0,3,"""l2f_vs_local""",2.999097


In [10]:
queries = [[i] for i in range(512)]
bigram_counts = torch.tensor(index.batch_count_next(queries, 512), dtype=torch.float32)
unigrams = bigram_counts.sum(dim=1)
unigrams /= unigrams.sum()

In [11]:
batch, _ = get_batch('train')



noised_batch = noise_operator(batch, unigrams, 0.1, 4, 1024)

In [66]:
@torch.no_grad()
def run_minibatch(model, minibatch):
    with torch.amp.autocast('cuda'):
        logits = model(minibatch.to('cuda'))[:, -1:, :]
        current_log_probs = log_softmax(logits, dim=-1)
    return current_log_probs.to('cpu')
    

def noise_operator(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    # Count corrupted tokens per sequence
    #corrupted_count = (noised_tokens != tokens.unsqueeze(1)).sum(dim=-1)
    
    return noised_tokens


import torch

def noise_operator_metadata(tokens, unigram_probabilities, rho: float, n_gram: int, num_samples: int):
    batch_size, seq_len = tokens.shape
    device = tokens.device
    
    mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)
    mask[:, -n_gram:] = False
    
    noised_tokens = tokens.unsqueeze(1).expand(-1, num_samples, -1).clone()
    
    bernoulli_mask = torch.bernoulli(torch.full((batch_size, num_samples, seq_len), rho, device=device))
    bernoulli_mask = bernoulli_mask * mask.unsqueeze(1)
    
    sampled_tokens = torch.multinomial(
        unigram_probabilities,
        batch_size * num_samples * seq_len,
        replacement=True
    ).view(batch_size, num_samples, seq_len)
    
    noised_tokens = torch.where(bernoulli_mask.bool(), sampled_tokens, noised_tokens)
    
    # Identify corrupted tokens
    corrupted_mask = (noised_tokens != tokens.unsqueeze(1))
    
    # Count corrupted tokens per sequence
    corrupted_count = corrupted_mask.sum(dim=-1)
    
    # Track positions of corrupted tokens
    position_indices = torch.arange(seq_len, device=device).expand(batch_size, num_samples, -1)
    corrupted_positions = position_indices[corrupted_mask]
    
    # Calculate distance from end for each corrupted token
    distance_from_end = seq_len - corrupted_positions - 1
    
    return noised_tokens, corrupted_count, corrupted_positions, distance_from_end

# Usage example:
# noised_tokens, corrupted_count, corrupted_positions, distance_from_end = noise_operator(tokens, unigram_probabilities, rho, n_gram, num_samples)

    
def compare_noised_predictions_batched(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size=10):
    device = batch.device
    orig_batch_size, seq_len = batch.shape

    # Get original predictions and apply log_softmax
    with torch.no_grad():
        original_logits = log_softmax(model(batch)[:, -1:, :], dim=-1)

    # Apply noise operator
    noised_batches = noise_operator(
        batch,
        unigram_probabilities,
        rho,
        n_gram,
        num_samples
    )
    
    # Initialize tensor to store noised log_softmax probabilities
    noised_log_probs = torch.zeros((orig_batch_size, num_samples, original_logits.shape[-1]), device='cpu')

    # Process noised batches
    for i in range(0, orig_batch_size):
        end_i = min(i + 1, orig_batch_size)
        for j in range(0, num_samples, batch_size):
            end_j = min(j + batch_size, num_samples)
            
            current_batch = noised_batches[i:end_i, j:end_j].reshape(-1, seq_len)            
            current_log_probs = run_minibatch(model, current_batch)
            reshaped_log_probs = current_log_probs.view(end_i - i, end_j - j, -1)
            noised_log_probs[i:end_i, j:end_j] = reshaped_log_probs.to('cpu')

    return original_logits, noised_log_probs
    

def kl_divergence_log(p_log, q_log):
    return (p_log.exp() * (p_log - q_log)).sum(dim=-1)


def entropy(p_log):
    return -1. * (softmax(p_log, dim=-1) * p_log).sum(dim=-1)


def analyze_noise_effect(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size=10):
    original_log_probs, noised_log_probs = (
        compare_noised_predictions_batched(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size)
    )

    mean_noised_log_probs = log_softmax(noised_log_probs.mean(dim=1), dim=-1)

    # Move tensors to CPU for calculations
    original_log_probs = original_log_probs.cpu()
    noised_log_probs = noised_log_probs.cpu()

    # Calculate KL divergences in log space
    kl_noised_original = kl_divergence_log(noised_log_probs, original_log_probs)
    kl_mean_noised_original = kl_divergence_log(mean_noised_log_probs, original_log_probs.squeeze())

    # Print some basic statistics for reference

    avg_kl_noised_original = kl_noised_original.mean().item()
    max_kl_noised_original =  kl_noised_original.max().item()
    avg_kl_mean_noised_original = kl_mean_noised_original.mean().item()
    max_kl_mean_noised_original =  kl_mean_noised_original.max().item()
    

    print(f"Average KL(noised || original): {avg_kl_noised_original:.4f}")
    print(f"Max KL(noised || original): {max_kl_noised_original:.4f}")
    print(f"Average KL(mean noised || original): {avg_kl_mean_noised_original:.4f}")
    print(f"Max KL(mean noised || original): {max_kl_mean_noised_original:.4f}")

    return original_log_probs.squeeze(), mean_noised_log_probs, kl_noised_original #, corrupted_positions, dist_from_end

# Usage example
# model = your_transformer_model
# batch = your_input_batch
# unigram_probabilities = your_unigram_probabilities
# rho = 0.1
# n_gram = 2
# num_samples = 1000
# batch_size = 10  # Adjust this based on your GPU memory
# per_seq_std, euclidean_distance = analyze_noise_effect(model, batch, unigram_probabilities, rho, n_gram, num_samples, batch_size)

In [67]:
batch, _ = get_batch('train')
log_probs, t_log_probs, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 256, 16)

Average KL(noised || original): 0.2717
Max KL(noised || original): 9.3362
Average KL(mean noised || original): 0.0358
Max KL(mean noised || original): 0.1576


In [68]:
entropy(t_log_probs)

tensor([5.3794e-04, 3.0769e-02, 2.2824e-02, 4.5360e-01, 1.0706e+00, 1.4880e+00,
        1.7449e+00, 4.5146e-03, 8.0832e-01, 3.0447e-01, 2.6097e-02, 2.4319e+00,
        2.5486e+00, 1.6996e+00, 1.3079e-04, 1.2619e+00])

In [69]:
entropy(log_probs.squeeze()) - entropy(t_log_probs)

tensor([ 1.1989e-04, -2.5578e-02, -1.9680e-02,  3.8841e-02, -1.2229e-01,
        -5.4508e-01, -3.0910e-01, -2.8948e-03, -2.7419e-01,  3.1388e-02,
         7.1690e-03, -7.3609e-01, -4.2333e-01, -6.1854e-01, -3.1693e-05,
        -4.4911e-01])

In [59]:
log_probs, _, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 512, 16)

Average KL(noised || original): 0.2986
Max KL(noised || original): 13.0858
Average KL(mean noised || original): 0.0897
Max KL(mean noised || original): 0.4492


In [None]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer/tinystories512")
ts_bin_path = 'data/tinystories/train.bin'
index_path = "data/tinystories/ngrams/suffix_tree.idx"

index = MemmapIndex(ts_bin_path, index_path)

In [58]:
_, _, kl_noised_original = analyze_noise_effect(tl_model, batch, unigrams, 0.1, 5, 1024, 16)

Average KL(noised || original): 0.3129
Max KL(noised || original): 12.2490
Average KL(mean noised || original): 0.0929
Max KL(mean noised || original): 0.4068


In [60]:
entropy(log_probs)

tensor([[1.0375e+00],
        [7.5132e-04],
        [7.6108e-03],
        [1.4330e+00],
        [1.6799e-01],
        [1.4628e+00],
        [1.6195e-03],
        [5.4730e-01],
        [6.8034e-01],
        [6.7074e-02],
        [1.9394e-02],
        [1.2514e+00],
        [1.2036e+00],
        [1.5110e+00],
        [1.8040e+00],
        [1.9232e-01]])

In [None]:


def get_batch(batch_size=16, seq_len=1024):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122

    data = np.memmap(data_dir / 'train.bin', dtype=np.uint16, mode='r')
    
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in ix])
    
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x = x.to(device)
    return x


def noise_sensitivity_for_checkpoint(model, num_sequences, unigram_probabilities, rho, n_gram, num_samples, sample_batch_size):
    batch = get_batch('train')
    


def noise_sensitivity_over_time(epochs, num_sequences, unigram_probabilities, rho, n_gram, num_samples, sample_batch_size):
    model_path = Path('/media/External01/out')
    for step in epochs:
        ckpt = torch.load(model_path / f'ckpt{epoch}.pt', map_location='cpu')
        config = nanogpt_to_hooked_transformer_config(ckpt['model_args'])
        tl_weights = convert_nanogpt_weights(ckpt['model'], config)
        tl_model = HookedTransformer(config)
        tl_model.load_state_dict(tl_weights)