## Set up and import 

In [101]:
import io
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, GPTNeoXForCausalLM # AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForCausalLM, 

device = "cuda:0" 

In [2]:
df = pd.read_csv(f'{os.path.dirname(os.getcwd())}/data/jigsaw_toxic/train.csv') 
df.head(10)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
5,00025465d4725e87,"""\n\nCongratulations from me as well, use the ...",0,0,0,0,0,0
6,0002bcb3da6cb337,COCKSUCKER BEFORE YOU PISS AROUND ON MY WORK,1,1,1,0,1,0
7,00031b1e95af7921,Your vandalism to the Matt Shirvington article...,0,0,0,0,0,0
8,00037261f536c51d,Sorry if the word 'nonsense' was offensive to ...,0,0,0,0,0,0
9,00040093b2687caa,alignment on this subject and which are contra...,0,0,0,0,0,0


## Config

In [3]:
n = 1000
toxic = [df[df['toxic']==1]['comment_text'][i:i+1].to_string(index=False) for i in range(0,n)]
non_toxic = [df[df['toxic']==0]['comment_text'][i:i+1].to_string(index=False) for i in range(0,n)]

In [4]:
model_name = "eleuther-pythia2.8b-hh-sft"
layer = -1


## Load hidden states for base, sft and dpo

In [5]:
x = []
y = []
for idx, model_name in enumerate(["pythia-2.8b","eleuther-pythia2.8b-hh-sft","eleuther-pythia2.8b-hh-dpo"]):
  model_SAVEFOLD0 = f"{os.path.dirname(os.getcwd())}/outputs/{model_name}"
  model_SAVEFOLD = f"{model_SAVEFOLD0}/layer{layer}/"

  toxic_f = f"{model_SAVEFOLD}toxic_hs.npy"
  non_toxic_f = f"{model_SAVEFOLD}non_toxic_hs.npy"
  toxic_hs = np.load(toxic_f, mmap_mode = 'r')
  non_toxic_hs = np.load(non_toxic_f, mmap_mode = 'r')

  x.append(np.concatenate([non_toxic_hs, toxic_hs],0))
  y.append(np.concatenate([np.array(len(non_toxic)*[2*idx]), np.array(len(toxic)*[2*idx+1])],0))

x = np.concatenate([x[0], x[1], x[2]],0) 
y = np.concatenate([y[0], y[1], y[2]],0) 

In [6]:
x[0]

array([ 0.34286177,  0.72510064,  1.0699983 , ...,  0.29686835,
        1.97181   , -0.20555224], dtype=float32)

In [7]:
x[2000]

array([ 0.51742744,  0.92089754,  1.2923046 , ...,  0.40341443,
        2.0513492 , -0.16610062], dtype=float32)

In [8]:
x[4000]

array([ 0.255926  ,  1.0711516 ,  1.2434642 , ...,  0.5368451 ,
        2.0956845 , -0.24333812], dtype=float32)

In [9]:
direction = 0
base_sft_dot = np.empty([int(len(x)/3)])
base_sft_sim = np.empty([int(len(x)/3)])
base_dpo_dot = np.empty([int(len(x)/3)])
base_dpo_sim = np.empty([int(len(x)/3)])
sft_dpo_dot = np.empty([int(len(x)/3)])
sft_dpo_sim = np.empty([int(len(x)/3)])
for i in range(int(len(x)/3)):
    base_sft_dot[i] = np.dot(x[i], x[i+2*n])
    base_sft_sim[i] = base_sft_dot[i]/(np.linalg.norm(x[i])*np.linalg.norm(x[i+2*n]))
    # print(i, i+2*n)

for i in range(int(len(x)/3)):
    base_dpo_dot[i] = np.dot(x[i], x[i+4*n])
    base_dpo_sim[i] = base_dpo_dot[i]/(np.linalg.norm(x[i])*np.linalg.norm(x[i+4*n]))
    # print(i, i+4*n)

for i in range(int(len(x)/3)):
    sft_dpo_dot[i] = np.dot(x[i+2*n], x[i+4*n])
    sft_dpo_sim[i] = sft_dpo_dot[i]/(np.linalg.norm(x[i+2*n])*np.linalg.norm(x[i+4*n]))
    # print(i+2*n, i+4*n)

In [10]:
print(np.mean(base_sft_sim))
print(np.mean(base_dpo_sim))
print(np.mean(sft_dpo_sim))

0.998023691110947
0.996511646943349
0.9972572725990647


They're exactly the same

## Transformer lens

Forked TransformerLens and added functionality for HF lomahony pythia models. 

git clone https://github.com/lomahony/TransformerLens

cd TransformerLens

pip3 install -e .

In [11]:
!pip list

Package                   Version              Editable project location
------------------------- -------------------- ---------------------------
absl-py                   1.4.0
accelerate                0.21.0
aiohttp                   3.8.4
aiosignal                 1.3.1
anthropic                 0.3.11
antlr4-python3-runtime    4.9.3
anyio                     3.7.1
apex                      0.1
appdirs                   1.4.4
argon2-cffi               21.3.0
argon2-cffi-bindings      21.2.0
arrow                     1.2.3
asttokens                 2.2.1
async-lru                 2.0.4
async-timeout             4.0.2
attrs                     23.1.0
Babel                     2.12.1
backcall                  0.2.0
beartype                  0.14.1
beautifulsoup4            4.12.2
best-download             0.0.9
bleach                    6.0.0
blessed                   1.20.0
cachetools                5.3.1
certifi                   2019.11.28
cffi                      1.15.1
chardet

In [12]:
import sys
sys.path.append(f'{os.path.dirname( os.path.dirname(os.getcwd()))}/TransformerLens')

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
from fancy_einsum import einsum
from torchtyping import TensorType as TT
from typing import List, Optional, Callable, Tuple, Union
import functools
from tqdm import tqdm
from IPython.display import display

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens 
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f93659343d0>

In [97]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")

tl_model_base = transformer_lens.HookedTransformer.from_pretrained("pythia-70m")
tl_model_sft = transformer_lens.HookedTransformer.from_pretrained("lomahony/eleuther-pythia70m-hh-sft")
tl_model_dpo = transformer_lens.HookedTransformer.from_pretrained("lomahony/eleuther-pythia70m-hh-dpo")


Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m into HookedTransformer


Using pad_token, but it is not set yet.


Loaded pretrained model lomahony/eleuther-pythia70m-hh-sft into HookedTransformer


Using pad_token, but it is not set yet.


Loaded pretrained model lomahony/eleuther-pythia70m-hh-dpo into HookedTransformer


In [None]:
# hf_model_sft = GPTNeoXForCausalLM.from_pretrained("lomahony/eleuther-pythia70m-hh-sft").to("cpu")
# tl_model_sft = HookedTransformer.from_pretrained("lomahony/eleuther-pythia70m-hh-sft",
#     hf_model=hf_model_sft,
#     device="cpu",
#     tokenizer=tokenizer,
# ).to("cuda" if torch.cuda.is_available() else "cpu")

# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
#     model = HookedTransformer.from_pretrained(
#     name,
#     hf_model=hf_model,
#     fold_value_biases=False,
#     fold_ln=False,
#     tokenizer=tokenizer,
#     n_device=2,
#     move_to_device=False,
#     center_writing_weights=False,
# )

## Model weights

In [20]:
cache_dir = None

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m", cache_dir=cache_dir)
model_base = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m", cache_dir=cache_dir)
model_base.cuda()

# tokenizer = AutoTokenizer.from_pretrained("lomahony/eleuther-pythia70m-hh-sft", cache_dir=cache_dir)
model_sft = GPTNeoXForCausalLM.from_pretrained("lomahony/eleuther-pythia70m-hh-sft", cache_dir=cache_dir)
model_sft.cuda()

# tokenizer = AutoTokenizer.from_pretrained("lomahony/eleuther-pythia70m-hh-dpo", cache_dir=cache_dir)
model_dpo = GPTNeoXForCausalLM.from_pretrained("lomahony/eleuther-pythia70m-hh-dpo", cache_dir=cache_dir)
model_dpo.cuda()

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attention): GPTNeoXAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_out): Linear(in_features=512, out_features=50304, bias=False)
)

### Look at how much the layers and parameters change on average

In [None]:
i=0
for name, param in model_base.named_parameters():
    if param.requires_grad:
        print(type(param), name, param.size())
        i+=1
print(i)

In [None]:
list(model_base.parameters())[-2]

In [None]:
base_sft_diff=[]
base_dpo_diff=[]
sft_dpo_diff=[]
#  len(list(model_dpo.parameters())) # 76

for i in range( len( list( model_base.parameters() ) ) ):
    mean = np.mean( (list(model_base.parameters())[i] - list(model_sft.parameters())[i]).detach().cpu().numpy() )
    base_sft_diff.append( mean )

for i in range( len( list( model_base.parameters() ) ) ):
    mean = np.mean( (list(model_base.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() )
    base_dpo_diff.append( mean )
    
for i in range( len( list( model_sft.parameters() ) ) ):
    mean = np.mean( (list(model_sft.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() )
    sft_dpo_diff.append( mean )

In [None]:
plt.plot(base_sft_diff)
plt.plot(base_dpo_diff)
plt.plot(sft_dpo_diff)
plt.legend(["base_sft_diff", "base_dpo_diff", "sft_dpo_diff"])

In [None]:
base_sft_diff_abs = [abs(item) for item in base_sft_diff]
base_sft_diff_abs_max = list(np.argsort(base_sft_diff_abs)[::-1][:10])
print(base_sft_diff_abs_max)
names = [name for name, param in list(model_base.named_parameters())]
print([base_sft_diff_abs[i] for i in base_sft_diff_abs_max])
[names[i] for i in base_sft_diff_abs_max]

### What proportion of model weights haven't changed?

In [None]:
base_sft_nonzero=[]
base_dpo_nonzero=[]
sft_dpo_nonzero=[]
#  len(list(model_dpo.parameters())) # 76

for i in range( len( list( model_base.parameters() ) ) ):
    non_zero = ((list(model_base.parameters())[i] - list(model_sft.parameters())[i]).detach().cpu().numpy() != 0).sum()
    size = 1
    for dim in np.shape( ((list(model_base.parameters())[i] - list(model_sft.parameters())[i]).detach().cpu().numpy() != 0) ): 
        size *= dim
    base_sft_nonzero.append( non_zero/size )

for i in range( len( list( model_base.parameters() ) ) ):
    non_zero = ((list(model_base.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() != 0).sum()
    size = 1
    for dim in np.shape( ((list(model_base.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() != 0) ): 
        size *= dim
    base_dpo_nonzero.append( non_zero/size )

for i in range( len( list( model_sft.parameters() ) ) ):
    non_zero = ((list(model_sft.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() != 0).sum()
    size = 1
    for dim in np.shape( ((list(model_sft.parameters())[i] - list(model_dpo.parameters())[i]).detach().cpu().numpy() != 0) ): 
        size *= dim
    sft_dpo_nonzero.append( non_zero/size )

In [None]:
base_sft_nonzero_sorted = np.argsort(base_sft_nonzero)
[base_sft_nonzero[i] for i in base_sft_nonzero_sorted][:10]
names = [name for name, param in list(model_base.named_parameters())]
print("Layers least affected by fine tuning: ")
print([names[i] for i in base_sft_nonzero_sorted[:10]])
print("How much these layers are affected by fine tuning: ")
print([base_sft_nonzero[i] for i in base_sft_nonzero_sorted[:10]])

In [None]:
base_dpo_nonzero_sorted = np.argsort(base_dpo_nonzero)
[base_dpo_nonzero[i] for i in base_dpo_nonzero_sorted][:10]
names = [name for name, param in list(model_base.named_parameters())]
print("Layers least affected by fine tuning: ")
print([names[i] for i in base_dpo_nonzero_sorted[:10]])
print("How much these layers are affected by fine tuning: ")
print([base_dpo_nonzero[i] for i in base_dpo_nonzero_sorted[:10]])

In [None]:
sft_dpo_nonzero_sorted = np.argsort(sft_dpo_nonzero)
[sft_dpo_nonzero[i] for i in sft_dpo_nonzero_sorted][:10]
names = [name for name, param in list(model_base.named_parameters())]
print("Layers least affected by fine tuning: ")
print([names[i] for i in sft_dpo_nonzero_sorted[:10]])
print("How much these layers are affected by fine tuning: ")
print([sft_dpo_nonzero[i] for i in sft_dpo_nonzero_sorted[:10]])

## Look at model output

In [99]:
text = toxic[0]

In [135]:
# Se what HF model generates
inputs = tokenizer(text, return_tensors="pt").to(device)
print(f"inputs: {inputs}")
tokens = model_base.generate(**inputs, max_length=50)
print(f"tokens: {tokens}")
tokenizer.decode(tokens[0])

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


inputs: {'input_ids': tensor([[   36,  9466,  6971,  7519,   947,  8728, 31966,  8702,   367, 16808,
          6647, 15289,  8160, 17450, 37051]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
tokens: tensor([[   36,  9466,  6971,  7519,   947,  8728, 31966,  8702,   367, 16808,
          6647, 15289,  8160, 17450, 37051,    15,   187,   187,     3,    42,
          1353,   417,  1469,   281,   320,  2104,   281,   513,   326,    13,
           533,   309,  1353,   417,  1469,   281,   320,  2104,   281,   513,
           326,   449,   187,   187,     3,    42,  1353,   417,  1469,   281]],
       device='cuda:0')


'COCKSUCKER BEFORE YOU PISS AROUND ON MY WORK.\n\n"I\'m not going to be able to do that, but I\'m not going to be able to do that."\n\n"I\'m not going to'

In [145]:
tl_model_base.run_with_cache(tokens)[0].shape # torch.Size([1, 50, 50304])
tl_model_base.run_with_cache(inputs['input_ids'])[0].shape # torch.Size([1, 15, 50304])

torch.Size([1, 15, 50304])

In [146]:
tl_model_base.run_with_cache(inputs['input_ids'])[1] 

ActivationCache with keys ['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'bloc

In [143]:
tl_model_base.run_with_cache(tokens)[1]['ln_final.hook_normalized'].shape # torch.Size([1, 50, 512])


torch.Size([1, 50, 512])

## Looking at stuff...

#### Check out layers...

In [None]:
# list(model_dpo.children())

In [None]:
# i=0
# for name, module in model_dpo.named_modules():
#     # print(name, sep = " ")
#     if list(module.children()) == []:
#         print(name, end = " ")
#         print(module)
#         i+=1
# print(i)

In [None]:
import torch

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

In [None]:
# get_children(model_dpo) # len(get_children(model_dpo)) 70

#### Compare TransformerLens model to HF...

In [71]:
len(list(tl_model_base.parameters())) # 75
len(list(tl_model_sft.parameters())) # 75
len(list(model_base.parameters())) # 76
len(list(model_sft.parameters())) # 76
#.shape

76

In [81]:
i=0
for name, param in list( tl_model_sft.named_parameters() ):
    print(name, list( tl_model_sft.parameters())[i].shape)
    i+=1

embed.W_E torch.Size([50304, 512])
blocks.0.attn.W_Q torch.Size([8, 512, 64])
blocks.0.attn.W_K torch.Size([8, 512, 64])
blocks.0.attn.W_V torch.Size([8, 512, 64])
blocks.0.attn.W_O torch.Size([8, 64, 512])
blocks.0.attn.b_Q torch.Size([8, 64])
blocks.0.attn.b_K torch.Size([8, 64])
blocks.0.attn.b_V torch.Size([8, 64])
blocks.0.attn.b_O torch.Size([512])
blocks.0.mlp.W_in torch.Size([512, 2048])
blocks.0.mlp.b_in torch.Size([2048])
blocks.0.mlp.W_out torch.Size([2048, 512])
blocks.0.mlp.b_out torch.Size([512])
blocks.1.attn.W_Q torch.Size([8, 512, 64])
blocks.1.attn.W_K torch.Size([8, 512, 64])
blocks.1.attn.W_V torch.Size([8, 512, 64])
blocks.1.attn.W_O torch.Size([8, 64, 512])
blocks.1.attn.b_Q torch.Size([8, 64])
blocks.1.attn.b_K torch.Size([8, 64])
blocks.1.attn.b_V torch.Size([8, 64])
blocks.1.attn.b_O torch.Size([512])
blocks.1.mlp.W_in torch.Size([512, 2048])
blocks.1.mlp.b_in torch.Size([2048])
blocks.1.mlp.W_out torch.Size([2048, 512])
blocks.1.mlp.b_out torch.Size([512])
blo

In [82]:
i=0
for name, param in list( model_sft.named_parameters() ):
    print(name, list( model_sft.parameters())[i].shape)
    i+=1

gpt_neox.embed_in.weight torch.Size([50304, 512])
gpt_neox.layers.0.input_layernorm.weight torch.Size([512])
gpt_neox.layers.0.input_layernorm.bias torch.Size([512])
gpt_neox.layers.0.post_attention_layernorm.weight torch.Size([512])
gpt_neox.layers.0.post_attention_layernorm.bias torch.Size([512])
gpt_neox.layers.0.attention.query_key_value.weight torch.Size([1536, 512])
gpt_neox.layers.0.attention.query_key_value.bias torch.Size([1536])
gpt_neox.layers.0.attention.dense.weight torch.Size([512, 512])
gpt_neox.layers.0.attention.dense.bias torch.Size([512])
gpt_neox.layers.0.mlp.dense_h_to_4h.weight torch.Size([2048, 512])
gpt_neox.layers.0.mlp.dense_h_to_4h.bias torch.Size([2048])
gpt_neox.layers.0.mlp.dense_4h_to_h.weight torch.Size([512, 2048])
gpt_neox.layers.0.mlp.dense_4h_to_h.bias torch.Size([512])
gpt_neox.layers.1.input_layernorm.weight torch.Size([512])
gpt_neox.layers.1.input_layernorm.bias torch.Size([512])
gpt_neox.layers.1.post_attention_layernorm.weight torch.Size([512])

In [72]:
print( list( tl_model_sft.named_parameters() ) )
# list( tl_model_sft.named_parameters() )[-2]

[('embed.W_E', Parameter containing:
tensor([[-1.0292e-02, -6.1875e-03,  8.4457e-03,  ..., -9.7581e-03,
          3.2333e-02, -1.7990e-02],
        [-1.3339e-05,  5.1984e-06, -9.7028e-06,  ...,  4.3043e-06,
         -2.6511e-05, -2.4246e-05],
        [-4.0108e-02,  1.3055e-02, -5.5338e-02,  ...,  3.1266e-02,
          5.2706e-02, -1.3561e-02],
        ...,
        [-3.0224e-05,  3.6906e-06, -4.3874e-05,  ..., -1.2522e-05,
          1.8651e-05, -5.0713e-06],
        [-8.6227e-06,  3.7765e-07,  3.4591e-05,  ..., -2.3643e-05,
         -1.3749e-05,  2.0703e-05],
        [ 3.4700e-05,  1.7295e-05,  1.8845e-05,  ..., -2.4936e-06,
         -3.6261e-06, -5.4738e-06]], device='cuda:0', requires_grad=True)), ('blocks.0.attn.W_Q', Parameter containing:
tensor([[[-1.8906e-02, -4.6398e-03, -5.6957e-03,  ...,  1.8453e-02,
           4.5274e-02,  5.2218e-03],
         [-1.5819e-02,  1.4255e-03,  1.3373e-02,  ..., -2.1503e-02,
           6.9732e-03, -3.2739e-02],
         [ 4.3338e-03, -3.2973e-02, -8

In [73]:
print( list( model_sft.named_parameters() ) )
# print( list( model_sft.parameters() )[-3].detach().cpu().shape )

[('gpt_neox.embed_in.weight', Parameter containing:
tensor([[-9.8190e-03, -5.7144e-03,  8.9188e-03,  ..., -9.2850e-03,
          3.2806e-02, -1.7517e-02],
        [-1.4424e-05,  4.1127e-06, -1.0788e-05,  ...,  3.2187e-06,
         -2.7597e-05, -2.5332e-05],
        [-4.0737e-02,  1.2425e-02, -5.5967e-02,  ...,  3.0637e-02,
          5.2076e-02, -1.4191e-02],
        ...,
        [-3.0756e-05,  3.1590e-06, -4.4405e-05,  ..., -1.3053e-05,
          1.8120e-05, -5.6028e-06],
        [-9.2983e-06, -2.9802e-07,  3.3915e-05,  ..., -2.4319e-05,
         -1.4424e-05,  2.0027e-05],
        [ 3.3855e-05,  1.6451e-05,  1.8001e-05,  ..., -3.3379e-06,
         -4.4703e-06, -6.3181e-06]], device='cuda:0', requires_grad=True)), ('gpt_neox.layers.0.input_layernorm.weight', Parameter containing:
tensor([0.9737, 1.0712, 0.9863, 1.0640, 0.9192, 0.9982, 1.0652, 0.9435, 0.7604,
        0.7984, 0.9029, 0.8993, 0.9902, 0.9900, 1.0164, 1.0555, 1.0446, 0.9249,
        1.0220, 0.9027, 0.8915, 1.0156, 1.0350, 1.

Different structure. E.g., TL stores attn.W_Q, attn.W_K, attn.W_V separately, whereas pytorch concatenates. 

Some layers extracted match (but transposed)

In [86]:
print( list( tl_model_sft.named_parameters() )[-4] )


('blocks.5.mlp.W_out', Parameter containing:
tensor([[ 0.0368,  0.0445,  0.0370,  ..., -0.0435, -0.0058,  0.0598],
        [-0.0159,  0.0496, -0.0270,  ..., -0.0132, -0.0217,  0.0325],
        [ 0.0061, -0.0561, -0.0089,  ...,  0.0161,  0.0554, -0.0443],
        ...,
        [ 0.0836, -0.0677,  0.0185,  ..., -0.1025,  0.0264,  0.0230],
        [-0.0167,  0.0339,  0.0315,  ...,  0.0043, -0.0054, -0.0226],
        [ 0.0115, -0.0161, -0.0157,  ..., -0.0131,  0.0186,  0.0502]],
       device='cuda:0', requires_grad=True))


In [87]:
print( list( model_sft.named_parameters() )[-5] )


('gpt_neox.layers.5.mlp.dense_4h_to_h.weight', Parameter containing:
tensor([[ 0.0368, -0.0155,  0.0066,  ...,  0.0833, -0.0166,  0.0116],
        [ 0.0444,  0.0500, -0.0556,  ..., -0.0680,  0.0340, -0.0160],
        [ 0.0369, -0.0266, -0.0084,  ...,  0.0182,  0.0316, -0.0155],
        ...,
        [-0.0435, -0.0128,  0.0166,  ..., -0.1028,  0.0044, -0.0129],
        [-0.0059, -0.0213,  0.0559,  ...,  0.0261, -0.0053,  0.0187],
        [ 0.0598,  0.0329, -0.0438,  ...,  0.0227, -0.0225,  0.0503]],
       device='cuda:0', requires_grad=True))


In [90]:
tl_model_base
for parameters in tl_model_base.parameters():
    print(parameters)
tl_model_base

Parameter containing:
tensor([[-1.0292e-02, -6.1875e-03,  8.4457e-03,  ..., -9.7581e-03,
          3.2333e-02, -1.7990e-02],
        [-1.3339e-05,  5.1984e-06, -9.7028e-06,  ...,  4.3043e-06,
         -2.6511e-05, -2.4246e-05],
        [-4.0268e-02,  1.3306e-02, -5.5435e-02,  ...,  3.0991e-02,
          5.2964e-02, -1.3343e-02],
        ...,
        [-3.0224e-05,  3.6906e-06, -4.3874e-05,  ..., -1.2522e-05,
          1.8651e-05, -5.0713e-06],
        [-8.6227e-06,  3.7765e-07,  3.4591e-05,  ..., -2.3643e-05,
         -1.3749e-05,  2.0703e-05],
        [ 3.4700e-05,  1.7295e-05,  1.8845e-05,  ..., -2.4936e-06,
         -3.6261e-06, -5.4738e-06]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[[-1.9019e-02, -4.5817e-03, -5.8410e-03,  ...,  1.8321e-02,
           4.5304e-02,  5.3359e-03],
         [-1.5665e-02,  1.5674e-03,  1.3378e-02,  ..., -2.1674e-02,
           7.1694e-03, -3.2571e-02],
         [ 4.5031e-03, -3.2967e-02, -7.5577e-04,  ...,  8.2200e-03,
         

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint(

In [92]:
for parameters in tl_model_sft.parameters():
    print(parameters)
tl_model_sft

Parameter containing:
tensor([[-1.0292e-02, -6.1875e-03,  8.4457e-03,  ..., -9.7581e-03,
          3.2333e-02, -1.7990e-02],
        [-1.3339e-05,  5.1984e-06, -9.7028e-06,  ...,  4.3043e-06,
         -2.6511e-05, -2.4246e-05],
        [-4.0108e-02,  1.3055e-02, -5.5338e-02,  ...,  3.1266e-02,
          5.2706e-02, -1.3561e-02],
        ...,
        [-3.0224e-05,  3.6906e-06, -4.3874e-05,  ..., -1.2522e-05,
          1.8651e-05, -5.0713e-06],
        [-8.6227e-06,  3.7765e-07,  3.4591e-05,  ..., -2.3643e-05,
         -1.3749e-05,  2.0703e-05],
        [ 3.4700e-05,  1.7295e-05,  1.8845e-05,  ..., -2.4936e-06,
         -3.6261e-06, -5.4738e-06]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[[-1.8906e-02, -4.6398e-03, -5.6957e-03,  ...,  1.8453e-02,
           4.5274e-02,  5.2218e-03],
         [-1.5819e-02,  1.4255e-03,  1.3373e-02,  ..., -2.1503e-02,
           6.9732e-03, -3.2739e-02],
         [ 4.3338e-03, -3.2973e-02, -8.6588e-04,  ...,  8.2098e-03,
         

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint(

In [94]:
for parameters in tl_model_dpo.parameters():
    print(parameters)
tl_model_dpo

Parameter containing:
tensor([[-1.0292e-02, -6.1875e-03,  8.4457e-03,  ..., -9.7581e-03,
          3.2333e-02, -1.7990e-02],
        [-1.3339e-05,  5.1984e-06, -9.7028e-06,  ...,  4.3043e-06,
         -2.6511e-05, -2.4246e-05],
        [-3.9968e-02,  1.3042e-02, -5.5322e-02,  ...,  3.1405e-02,
          5.2649e-02, -1.3493e-02],
        ...,
        [-3.0224e-05,  3.6906e-06, -4.3874e-05,  ..., -1.2522e-05,
          1.8651e-05, -5.0713e-06],
        [-8.6227e-06,  3.7765e-07,  3.4591e-05,  ..., -2.3643e-05,
         -1.3749e-05,  2.0703e-05],
        [ 3.4700e-05,  1.7295e-05,  1.8845e-05,  ..., -2.4936e-06,
         -3.6261e-06, -5.4738e-06]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[[-1.8851e-02, -4.6803e-03, -5.7019e-03,  ...,  1.8470e-02,
           4.5205e-02,  5.4182e-03],
         [-1.5904e-02,  1.4855e-03,  1.3371e-02,  ..., -2.1468e-02,
           6.9852e-03, -3.2718e-02],
         [ 4.4306e-03, -3.2950e-02, -8.8395e-04,  ...,  8.1646e-03,
         

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint(