# Mechanistic analysis of HTML tags closing in GPT-2

We already know there're from [Gurnee et al. 2023](https://arxiv.org/abs/2305.01610) that there are neurons in transformers that activates for specific code text such as Python and HTML.
In this notebook I will explore how GPT-2 processes HTML code to find how it closes tags. The analysis is divided into two sections:

1. The mechanism that allow the general closure of a tag (the one that predict the `</` token)
2. The mechanism responsible to find the right tag to close (the one that, for example, predict `div` token after the previous)

In [2]:
from transformer_lens import HookedTransformer
import torch
from tqdm import tqdm

In [3]:
# Load model
model = HookedTransformer.from_pretrained('gpt2')
model.eval()
nl = len(model.blocks)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [4]:
# Some useful functions
@torch.no_grad()
def generate(model, prompt, max_new_tokens=10):
    tokens = model.to_tokens(prompt)
    for _ in range(max_new_tokens):
        output = model(tokens)
        next_token = torch.argmax(output[0, -1, :])
        tokens = torch.cat((tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=1)
    return model.tokenizer.decode(tokens[0], skip_special_tokens=True)

### Data

In [5]:
import requests

def scrape_wikipedia_html(url):
    # Send a GET request to the URL
    response = requests.get(url)
    
    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        # Return the HTML content of the page
        return response.text
    else:
        print("Failed to retrieve the page. Status code:", response.status_code)
        return None

# Example usage
url = "https://en.wikipedia.org/wiki/Special:Random"
html_content = [scrape_wikipedia_html(url) for i in tqdm(range(10))]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  3.41it/s]


In [6]:
from bs4 import BeautifulSoup

def clean_html(html_content):
    # Parse the HTML content
    soup = BeautifulSoup(html_content, 'html.parser')
    
    # Remove meta tags starting with "data"
    for tag in soup.find_all(lambda tag: tag.name == 'meta' and tag.has_attr('name') and tag['name'].startswith('data')):
        tag.decompose()
    
    # Replace all class and id attributes with empty strings
    for tag in soup.find_all():
        if 'class' in tag.attrs:
            tag['class'] = ''
        if 'id' in tag.attrs:
            tag['id'] = ''
        if 'rel' in tag.attrs:
            tag['rel'] = ''
        if 'title' in tag.attrs:
            tag['title'] = ''
        if 'style' in tag.attrs:
            tag['style'] = ''
        if 'href' in tag.attrs and (tag.name == 'a' or tag.name == 'link'):
            tag['href'] = 'www.wikipedia.org'
    
    # Return the cleaned HTML content
    return str(soup)

cleaned_html = [clean_html(i) for i in html_content]

In [7]:
from bs4 import Comment
def remove_footer(html_content):
    # Parse the HTML content
    soup = BeautifulSoup(html_content, 'html.parser')
    
    # Find and remove footer elements
    footer = soup.find('footer')
    if footer:
        footer.decompose()

    for style_tag in soup.find_all('style'):
        style_tag.decompose()

    for script_tag in soup.find_all('script'):
        script_tag.decompose()

    for comment in soup.find_all(text=lambda text: isinstance(text, Comment)):
        comment.extract()
    
    # Separate body and head
    body = soup.body.extract()
    head = soup.head.extract()
    
    return str(head) + '\n' + str(body)

head_and_body = [remove_footer(i) for i in cleaned_html]

  for comment in soup.find_all(text=lambda text: isinstance(text, Comment)):


In [28]:
# Text to tokens conversion
lines = head_and_body[0].split('\n')
str_tokens = [[]]
for line in lines:
    if len(str_tokens[-1] + model.to_str_tokens(line) + ['\n']) >= 512:
        str_tokens[-1] += [model.tokenizer.pad_token] * (511 - len(str_tokens[-1]))
        str_tokens.append([])

    str_tokens[-1] += model.to_str_tokens(line) + ['\n']

