In [56]:
# langage model
from dataclasses import dataclass
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# @dataclass
# class LlamaConfig:
    # windows: B, T
    # input/output: V, D
    # attn: N, H, K
    # layers: int = 32

N, H, K = 32, 0, 0


class MHA(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(768, 768)
        self.k_proj = nn.Linear(768, 768)
        self.v_proj = nn.Linear(768, 768)
        self.o_proj = nn.Linear(768, 768)

class FFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.up_proj = nn.Linear(768, 768)
        self.gate_proj = nn.Linear(768, 768)
        self.down_proj = nn.Linear(768, 768)





class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = MHA()
        self.input_layernorm = nn.Linear(768, 768)
        self.mlp = FFN()
        self.post_attention_layernorm = nn.Linear(768, 768)


class Llama(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_tokens = nn.Linear(768, 768)
        self.layers = nn.ModuleList([Block() for _ in range(N)])
        self.norm = nn.Linear(768, 768)
        self.lm_head = nn.Linear(768, 768)

    def forward(self, X):
        return X
        # 1. embeddings: X_BT -> X_BTD

        # 2. blocks + unembed: Nx(X_BTD -> X_BTK -> X_BTD)

        # 3. logits: X_BTD -> XBTV

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained Llama2 model weights from huggingface"""
        assert model_type in {'meta-llama/Llama-2-7b-hf'}
        from transformers import LlamaModel
        print("loading weights from pretrained llama: %s" % model_type)

        # # n_layer, n_head and n_embd are determined from model_type
        # config_args = {
        #     'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
        #     'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
        #     'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
        #     'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        # }[model_type]
        # config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        # config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints

        # 1. model init: mom i want a llama. we have llama at home
        model_hf, model = LlamaModel.from_pretrained(model_type), Llama()
        sd_hf, sd = model_hf.state_dict(), model.state_dict()
        sd_hf_keys, sd_keys = sd_hf.keys(), sd.keys() # python isn't zero-alloc

        # 2. copy
        # - ensuring all params are aligned and match in names and shapes
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        # assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"

        for k in sd_hf_keys:
            # vanilla copy over the other parameters
            if k not in sd_keys:
                print(f"llama at home is missing {k}")
            # assert sd_hf[k].shape == sd[k].shape

            # if sd_hf[k].shape != sd[k].shape:
            #     print(f"expected shape for {k}: {sd_hf[k].shape}, actual: {sd[k].shape}")
            
            # with torch.no_grad():
            #     sd[k].copy_(sd_hf[k])




            # if any(k.endswith(w) for w in transposed):
            #     # special treatment for the Conv1D weights we need to transpose
            #     assert sd_hf[k].shape[::-1] == sd[k].shape
            #     with torch.no_grad():
            #         sd[k].copy_(sd_hf[k].t())
            # else:
            #     # vanilla copy over the other parameters
            #     assert sd_hf[k].shape == sd[k].shape
            #     with torch.no_grad():
            #         sd[k].copy_(sd_hf[k])

        return model_hf

# model
model = Llama.from_pretrained("meta-llama/Llama-2-7b-hf")
model.to(device)
print(f"loaded model to device {device}")

loading weights from pretrained llama: meta-llama/Llama-2-7b-hf


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.11s/it]


loaded model to device cuda


In [26]:
# inference loop

# tokenize
B, T_MAX = 5, 30
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokens_1N = tokenizer("Hello, I'm a language model,", return_tensors="pt").input_ids
tokens_BT = tokens_N.repeat(B, 1) # (B, T)
X_BT = tokens_BT.to(device)
print(X_BT)

torch.manual_seed(1337)
torch.cuda.manual_seed(1337)

generate_ids = model.generate(X_BT, max_length=T_MAX)
print(generate_ids)
# while X_BT.size(1) < T_MAX:
#     with torch.no_grad():
#         generate_ids = model.generate(X_BT, max_length=T_MAX)
#         print(generate_ids)
        # logits, probs, topk_probs
        # tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        # X_BT = torch.cat((X_BT, generate_ids), dim=1)
        
        # concat

tensor([[    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892]],
       device='cuda:0')
tensor([[    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892,
           322,   306, 29915, 29885,  1244,   304,  1371,   366,   411,   596,
         10432,  3271,  1287, 29889,    13, 12024, 29915, 29879,   679,  4687],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892,
           322,   306, 29915, 29885,  1244,   304,  1371,   366,   411,   596,
          2060, 29889,   306, 29915, 29885,   263, 10257,   297,  5164,  4235],
        [    1, 15043, 29892,   306, 29915, 29885,   263,  4086,  1904, 29892,
           322,   30