<a href="https://colab.research.google.com/github/fmars/n00bGPT/blob/main/colab/model_parity_debugging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

In [4]:
import torch
import transformers

In [152]:
def weight_copy(x,y,src,dst,trans):
  for a,b in zip(src, dst):
    print(f'{a} -> {b}')
    need_transpose = any([a.endswith(s) for s in trans])
    if need_transpose:
      y.state_dict()[b].copy_(x.state_dict()[a].T)
    else:
      y.state_dict()[b].copy_(x.state_dict()[a])

In [84]:
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

cfg = GPT2Config()
m_ori=transformers.AutoModel.from_pretrained('gpt2')
m_lmh= GPT2LMHeadModel(cfg).from_pretrained('gpt2')

In [109]:
cfg

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 50257
}

In [148]:
from dataclasses import dataclass
@dataclass
class ModelConfig:
  vocab_size: int = 50304
  block_size: int = 1024
  n_layer: int = 7
  n_head: int = 4
  emb_dim: int = 64
  bias: bool = True
  dropout: float = 0.0
  use_torch_mhattention: bool = True

  def __init__(self, gpt_cfg):
    self.vocab_size = gpt_cfg.vocab_size
    self.block_size = gpt_cfg.n_positions
    self.n_layer = gpt_cfg.n_layer
    self.n_head = gpt_cfg.n_head
    self.emb_dim = gpt_cfg.n_embd
    self.bias = True
    self.droput = 0.0
m_cfg = ModelConfig(cfg)

In [160]:
batch_size = 16
seq_len = 64
emb_dim = m_lmh.config.n_embd
n_head = m_lmh.config.n_head
bias = True
dropout = 0

# GPT2Attention <> MultiheadAttention <> torch.nn.MultiheadAttention

In [181]:
x = m_lmh.transformer.h[0].attn
print(type(x))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'>


In [87]:
class MultiheadAttention(torch.nn.Module):
  def __init__(self, emb_dim, n_head):
    super().__init__()
    self.emb_dim = emb_dim
    self.n_head = n_head
    self.head_dim = emb_dim // n_head
    assert self.head_dim * self.n_head == self.emb_dim
    self.bias = True

    self.attn = torch.nn.Linear(self.emb_dim, 3 * self.emb_dim, bias = self.bias)
    self.proj = torch.nn.Linear(self.emb_dim, self.emb_dim, bias=self.bias)

  def forward(self, x):
    # x: [n_batch, seq_len, emb_dim], assuming q=k=v=x in MultiheadAttentionSimple implementation
    n_batch, seq_len, emb_dim = x.shape
    # Vectorized/concated form of QW_q_i, KW_k_i, and VW_v_i
    attn = self.attn(x) # [n_batch, seq_len, 3*emb_dim]
    q,k,v = attn.split(self.emb_dim, dim=-1)
    # Reshape to per-head form
    q = q.view(n_batch, seq_len, self.n_head, self.head_dim).transpose(1,2)
    k = k.view(n_batch, seq_len, self.n_head, self.head_dim).transpose(1,2)
    v = v.view(n_batch, seq_len, self.n_head, self.head_dim).transpose(1,2)
    # Compute dot product attention, which input is [N,..., seq_len, emb_dim]
    y = torch.nn.functional.scaled_dot_product_attention(q,k,v, is_causal=True) # [n_batch, n_head, seq_len, head_dim]
    y = y.transpose(1,2).contiguous().view(n_batch, seq_len, emb_dim) # [n_batch, seq_len, emb_dim]
    y = self.proj(y)
    return y
y=MultiheadAttention(emb_dim, n_head)

In [88]:
z = torch.nn.MultiheadAttention(emb_dim, n_head,batch_first=True)

In [89]:
for k,v in x.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in y.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in z.state_dict().items():
  print(f'{k} -> {v.shape}')