str_tokens[-1] += [model.tokenizer.pad_token] * (511 - len(str_tokens[-1]))
tokens = model.to_tokens([''.join(i) for i in str_tokens])

In [29]:
closing_tag_idx = [[i for i, token in enumerate(tokens) if token == '</'] for tokens in str_tokens]

In [30]:
tokens.shape

torch.Size([23, 512])

In [32]:
lines

['<head>',
 '<meta charset="utf-8"/>',
 '<title>Stits SA-7 Sky-Coupe - Wikipedia</title>',
 '',
 '',
 '<link href="www.wikipedia.org" rel=""/>',
 '',
 '<meta content="" name="ResourceLoaderDynamicStyles"/>',
 '<link href="www.wikipedia.org" rel=""/>',
 '<meta content="MediaWiki 1.42.0-wmf.25" name="generator"/>',
 '<meta content="origin" name="referrer"/>',
 '<meta content="origin-when-cross-origin" name="referrer"/>',
 '<meta content="max-image-preview:standard" name="robots"/>',
 '<meta content="telephone=no" name="format-detection"/>',
 '<meta content="https://upload.wikimedia.org/wikipedia/commons/thumb/c/c6/Empire_State_Aerosciences_Museum_-_Glenville%2C_New_York_%288158375346%29.jpg/1200px-Empire_State_Aerosciences_Museum_-_Glenville%2C_New_York_%288158375346%29.jpg" property="og:image"/>',
 '<meta content="1200" property="og:image:width"/>',
 '<meta content="900" property="og:image:height"/>',
 '<meta content="https://upload.wikimedia.org/wikipedia/commons/thumb/c/c6/Empire_Stat

## 1. General closure

Before starting all the experiment, let's spend a moment thinking about the necessary conditions for the model to produce the token `</`, which is very uncommon in natural language. They're two:

1. We are in the context of HTML text,
2. The text inside the tag is about to end.

While the first one is fairly easy to verify and we know there are neurons dedicated to that, the second is much more complex and I expect the model to not always be able to rightly predict the end-tag token. This is similar to predicting the `.` to end a phrase in natural language, it's easy in some cases but very hard in others.

In [33]:
import torch 
from sae_lens import LMSparseAutoencoderSessionloader
from huggingface_hub import hf_hub_download

layer = 8 # pick a layer you want.
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path = path
)
sparse_autoencoder.eval()

(…)2-small_blocks.8.hook_resid_pre_24576.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.33k [00:00<?, ?B/s]

: 

In [9]:
bs = 4
out = []

for i in range(0, len(tokens), bs):
    with torch.no_grad():
        output = model(tokens[i:i+bs])
    out.append(output)

out = torch.cat(out)

: 

In [8]:
from tqdm import tqdm

# Collect activations
bs = 16

activations = []
for b in tqdm(range(0, len(tokens), bs)):
    tokens_batch = tokens[b:b+bs] # [bs, p]
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens_batch)
    activations.append(torch.cat([cache[f"blocks.{i}.mlp.hook_post"][None, :, -1] for i in range(nl)]))

activations = torch.cat(activations, dim=1) # [nl, n, dm]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/63 [00:00<?, ?it/s]

100%|██████████| 63/63 [00:24<00:00,  2.57it/s]


In [8]:
# Build the classifier (I build one classifier for the last token position for each layer)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
import warnings
from sklearn.exceptions import ConvergenceWarning
import numpy as np

# Hide ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)

lr = LogisticRegression(penalty='l1', C=0.12, max_iter=500, solver='saga')

active_neurons = []
scores = []

