# No position transformer experiment

## Setup

In [1]:
%pip install setuptools transformer_lens matplotlib ipywidgets plotly ipykernel nbformat

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
from transformer_lens import EasyTransformer, EasyTransformerConfig
from transformer_lens.utils import to_numpy
import tqdm.auto as tqdm
import plotly.express as px

## Defining the model

In [3]:
torch.set_default_device("mps")
cfg = EasyTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
)
model = EasyTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

### Ablation

We are deactivating the position embedding to see how it affects the attention mechanism and next token prediction by the transformer.

In [4]:
def deactivate_position(model: EasyTransformer):
    model.pos_embed.W_pos.data[:] = 0.
    model.pos_embed.W_pos.requires_grad = False

deactivate_position(model)

### Plotting utilities

In [5]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y":yaxis, "x":xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()
    
def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

## Define synthetic data generator and loss function

In [6]:
def data_generator(cfg: EasyTransformerConfig, seed: int, batch_size: int, inc_bos_token: bool=True):
    torch.manual_seed(seed)
    while True:
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if inc_bos_token:
            x[:,0] = 0
        yield x

In [7]:
def loss_fn(logits, tokens, return_per_token=False):
    # logits: [batch, position, d_vocab]
    # tokens: [batch, position]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if return_per_token:
        return -correct_log_probs
    return -correct_log_probs.mean()

In [8]:
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:,1,0] = 10.
test_logits[:,2,1] = 10.
test_logits[:,3,2] = 10.
test_logits[:,4,3] = 10.

print(loss_fn(test_logits, test_tokens, True))
print(loss_fn(test_logits, test_tokens, False))

tensor([[0.0005, 0.0004, 0.0007, 0.0005]], device='mps:0')
tensor(0.0005, device='mps:0')


## Model training

In [9]:
batch_size = 256
epochs = 4000
lr = 1e-4
max_grad_norm = 1.0
wd = 0.1
betas = (0.9, 0.95)
optim = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda i: max(i/100, 1.))

data_loader = data_generator(cfg=cfg, seed=137, batch_size=batch_size)

In [10]:
losses = []
model.zero_grad()
for epoch in tqdm.tqdm(range(epochs)):
    tokens = next(data_loader)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    optim.step()
    optim.zero_grad()
    scheduler.step()
    losses.append(loss.item())
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}: {loss.item()}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch: 0: 6.0289130210876465
Epoch: 100: 5.647429466247559
Epoch: 200: 5.4475250244140625
Epoch: 300: 5.0785651206970215
Epoch: 400: 4.129546642303467
Epoch: 500: 2.7287404537200928
Epoch: 600: 1.4616104364395142
Epoch: 700: 0.6579561829566956
Epoch: 800: 0.27724456787109375
Epoch: 900: 0.12798823416233063
Epoch: 1000: 0.07009649276733398
Epoch: 1100: 0.06959736347198486
Epoch: 1200: 0.05661126598715782
Epoch: 1300: 0.0629037469625473
Epoch: 1400: 0.03376340866088867
Epoch: 1500: 0.027074169367551804
Epoch: 1600: 0.02489548549056053
Epoch: 1700: 0.01888277754187584
Epoch: 1800: 0.02196997031569481
Epoch: 1900: 0.015225996263325214
Epoch: 2000: 0.047929421067237854
Epoch: 2100: 0.032715655863285065
Epoch: 2200: 0.02097431942820549
Epoch: 2300: 0.03473001718521118
Epoch: 2400: 0.0059524280950427055
Epoch: 2500: 0.007289866916835308
Epoch: 2600: 0.021471310406923294
Epoch: 2700: 0.015927957370877266
Epoch: 2800: 0.012854551896452904
Epoch: 2900: 0.025960108265280724
Epoch: 3000: 0.0103785

In [11]:
px.line(losses, labels={"x": "Epoch", "y": "Loss"})

## Model interpretability

In [12]:
model.pos_embed.W_pos.norm()

tensor(0., device='mps:0')

### Look at attention patterns

In [13]:
big_data_loader = data_generator(cfg, 137, 4000)
big_tokens = next(big_data_loader)
logits, cache = model.run_with_cache(big_tokens)
print(f"Logits: {loss_fn(logits, big_tokens).item()}")

Logits: 0.00855222623795271


In [14]:
batch_index = 0
tokens = big_tokens[batch_index]
imshow(cache["attn", 0].mean([0, 1]), title="Layer 0 Attention Pattern", height=500, width=500)
imshow(cache["attn", 1].mean([0, 1]), title="Layer 1 Attention Pattern", height=500, width=500)

### Look at how different bits of the model directly contribute to the logits

In [15]:
resid_components = [cache["embed"], cache["attn_out", 0], cache["mlp_out", 0], cache["attn_out", 1], cache["mlp_out", 1]]
labels = ["E", "A0", "M0", "A1", "M1"]
resid_stack = torch.stack(resid_components, 0)
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)
print(resid_stack.shape)

torch.Size([5, 4000, 50, 64])


In [16]:
batch_index = 0
fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U
logit_components = resid_stack[:,batch_index] @ fold_W_U/cache["ln_final.hook_scale"][batch_index]
print(logit_components.shape)

torch.Size([5, 50, 300])


In [17]:
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
data = logit_components[:, torch.arange(1, model.cfg.n_ctx), tokens[:-1]]
line(data.T, line_labels=labels)

## Analysis

### Folding in layer norm

### Understand Attention layer 0

In [18]:
analysis_cfg = EasyTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LNPre",
    init_weights=False,
)
analysis_model = EasyTransformer(analysis_cfg)
state_dict = model.state_dict()
analysis_model.load_and_process_state_dict(state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True)
deactivate_position(analysis_model)

In [19]:
logits, analysis_logits = model(big_tokens), analysis_model(big_tokens)

print(loss_fn(analysis_logits, big_tokens))
print(loss_fn(logits, big_tokens))

tensor(0.0086, device='mps:0', grad_fn=<NegBackward0>)
tensor(0.0086, device='mps:0', grad_fn=<NegBackward0>)


In [20]:
QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T # analysis_model.W_E @ analysis_model.blocks[0].attn.W_Q[0] @ analysis_model.blocks[0].attn.W_K[0].T @ analysis_model.W_E.T

imshow(QK, xaxis="Key", yaxis="Query", title="Full QK circuit for attn 0")

In [21]:
OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0] #analysis_model.W_E @ analysis_model.blocks[0].attn.W_V[0] @ analysis_model.blocks[0].attn.W_O[0].T @ analysis_model.blocks[0].mlp.W_in

imshow(OV, xaxis="Neuron", yaxis="Vocab", title="Full OV circuit for attn 0")

### Understand MLP from layer 0

### Understand Attention layer 1

### Understand MLP from layer 1