In this notebook, I validate a part of my implementation of the LLaMa 7B model by comparing the LLaMA transformer block with my implementation of the LLaMA transformer block.

In [1]:
import torch as t
import sys
import os
notebook_path = os.path.abspath('')
project_root = os.path.join(notebook_path, '..')
sys.path.append(project_root)

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")

t.save(model.model.layers[0], 'llama7b_block.pt')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Comparing Llama 7B components with my implementation

In [2]:
from src.models import Llama7BModel
import torch as t

my_model = Llama7BModel()

llama_block = t.load('llama7b_block.pt')
my_block = my_model.transformer_blocks[0]

In [3]:
my_block.norm_layer1.weight = llama_block.input_layernorm.weight
my_block.norm_layer2.weight = llama_block.post_attention_layernorm.weight

my_block.mha_block.linear_q.weight = t.nn.Parameter(llama_block.self_attn.q_proj.weight)
my_block.mha_block.linear_k.weight = t.nn.Parameter(llama_block.self_attn.k_proj.weight)
my_block.mha_block.linear_v.weight = t.nn.Parameter(llama_block.self_attn.v_proj.weight)
my_block.mha_block.linear_o.weight = t.nn.Parameter(llama_block.self_attn.o_proj.weight)

my_block.mlp_block.linear_gate.weight = t.nn.Parameter(llama_block.mlp.gate_proj.weight)
my_block.mlp_block.linear_up.weight = t.nn.Parameter(llama_block.mlp.up_proj.weight)
my_block.mlp_block.linear_down.weight = t.nn.Parameter(llama_block.mlp.down_proj.weight)


### MLP blocks

In [4]:
x = t.randn((1,8,4096))
t.allclose(my_block.mlp_block(x), llama_block.mlp(x))

True

### Multihead attention


In [5]:
x = t.randn((1,8,4096))

seq_len = x.shape[-2]
att_mask = t.where(t.arange(seq_len).unsqueeze(1) < t.arange(seq_len), -t.inf, 0)

m_off_out = my_block.mha_block(x)
llama_off_out = llama_block.self_attn(x)[0]

t.allclose(my_block.mha_block(x), llama_block.self_attn(x)[0], atol=1e-6)

True

### Overall transformer block

In [6]:
x = t.randn((1,8,4096))
seq_len = x.shape[-2]
att_mask = t.where(t.arange(seq_len).unsqueeze(1) < t.arange(seq_len), -t.inf, 0)

# My transformer block adds an attention mask automatically, while the llama one is added by the surrounding model running code
t.allclose(my_block(x), llama_block(x, attention_mask=att_mask.unsqueeze(0).unsqueeze(0))[0], atol=1e-6)

True