Copied from Neel's notebooks etc.

In [7]:
#@title Neel's Setup
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    %pip install git+https://github.com/neelnanda-io/neel-plotly
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload








In [8]:
#@title Many useful imports

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

# Import stuff
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from neel_plotly import line, scatter, imshow, histogram

import circuitsvis as cv

Using renderer: colab


ModuleNotFoundError: No module named 'circuitsvis'

In [None]:
#@title Load and run transformer lens model on device

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small", device=device)
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = model.to_tokens(gpt2_text)
print(gpt2_tokens.device)
gpt2_logits, gpt2_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True)


In [None]:
#@title Size-Constrained CircuitViz Attention Pattern
# Plot doesn't display when initialized in a method ?
layer = 2
head = 4
text = "Access and plot the attention pattern of head L2H4 on the prompt"
cache = gpt2_cache

class SizeLimitedObject:
    def __init__(self, obj, max_width='500px', max_height='500px'):
        self.obj = obj
        self.max_width = max_width
        self.max_height = max_height

    def _repr_html_(self):
        return f"""
        <div style='max-width: {self.max_width}; max-height: {self.max_height}; padding: 20px;'>
            {self.obj._repr_html_()}
        </div>
        """

attention_pattern = cache["pattern", layer, "attn"]
str_tokens = model.to_str_tokens(text)

print(f"Layer {layer} Head {head} Attention Pattern:")
head_attention = cv.attention.attention_pattern(tokens=str_tokens, attention=attention_pattern[head - 1])

sized_viz = SizeLimitedObject(head_attention)
# sized_viz

In [None]:
#@title Ablation Hooks

# Other inputs for utils.get_act_name can be found here:
# https://github.com/luciaquirke/TransformerLens/blob/49edbec5424081182ef090265e2e6112153deffc/transformer_lens/utils.py

layer_to_ablate = 0
head_index_to_ablate = 8
text = gpt2_text
model = model

def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    # print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