c_attn.weight -> torch.Size([768, 2304])
c_attn.bias -> torch.Size([2304])
c_proj.weight -> torch.Size([768, 768])
c_proj.bias -> torch.Size([768])
------------------------------
attn.weight -> torch.Size([2304, 768])
attn.bias -> torch.Size([2304])
proj.weight -> torch.Size([768, 768])
proj.bias -> torch.Size([768])
------------------------------
in_proj_weight -> torch.Size([2304, 768])
in_proj_bias -> torch.Size([2304])
out_proj.weight -> torch.Size([768, 768])
out_proj.bias -> torch.Size([768])


In [90]:
src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in y.state_dict().keys()]
trans=['c_attn.weight', 'c_proj.weight']
weight_copy(x,y,src,dst,trans)

src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in z.state_dict().keys()]
trans=['c_attn.weight', 'c_proj.weight']
weight_copy(x,z,src,dst,trans)

c_attn.weight -> attn.weight
c_attn.bias -> attn.bias
c_proj.weight -> proj.weight
c_proj.bias -> proj.bias
c_attn.weight -> in_proj_weight
c_attn.bias -> in_proj_bias
c_proj.weight -> out_proj.weight
c_proj.bias -> out_proj.bias


In [91]:
q=k=v = torch.randn(batch_size, seq_len, emb_dim)
attn_mask = torch.triu(torch.ones(q.size(1),q.size(1)),diagonal=1)
attn_mask = attn_mask.bool()
o_x = x(q)[0]
o_y = y(q)
o_z = z(q,k,v, attn_mask=attn_mask,is_causal=True)[0]

print(f'x <> y {torch.max(o_x-o_y)} {torch.allclose(o_x, o_y,atol=1e-3)}')
print(f'x <> z {torch.max(o_x-o_z)} {torch.allclose(o_x, o_z,atol=1e-3)}')
print(f'y <> z {torch.max(o_y-o_z)} {torch.allclose(o_y, o_z,atol=1e-3)}')

x <> y 0.0002288818359375 True
x <> z 0.0002624988555908203 True
y <> z 0.00019073486328125 True


# GPT2MLP <> FeedForward

In [94]:
x = m_lmh.transformer.h[0].mlp
print(type(x))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2MLP'>


In [110]:
from transformers.activations import ACT2FN
class FeedForward(torch.nn.Module):
  def __init__(self, emb_dim, bias, dropout):
    super().__init__()
    self.context_proj_1 = torch.nn.Linear(emb_dim, 4 *emb_dim, bias=bias)
    self.gelu = torch.nn.GELU()
    self.gelu = ACT2FN['gelu_new']
    self.context_proj_2 = torch.nn.Linear(4*emb_dim, emb_dim, bias=bias)
    self.dropout = torch.nn.Dropout(dropout)

  def forward(self, x):
    x = self.context_proj_1(x)
    x = self.gelu(x)
    x = self.context_proj_2(x)
    x = self.dropout(x)
    return x

y = FeedForward(emb_dim, bias, dropout)

In [111]:
for k,v in x.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in y.state_dict().items():
  print(f'{k} -> {v.shape}')

c_fc.weight -> torch.Size([768, 3072])
c_fc.bias -> torch.Size([3072])
c_proj.weight -> torch.Size([3072, 768])
c_proj.bias -> torch.Size([768])
------------------------------
context_proj_1.weight -> torch.Size([3072, 768])
context_proj_1.bias -> torch.Size([3072])
context_proj_2.weight -> torch.Size([768, 3072])
context_proj_2.bias -> torch.Size([768])


In [112]:
src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in y.state_dict().keys()]
trans=['c_fc.weight', 'c_proj.weight']
weight_copy(x,y,src,dst,trans)

c_fc.weight -> context_proj_1.weight
c_fc.bias -> context_proj_1.bias
c_proj.weight -> context_proj_2.weight
c_proj.bias -> context_proj_2.bias