for i in range(nl):
    X_train, X_test, y_train, y_test = train_test_split(activations[i].cpu().numpy(), data['label'], test_size=0.2, random_state=42)
    lr.fit(X_train, y_train)
    active_neurons.append(torch.tensor(np.nonzero(lr.coef_[0])[0], device='mps', dtype=torch.int32))
    scores.append((
        f1_score(y_test.values, lr.predict(X_test)),
        precision_score(y_test.values, lr.predict(X_test)),
        recall_score(y_test.values, lr.predict(X_test))
    ))
    print(f"Layer {i}\nP:{scores[-1][1]:.2f}\tR: {scores[-1][2]:.2f}\tF1: {scores[-1][0]:.2f}\nActive neurons: {sum(lr.coef_[0] != 0)}/{lr.coef_.shape[1]}\n")

Layer 0
P:1.00	R: 0.92	F1: 0.96
Active neurons: 3/3072

Layer 1
P:0.99	R: 0.97	F1: 0.98
Active neurons: 3/3072

Layer 2
P:1.00	R: 0.90	F1: 0.95
Active neurons: 3/3072

Layer 3
P:0.95	R: 0.86	F1: 0.90
Active neurons: 4/3072

Layer 4
P:0.93	R: 0.84	F1: 0.88
Active neurons: 3/3072

Layer 5
P:1.00	R: 0.81	F1: 0.90
Active neurons: 2/3072

Layer 6
P:1.00	R: 0.84	F1: 0.91
Active neurons: 3/3072

Layer 7
P:0.99	R: 0.87	F1: 0.93
Active neurons: 3/3072

Layer 8
P:0.99	R: 0.80	F1: 0.88
Active neurons: 4/3072

Layer 9
P:0.95	R: 0.81	F1: 0.88
Active neurons: 2/3072

Layer 10
P:0.99	R: 0.75	F1: 0.85
Active neurons: 2/3072

Layer 11
P:0.95	R: 0.87	F1: 0.91
Active neurons: 3/3072



Only a few neurons are active among more than three thousands and they provide good scores. More specifically, we see that the F1 score decreases for later layers; this is due to the decrease in recall which means that the neurons are more specific. They don't correspond to the general HTML feature but they could represents more precise parts of the code.

A natural question to ask ourself is whether we can touch these neurons to obtain HTML text. We can do that by attacching hooks to the model components. They allow to modify the values of neurons at generation time!

In [9]:
from jaxtyping import Float
from transformer_lens import ActivationCache
from transformer_lens.hook_points import HookPoint
from functools import partial
from transformer_lens import utils

def neurons_ablation_hook(
    x: Float[torch.Tensor, "batch pos d_mlp"],
    hook: HookPoint,
    neurons: list[int],
    act_store: list = None
) -> Float[torch.Tensor, "batch pos d_mlp"]:
    if act_store is not None:
        act_store.append(x[:, -1, neurons].cpu().numpy())
    x[:, -1, neurons] = 0
    return x

In [10]:
print(generate(model, "Title here"))
print(generate(model, "<h1>Title here"))

Title here.

The first thing I noticed about this
<h1>Title here</h1> <p>This is a


In [11]:
text_tokens = model.to_tokens('Title here')
html_tokens = model.to_tokens('<h1>Title here')

text_neurons_values = []
html_neurons_values = []

hooks = [(
    f"blocks.{l}.mlp.hook_post", 
    partial(neurons_ablation_hook, neurons=active_neurons[l], act_store=text_neurons_values)
    ) for l in range(nl)]

with torch.no_grad():
    logits = model.run_with_hooks(text_tokens, fwd_hooks=hooks)

hooks = [(
    f"blocks.{l}.mlp.hook_post", 
    partial(neurons_ablation_hook, neurons=active_neurons[l], act_store=html_neurons_values)
    ) for l in range(nl)]

with torch.no_grad():
    logits = model.run_with_hooks(html_tokens, fwd_hooks=hooks)

In [19]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=3, cols=4, subplot_titles=[f"Layer {i}" for i in range(nl)], shared_yaxes=True)