def mlp_ablation_hook(
    value: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos d_model"]:
    value[:, :, :] = 0.
    return value

def get_mlp_mean_ablation_hook(batched_cache, act_name):
    # mean over all batches
    mean = batched_cache[act_name].mean(dim=0)
    print(mean)

    def mlp_mean_ablation_hook(
        value: Float[torch.Tensor, "batch pos d_model"],
        hook: HookPoint
    ) -> Float[torch.Tensor, "batch pos d_model"]:
        # print(f"Shape of the value tensor: {value.shape}")
        value[:, :, :] = mean
        return value

    return mlp_mean_ablation_hook

# original_loss = model(gpt2_tokens, return_type="loss")
# with model.hooks(fwd_hooks=[(utils.get_act_name("v", layer_to_ablate), head_ablation_hook)]):
#     ablated_head_loss = model(text, return_type="loss")

# with model.hooks(fwd_hooks=[(utils.get_act_name("pre", layer_to_ablate), mlp_ablation_hook)]):
#     ablated_mlp_loss = model(text, return_type="loss")

# print(f"Original Loss: {original_loss.item():.3f}")
# print(f"Ablated Head Loss: {ablated_head_loss.item():.3f}")
# print(f"Ablated MLP Loss: {ablated_mlp_loss.item():.3f}")

In [None]:
#@title Print Tensor First Few Elements

def head(tensor):
    slices = [slice(0, 3) for _ in tensor.shape]
    print(tensor[slices])

def upper(tensor):
    rows, cols = torch.tril_indices(tensor.size(-2), tensor.size(-1))
    result = tensor[..., rows, cols]
    print(result)

# head(torch.rand(100, 100, 100, 100))

In [None]:
#@title Line, Imshow/Heatmap, Scatter, Histogram

def line(x, y, line_labels=None, xaxis="", yaxis="", title="", **kwargs):
    df = pd.DataFrame({'x': x, 'y': y})
    fig = px.line(df, x='x', y='y', title=title)
    labels = {"x":xaxis, "y":yaxis}
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()
  
def imshow(tensor, xaxis="", yaxis="", **kwargs):
    plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

def scatter(x, y, labels=None, xaxis="", yaxis="", title="", **kwargs):
    df = pd.DataFrame({xaxis: x, yaxis: y})
    fig = px.scatter(df, x=xaxis, y=yaxis, title=title, **kwargs)
    if labels:
        for c, label in enumerate(labels):
            fig.data[c].name = label
    fig.show()

def histogram(data):
    df = pd.DataFrame({'Data': data})
    fig = px.histogram(df, x='Data')
    fig.show()

# # histogram example
# histogram(np.random.randn(1000))

# # line examples
# x = np.linspace(0, 50, 50, dtype="int")
# y = np.random.rand(50)
# line(x, y, xaxis="x values", yaxis="y values", title="Gaussian Random Variables")

# x = np.linspace(0, 10, 100)
# y = np.sin(x)
# line(x, y, xaxis="x values", yaxis="y values", title="Sine Wave")

# # imshow example
# tensor = np.random.rand(10, 10)
# imshow(tensor, xaxis="X-axis", yaxis="Y-axis")

# scatter example
# x = np.linspace(0, 50, 50, dtype="int")
# y = np.random.rand(50)
# scatter(x, y, xaxis="x values", yaxis="y values", title="Scatter Plot")

In [None]:
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader

# dataset = load_dataset('ag_news', split='test')

# # find maximum number of tokens
# def find_length(example):
#     tokens = model.to_tokens(example['text'])
#     return {"length": tokens.shape[1]}
# dataset = dataset.map(find_length)
# max_length = max(dataset['length']) # 260
max_length = 260

# Find number of entries at each pos
pos_counts = torch.zeros(max_length)
def add_entries(example):
    tokens = model.to_tokens(example['text'])
    pos_counts[:len(tokens[0])] += 1
dataset.map(add_entries)
print(pos_counts.shape)

act_name = utils.get_act_name("pre", 0)

test_text = dataset[0].get('text')
tokens = model.to_tokens(test_text)
_, cache = model.run_with_cache(tokens)

# get pos counts over all batches, then use to get average activation at each position and dimension. 
# Or histogram of token lists and filter / truncate to a reasonable length

batch_size = 10
sum_acts = torch.zeros(1, max_length, model.cfg.d_mlp) # batch, pos, d_mlp

for i in range(0, 7590, batch_size):
    test_text = dataset[i, i + batch_size].get('text')[0]
    tokens = model.to_tokens(test_text)
    _, cache = model.run_with_cache(tokens)

    # Keep track of how many pos aren't 0 for averaging

    # add zeros for the unused tensors.
    acts = cache[act_name]
    padded_acts = torch.nn.functional.pad(acts, (0, 0, 0, sum_acts.shape[1]-acts.shape[1]))

    sum_acts += padded_acts[0]

    # another strategy: divide each act by the number of tokens at that position in the dataset, then add to average

In [None]:
pos_counts_expanded = pos_counts.unsqueeze(0).unsqueeze(-1)  # [1, pos, 1]
print(pos_counts[])
print(pos_counts_expanded.shape)
# print(pos_counts_expanded)

print(sum_acts.shape) # [batch, pos, d_mlp]
print(sum_acts.mean(0))
print(sum_acts[0][0][:5])
print(sum_acts / pos_counts_expanded)


# # gpt-small
# gpt2_tokens = gpt2_model.to_tokens(test_text)
# gpt_loss, gpt_cache = gpt2_model.run_with_cache(gpt2_tokens, return_type="loss")
# mean_ablated_gpt_loss = gpt2_model.run_with_hooks(
#     gpt2_tokens,
#     return_type="loss",
#     fwd_hooks=[(utils.get_act_name("pre", 0), get_mlp_mean_ablation_hook(gpt_cache, utils.get_act_name("pre", 0)))])

# print(mean_ablated_gpt_loss)