In [114]:
input = torch.randn(batch_size, seq_len, emb_dim)
o_x = x(input)
o_y = y(input)
print(f'x <> y {torch.max(o_x-o_y)} {torch.allclose(o_x, o_y)}')

x <> y 0.0 True


# GPT2Block <> Layer

In [115]:
x = m_lmh.transformer.h[0]
print(type(x))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2Block'>


In [145]:
class Layer(torch.nn.Module):
  def __init__(self, emb_dim, n_head, bias, dropout):
    super().__init__()
    self.ln_1 = torch.nn.LayerNorm(emb_dim)
    self.attn = MultiheadAttention(emb_dim, n_head)
    self.ln_2 = torch.nn.LayerNorm(emb_dim)
    self.feed_fwd = FeedForward(emb_dim, bias, dropout)

  def forward(self, x):
    residual = x
    x = self.ln_1(x)
    x = self.attn(x)
    x = x + residual
    residual = x
    x = self.ln_2(x)
    x = self.feed_fwd(x)
    x = x + residual
    return x
y = Layer(emb_dim, n_head, bias, dropout)

In [137]:
for k,v in x.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in y.state_dict().items():
  print(f'{k} -> {v.shape}')

ln_1.weight -> torch.Size([768])
ln_1.bias -> torch.Size([768])
attn.c_attn.weight -> torch.Size([768, 2304])
attn.c_attn.bias -> torch.Size([2304])
attn.c_proj.weight -> torch.Size([768, 768])
attn.c_proj.bias -> torch.Size([768])
ln_2.weight -> torch.Size([768])
ln_2.bias -> torch.Size([768])
mlp.c_fc.weight -> torch.Size([768, 3072])
mlp.c_fc.bias -> torch.Size([3072])
mlp.c_proj.weight -> torch.Size([3072, 768])
mlp.c_proj.bias -> torch.Size([768])
------------------------------
ln_1.weight -> torch.Size([768])
ln_1.bias -> torch.Size([768])
attn.attn.weight -> torch.Size([2304, 768])
attn.attn.bias -> torch.Size([2304])
attn.proj.weight -> torch.Size([768, 768])
attn.proj.bias -> torch.Size([768])
ln_2.weight -> torch.Size([768])
ln_2.bias -> torch.Size([768])
feed_fwd.context_proj_1.weight -> torch.Size([3072, 768])
feed_fwd.context_proj_1.bias -> torch.Size([3072])
feed_fwd.context_proj_2.weight -> torch.Size([768, 3072])
feed_fwd.context_proj_2.bias -> torch.Size([768])


In [138]:
src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in y.state_dict().keys()]
trans=['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight','mlp.c_proj.weight']
weight_copy(x,y,src,dst,trans)

ln_1.weight -> ln_1.weight
ln_1.bias -> ln_1.bias
attn.c_attn.weight -> attn.attn.weight
attn.c_attn.bias -> attn.attn.bias
attn.c_proj.weight -> attn.proj.weight
attn.c_proj.bias -> attn.proj.bias
ln_2.weight -> ln_2.weight
ln_2.bias -> ln_2.bias
mlp.c_fc.weight -> feed_fwd.context_proj_1.weight
mlp.c_fc.bias -> feed_fwd.context_proj_1.bias
mlp.c_proj.weight -> feed_fwd.context_proj_2.weight
mlp.c_proj.bias -> feed_fwd.context_proj_2.bias


In [144]:
input = torch.rand(batch_size, seq_len, emb_dim)
o_x = x(input)[0]
o_y = y(input)
print(f'x <> y {torch.max(o_x-o_y)} {torch.allclose(o_x, o_y, atol=1e-5)}')

x <> y 1.52587890625e-05 True


# GPT2Model <> N00bGPT

In [146]:
x = m_lmh.transformer
print(type(x))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2Model'>