for i, (text, html) in enumerate(zip(text_neurons_values, html_neurons_values)):
    df = pd.DataFrame({
        'delta': text[0] - html[0],
        'neuron': active_neurons[i].cpu().numpy().astype(str)
    })
    fig.add_trace(go.Bar(x=df['neuron'], y=df['delta'], name='text', marker_color='blue'), row=i//4+1, col=i%4+1)

fig.update_layout(height=1000, width=1000, title_text="Neurons activations")
fig.show()

In [20]:
@torch.no_grad()
def generate_with_hooks(model, prompt, hooks, max_new_tokens=10):
    tokens = model.to_tokens(prompt)
    for _ in range(max_new_tokens):
        output = model.run_with_hooks(tokens, fwd_hooks=hooks)
        next_token = torch.argmax(output[0, -1, :])
        tokens = torch.cat((tokens, next_token.unsqueeze(0).unsqueeze(0)), dim=1)
    return model.tokenizer.decode(tokens[0], skip_special_tokens=True)

In [21]:
hooks = [(
    f"blocks.{l}.mlp.hook_post", 
    partial(neurons_ablation_hook, neurons=active_neurons[l])
    ) for l in range(nl)]

generate_with_hooks(model, "Title here", hooks)

'Title here.\n\nThis is a very good book.'

In [22]:
def neurons_patching_hook(
    x: Float[torch.Tensor, "batch pos d_mlp"],
    hook: HookPoint,
    neurons: list[int],
    patch: torch.Tensor,
) -> Float[torch.Tensor, "batch pos d_mlp"]:
    x[:, -1, neurons] = patch[None]
    return x

In [23]:
text_idx = torch.tensor((data['label'] == 0).values, dtype=bool)
html_idx = torch.tensor((data['label'] == 1).values, dtype=bool)

In [24]:
mean_text_activations_values = [activations[l, text_idx][:, active_neurons[l]].mean(0) for l in range(nl)]
mean_html_activations_values = [activations[l, html_idx][:, active_neurons[l]].mean(0) for l in range(nl)]

In [27]:
fig = make_subplots(rows=3, cols=4, subplot_titles=[f"Layer {i}" for i in range(nl)], shared_yaxes=True)

for i, (text, html) in enumerate(zip(mean_text_activations_values, mean_html_activations_values)):
    df = pd.DataFrame({
        'delta': text.cpu() - html.cpu(),
        'neuron': active_neurons[i].cpu().numpy().astype(str)
    })
    fig.add_trace(go.Bar(x=df['neuron'], y=df['delta'], name='text', marker_color='blue'), row=i//4+1, col=i%4+1)

fig.update_layout(height=1000, width=1000, title_text="Neurons activations")
fig.show()

In [28]:
hooks = [(
    f"blocks.{l}.mlp.hook_post", 
    partial(neurons_patching_hook, neurons=active_neurons[l], patch=mean_text_activations_values[l])
    ) for l in range(nl)]

generate_with_hooks(model, "<p>Title here", hooks, max_new_tokens=10)

'<p>Title here</p>TheWrap</p>The'

### Activation Patching

In  this part I'll explore the 

In [29]:
from jaxtyping import Float
from transformer_lens import ActivationCache
from transformer_lens.hook_points import HookPoint
from functools import partial
from transformer_lens import utils

def residual_stream_patching_hook(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    position: int,
    clean_cache: ActivationCache
) -> Float[torch.Tensor, "batch pos d_model"]:
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

def logits_to_logit_diff(model, logits, correct_answer, incorrect_answer):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

@torch.no_grad()
def activation_patching(model, clean_prompt, corrupted_prompt, correct_answer, incorrect_answer, component="hook_resid_post", run_corrupted=True):

    # By default this function runs the corrupted prompt and substitute the clean activations!
    if not run_corrupted:
        clean_prompt, corrupted_prompt = corrupted_prompt, clean_prompt
        correct_answer, incorrect_answer = incorrect_answer, correct_answer

    clean_tokens = model.to_tokens(clean_prompt)
    corrupted_tokens = model.to_tokens(corrupted_prompt)

    # We run on the clean prompt with the cache so we store activations to patch in later.
    clean_logits, clean_cache = model.run_with_cache(clean_tokens, device="mps")
    clean_logit_diff = logits_to_logit_diff(model, clean_logits, correct_answer=correct_answer, incorrect_answer=incorrect_answer)
    print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

    # We don't need to cache on the corrupted prompt.
    corrupted_logits = model(corrupted_tokens)
    corrupted_logit_diff = logits_to_logit_diff(model, corrupted_logits, correct_answer=correct_answer, incorrect_answer=incorrect_answer)
    print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

    # Run patching
    print("Running patching...")
     # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
    num_positions = len(clean_tokens[0])
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

    for layer in tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # Use functools.partial to create a temporary hook function with the position fixed
            temp_hook_fn = partial(residual_stream_patching_hook, position=position, clean_cache=clean_cache)
            # Run the model with the patching hook
            patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
                (f"blocks.{layer}.{component}", temp_hook_fn)
            ])
            # Calculate the logit difference
            patched_logit_diff = logits_to_logit_diff(model, patched_logits, correct_answer, incorrect_answer).detach()
            # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
            patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

    return patching_result

