<a href="https://colab.research.google.com/github/ducndh/refusal-orthogonal-vector/blob/main/Attention_Initialization_Experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
    !pip install einops
    !pip install git+https://github.com/neelnanda-io/Easy-Transformer@no-position-experiment
except:
    IN_COLAB = False

from easy_transformer import EasyTransformer, EasyTransformerConfig
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import tqdm.auto as tqdm
import einops
from easy_transformer.utils import to_numpy

Collecting git+https://github.com/neelnanda-io/Easy-Transformer@no-position-experiment
  Cloning https://github.com/neelnanda-io/Easy-Transformer (to revision no-position-experiment) to /tmp/pip-req-build-7o4yx1qj
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer /tmp/pip-req-build-7o4yx1qj
  Running command git checkout -b no-position-experiment --track origin/no-position-experiment
  Switched to a new branch 'no-position-experiment'
  Branch 'no-position-experiment' set up to track remote branch 'no-position-experiment' from 'origin'.
  Resolved https://github.com/neelnanda-io/Easy-Transformer to commit 2b91e7ec49854f74363e5699ae4f8f99a595e65c
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets (from easy_transformer==0.1.0)
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting wandb (from easy_transformer==0.1.0)
  Downloading wandb-0.18.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_6

Some plotting code. Wrappers around Plotly, not important to understand.

In [2]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y":yaxis, "x":xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()

def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

# Model Training

## Setup

### Define data + Loss function

In [73]:
def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):
    torch.manual_seed(seed)
    while True:
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if incl_bos_token:
            x[:, 0] = 0
        yield x
data_generator = make_data_generator(cfg, 2)
print(next(data_generator))

tensor([[  0,  93,  34, 155, 274, 116, 114, 248,  68,   3, 298,  83, 194,  20,
           8, 133,  32,  66,  62,  73, 210, 273,  46, 243, 104, 232, 161, 125,
         123, 251,   7,   4, 115, 127,  21,   1,  89, 142,   6,  15, 298, 251,
          88, 229, 108, 114,  23,  88,   3, 265],
        [  0, 118,  46, 274, 105, 268, 131,  35,  19,  58, 226, 278,  27,  25,
         276, 180, 164,   4,  95,  27,  74, 201, 105,  65,  80, 185,  44, 258,
         105,  60,  58,  47, 126,  60, 294, 253, 258, 136,  29, 101, 258,  77,
          80, 180, 159, 169, 122, 117,  27, 194]])


In [74]:
def loss_fn(logits, tokens, return_per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if return_per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [75]:
# Test the loss function works
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:, 1, 0]=10.
test_logits[:, 2, 1]=10.
test_logits[:, 3, 2]=10.
test_logits[:, 4, 3]=10.
print(loss_fn(test_logits, test_tokens, return_per_token=True))
print(loss_fn(test_logits, test_tokens, return_per_token=False))

tensor([[0.0004, 0.0003, 0.0031, 0.0005]])
tensor(0.0011)


### Defining the Model

In [206]:
cfg = EasyTransformerConfig(
    n_layers = 2,
    d_model=64,
    d_head=16,
    n_heads=4,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn='relu',
    normalization_type="LN",
)
model = EasyTransformer(cfg)

Moving model to device:  cpu


In [207]:
model.blocks[0].state_dict()

OrderedDict([('ln1.w',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
             ('ln1.b',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
             ('ln2.w',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1

### Model Editting

In [208]:
n_heads_duplicate = 2
layer_start, layer_end = 0, 2
for layer in range(layer_start, layer_end):
  for head in range(1, n_heads_duplicate):
    model.blocks[layer].attn.W_Q[head].data[:] = model.blocks[layer].attn.W_Q[0].data[:].clone()
    model.blocks[layer].attn.W_K[head].data[:] = model.blocks[layer].attn.W_K[0].data[:].clone()
    model.blocks[layer].attn.W_V[head].data[:] = model.blocks[layer].attn.W_V[0].data[:].clone()
    model.blocks[layer].attn.W_O[head].data[:] = model.blocks[layer].attn.W_O[0].data[:].clone()
model.blocks[layer_start].attn.W_Q[0] - model.blocks[layer_start].attn.W_Q[n_heads_duplicate-1]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [209]:
old_weight = model.blocks[layer_start].attn.W_Q[0].clone()
old_weight

tensor([[ 0.0253,  0.2366, -0.1061,  ..., -0.0652,  0.0280, -0.2112],
        [ 0.0003, -0.0793, -0.0032,  ...,  0.0725, -0.1783, -0.0779],
        [ 0.0454, -0.0644,  0.0376,  ...,  0.0997,  0.2580,  0.0604],
        ...,
        [-0.0304, -0.0751, -0.0753,  ..., -0.0020, -0.0477,  0.0103],
        [ 0.0294,  0.0202,  0.0554,  ..., -0.0466,  0.0026,  0.0712],
        [ 0.2720, -0.0647, -0.1073,  ...,  0.0254,  0.0256,  0.0658]],
       grad_fn=<CloneBackward0>)

In [151]:
model.blocks[0].mlp.W_in.data[:] = 0.1
model.blocks[0].mlp.W_out.data[:] = 0.1
model.blocks[0].mlp.state_dict()

OrderedDict([('W_in',
              tensor([[0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
                      [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
                      [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
                      ...,
                      [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
                      [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
                      [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000]])),
             ('b_in',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

### Setup Optimizer


In [210]:
batch_size = 256
num_epochs = 100
lr = 1e-2
betas = (0.9, 0.99)
# max_grad_norm = 1.0
wd = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i/100, 1.))

data_loader = make_data_generator(cfg, batch_size)

## Model Training

In [211]:
losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(data_loader)
    # tokens = tokens.cuda()
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    losses.append(loss.item())
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: {loss.item()}')
px.line(losses, labels={"x":"Epoch", "y":"Loss"})

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0: 5.993844985961914
Epoch 10: 5.779531002044678
Epoch 20: 5.600656986236572
Epoch 30: 5.554670810699463
Epoch 40: 4.964794158935547
Epoch 50: 2.237194061279297
Epoch 60: 0.2767258584499359
Epoch 70: 0.030183283612132072
Epoch 80: 0.007830281741917133
Epoch 90: 0.0035802454221993685


In [212]:
new_weight = model.blocks[layer_start].attn.W_Q[0]
new_weight - old_weight

tensor([[ 0.0334,  0.1961, -0.1751,  ...,  0.2123, -0.1080, -0.1708],
        [-0.1812, -0.1723,  0.1234,  ..., -0.1161,  0.0986, -0.1645],
        [ 0.1388, -0.0331, -0.0881,  ...,  0.0861,  0.0180, -0.1824],
        ...,
        [-0.1773, -0.0933, -0.1252,  ..., -0.0760, -0.1504,  0.1488],
        [-0.0458,  0.1069, -0.0401,  ..., -0.1764,  0.0539,  0.0773],
        [ 0.0998, -0.0449,  0.0050,  ..., -0.0018,  0.0317, -0.0397]],
       grad_fn=<SubBackward0>)

In [213]:
model.blocks[layer_start].attn.W_Q[0] - model.blocks[layer_start].attn.W_Q[n_heads_duplicate-1]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [214]:
model.blocks[layer_end-1].attn.W_Q[0] - model.blocks[layer_end-1].attn.W_Q[n_heads_duplicate-1]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SubBackward0>)

In [None]:
# torch.save(model.state_dict(), "no_pos_experiment_state_dict_v0.pth")