In [157]:
class GPTLMHeadModel(torch.nn.Module):
  def __init__(self, cfg: ModelConfig):
    super().__init__()
    self.cfg = cfg
    self.emb_dim = self.cfg.emb_dim

    self.token_emb = torch.nn.Embedding(self.cfg.vocab_size, self.emb_dim)
    self.pos_emb = torch.nn.Embedding(self.cfg.block_size, self.emb_dim)
    self.dropout = torch.nn.Dropout(self.cfg.dropout)
    self.layers = torch.nn.ModuleList([Layer(cfg.emb_dim, cfg.n_head, cfg.bias, cfg.dropout) for _ in range(self.cfg.n_layer)])
    self.ln = torch.nn.LayerNorm(self.emb_dim)

    self.lang_model_head = torch.nn.Linear(self.emb_dim, self.cfg.vocab_size, bias=False) # Do we need to explicitly disable bias here?

  def base_forward(self, x):
    n_batch, seq_len = x.shape
    assert seq_len <= self.cfg.block_size
    pos = torch.arange(0, seq_len)
    if torch.cuda.is_available():
      pos = pos.to(torch.device("cuda"))

    token_emb = self.token_emb(x) # [n_batch, seq_len, emb_dim]
    pos_emb = self.pos_emb(pos) # [seq_len, emb_dim]
    x = self.dropout(token_emb+pos_emb)
    for layer in self.layers:
      x = layer(x)
    x = self.ln(x)
    return x

  def forward(self, x, targets=None):
    x = self.base_forward(x)

    if targets: # Training
      logits = self.lm_head(x)
      loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
    else: # Inference
      logits = self.lang_model_head(x[:,[-1],:]) # Only need to compute lm_head for the last token
      loss = None
    return logits, loss

y = GPTLMHeadModel(m_cfg)

In [None]:
for k,v in x.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in y.state_dict().items():
  print(f'{k} -> {v.shape}')

In [None]:
src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in y.state_dict().keys()]
trans=['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight','mlp.c_proj.weight']
weight_copy(x,y,src,dst,trans)

In [166]:
input = torch.randint(0, m_cfg.vocab_size, (batch_size, seq_len))
o_x = x(input)[0]
o_y = y.base_forward(input)
print(f'x <> y {torch.max(o_x-o_y)} {torch.allclose(o_x, o_y, atol=1e-4)}')

x <> y 9.1552734375e-05 True


# GPT2LMHeadModel <> GPTLMHeadModel

In [167]:
x = m_lmh
print(type(x))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>


In [168]:
y = GPTLMHeadModel(m_cfg)

In [None]:
for k,v in x.state_dict().items():
  print(f'{k} -> {v.shape}')
print('-'*30)
for k,v in y.state_dict().items():
  print(f'{k} -> {v.shape}')

In [None]:
src = [str(i) for i in x.state_dict().keys()]
dst = [str(i) for i in y.state_dict().keys()]
trans=['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight','mlp.c_proj.weight']
weight_copy(x,y,src,dst,trans)

In [176]:
tokenizer=transformers.AutoTokenizer.from_pretrained('gpt2')
s='i am a software engineer and i like to'
input_ids=tokenizer(s,return_tensors='pt')['input_ids']

In [175]:
ids = input_ids
for i in range(20):
  outs=x(ids)[0][:,-1,:]
  _, id = torch.topk(outs,1)
  ids = torch.concat([ids, id],dim=1)
tokenizer.decode(ids[0])

'i am a software engineer and i like to write code. I am also a programmer. I am a big fan of the open source community.'

In [None]:
ids = input_ids
for i in range(20):
  outs=x(ids)[0][:,-1,:]
  _, id = torch.topk(outs,1)
  ids = torch.concat([ids, id],dim=1)
tokenizer.decode(ids[0])

In [180]:
ids = input_ids
for i in range(20):
  outs=y(ids)[0][:,-1,:]
  _, id = torch.topk(outs,1)
  ids = torch.concat([ids, id],dim=1)
tokenizer.decode(ids[0])

'i am a software engineer and i like to write code. I am also a programmer. I am a big fan of the open source community.'