In [30]:
clean_prompt = "\nTitle here"
corrupted_prompt = "<h1>Title here"

In [31]:
mlp_patching = activation_patching(model, clean_prompt, corrupted_prompt, ".", "</", component="hook_mlp_out")
attn_patching = activation_patching(model, clean_prompt, corrupted_prompt, ".", "</", component="hook_attn_out")
rs_patching = activation_patching(model, clean_prompt, corrupted_prompt, ".", "</", component="hook_resid_post")

Clean logit difference: 9.583
Corrupted logit difference: -4.099
Running patching...


100%|██████████| 12/12 [00:02<00:00,  4.15it/s]


Clean logit difference: 9.583
Corrupted logit difference: -4.099
Running patching...


100%|██████████| 12/12 [00:02<00:00,  4.08it/s]


Clean logit difference: 9.583
Corrupted logit difference: -4.099
Running patching...


100%|██████████| 12/12 [00:02<00:00,  4.35it/s]


In [32]:
model.to_str_tokens(clean_prompt)

['<|endoftext|>', '\n', 'Title', ' here']

In [33]:
import plotly.express as px
from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=3, subplot_titles=("MLP", "Attention", "Residual Stream"))

fig.add_trace(px.imshow(
    mlp_patching.cpu().numpy()[:, 1:],
    x=model.to_str_tokens(clean_prompt)[1:],
    y=[f"Layer {i}" for i in range(model.cfg.n_layers)], 
    title="Patching result", width=800, height=600).data[0], row=1, col=1)

fig.add_trace(px.imshow(
    attn_patching.cpu().numpy()[:, 1:],
    x=model.to_str_tokens(clean_prompt)[1:],
    y=[f"Layer {i}" for i in range(model.cfg.n_layers)],
    title="Patching result", width=800, height=600).data[0], row=1, col=2)

fig.add_trace(px.imshow(
    rs_patching.cpu().numpy()[:, 1:],
    x=model.to_str_tokens(clean_prompt)[1:],
    y=[f"Layer {i}" for i in range(model.cfg.n_layers)],
    title="Patching result", width=800, height=600).data[0], row=1, col=3)

fig.update_layout(coloraxis=dict(colorscale='Blues'))
fig.update_traces(zmin=-1, zmax=1)
fig.update_layout(title_text="Patching results")
fig.show()

## 2. Retrieving the right tag

In [53]:
# Building the dataset
code_split = []
for c in code:
    split = c.split('</')
    split = ['</'.join(split[:i+1]) for i in range(len(split))]
    code_split.extend(split)

In [None]:
tqdm.pandas()
(pd.Series(code_split).apply(lambda x: x.replace('<!DOCTYPE html>\n<html lang="en">\n', '')).progress_apply(lambda x: generate(model, x, max_new_tokens=1)[len(x):]) == '</').mean()

### Sparse autoencoders