In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
from buffer import MultiModelActivationBuffer

from datasets import load_dataset
import torch as t

from nnsight import LanguageModel
from buffer import MultiModelActivationBuffer
from trainers.top_k import TopKTrainer, AutoEncoderTopK
from training import trainSAE
from einops import rearrange, einsum
import matplotlib.pyplot as plt
from tqdm import tqdm

device = "cuda:0"
dtype = t.bfloat16
t.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b52bce1ddd0>

In [2]:
layer = 4
out_batch_size = 4*4096

submodule_list = []
model_list = []
for step in [1, 128, 256, 512, 1000, 2000, 4000, 8000, 16000, 32000, 64000, 143000]:
    model = LanguageModel(
        "EleutherAI/pythia-70m", 
        revision=f"step{step}", 
        trust_remote_code=False, 
        device_map=device,
        torch_dtype=dtype,
        )
    for x in model.parameters():
        x.requires_grad = False
    model_list.append(model)
    submodule_list.append(model.gpt_neox.layers[layer])
    
activation_dim = 512

dataset = load_dataset(
    'Skylion007/openwebtext', 
    split='train', 
    streaming=True,
    trust_remote_code=True
    )

class CustomData():
    def __init__(self, dataset):
        self.data = iter(dataset)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)['text']

data = CustomData(dataset)

buffer = MultiModelActivationBuffer(
    data=data,
    model_list=model_list,
    submodule_list=submodule_list,
    d_submodule=activation_dim, # output dimension of the model component
    n_ctxs=256,  # you can set this higher or lower dependong on your available memory
    device=device,
    refresh_batch_size=128,
    out_batch_size=out_batch_size,
    remove_bos=True
)  # buffer will yield batches of tensors of dimension = submodule's output dimension

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


In [3]:
ae = AutoEncoderTopK.from_pretrained("/root/features_over_time/checkpoints/trainer_0/ae.pt", k=128, device=device)

  state_dict = t.load(path)


In [13]:
def get_tokens_and_acts():
    x, tokens = buffer.get_seq_batch()
    _, top_acts, top_inds, _ = ae.encode(x, return_topk=True) 
    return tokens, top_acts, top_inds

In [15]:
tokens, f

torch.Size([128, 127]) torch.Size([128, 127, 128]) torch.Size([128, 127, 128])


In [44]:
n_batches = 10
tokens = []
top_acts = []
top_inds = []
for _ in tqdm(range(n_batches)):
    x, tokens_batch = buffer.get_seq_batch()
    _, top_acts_batch, top_inds_batch, _ = ae.encode(x, return_topk=True) 
    top_acts.append(top_acts_batch)
    tokens.append(tokens_batch)
    top_inds.append(top_inds_batch)

tokens = t.cat(tokens)
top_acts = t.cat(top_acts)
top_inds = t.cat(top_inds)

print(tokens.shape, top_acts.shape, top_inds.shape)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:08<00:00,  1.17it/s]

torch.Size([1280, 127]) torch.Size([1280, 127, 128]) torch.Size([1280, 127, 128])





In [40]:
def display_top_contexts(
    model,
    tokens,
    top_acts,
    top_inds,
    feature_id,
    num_to_show=10,
    left_context=10,
    right_context=10,
    min_opacity=0.2
):
    """
    Display the top activating contexts for a feature using HTML with translucent highlighting.
    
    Args:
        model: HuggingFace model with tokenizer
        tokens: tensor of shape [batch, n_ctx] containing token ids
        top_acts: tensor of shape [batch, n_ctx, k] containing activation values
        top_inds: tensor of shape [batch, n_ctx, k] containing feature indices
        feature_id: which feature to analyze
        num_to_show: number of top contexts to display
        left_context: number of tokens to show before the max activation
        right_context: number of tokens to show after the max activation
        min_opacity: minimum opacity for highlighting
    """
    import torch
    from IPython.display import HTML
    import html
    
    def get_color(act_value, max_act):
        """Generate translucent red color based on activation value."""
        if act_value == 0:
            return "rgba(0, 0, 0, 0)"
        opacity = min(1.0, max(min_opacity, act_value / max_act))
        return f"rgba(255, 0, 0, {opacity})"
    
    html_output = """
    <div style="font-family: monospace; background-color: #1a1a1a; padding: 20px;">
        <style>
            .token {
                display: inline;
                padding: 0;
                margin: 0;
                position: relative;
            }
            .token:hover .tooltip {
                display: block;
            }
            .tooltip {
                display: none;
                position: absolute;
                background: #333;
                color: white;
                padding: 0px 0px;
                border-radius: 0px;
                font-size: 12px;
                bottom: 100%;
                left: 50%;
                transform: translateX(-50%);
                white-space: nowrap;
                z-index: 1;
            }
            .context-box {
                margin: 20px 0;
                padding: 10px;
                border: 1px solid #333;
                border-radius: 4px;
                background: #252525;
            }
            .text-container {
                font-size: 0;
                word-spacing: 0;
                letter-spacing: 0;
            }
            .activation-label {
                color: #888;
                margin-bottom: 8px;
            }
        </style>
    """
    
    # Find positions where this feature appears in top_inds
    batch_idxs, pos_idxs, k_idxs = torch.where(top_inds == feature_id)
    
    # Get the activation values for this feature
    acts = top_acts[batch_idxs, pos_idxs, k_idxs]
    
    # Sort by activation value
    sorted_idxs = torch.argsort(acts, descending=True)
    top_positions = sorted_idxs[:num_to_show]
    
    # For each top position
    for i, pos_idx in enumerate(top_positions):
        batch_idx = batch_idxs[pos_idx]
        tok_pos = pos_idxs[pos_idx]
        activation = acts[pos_idx].item()
        
        # Get context window
        start_idx = max(0, tok_pos - left_context)
        end_idx = min(tokens.shape[1], tok_pos + right_context + 1)
        context_tokens = tokens[batch_idx, start_idx:end_idx]
        
        html_output += f"""
        <div class="context-box">
            <div class="activation-label">activation: {activation:.4f}</div>
            <div class="text-container">
        """
        
        # Process each token in the context
        for j, token in enumerate(context_tokens):
            text = model.tokenizer.decode([token])
            pos = start_idx + j
            
            if pos == tok_pos:
                # Max activation token
                color = get_color(activation, activation)
            elif pos in pos_idxs[batch_idxs == batch_idx]:
                # Other activations of this feature
                local_act_idx = torch.where((batch_idxs == batch_idx) & (pos_idxs == pos))[0][0]
                local_act = acts[local_act_idx].item()
                color = get_color(local_act, activation)
            else:
                # Regular context
                color = "rgba(0, 0, 0, 0)"
            
            tooltip_text = f"Activation: {activation:.4f}" if pos == tok_pos else ""
            
            html_output += f"""
                <span class="token" style="background-color: {color}; color: #fff; font-size: 16px;">
                    {html.escape(text)}
                    <span class="tooltip">{tooltip_text}</span>
                </span>"""
        
        html_output += """
            </div>
        </div>
        """
    
    html_output += "</div>"
    return HTML(html_output)

In [None]:
display_top_contexts_for_feature(
    model_list[0],
    tokens,
    top_acts,
    top_inds,
    feature_id=10,
    num_to_show=10,
    left_context=20,
    right_context=5
)