In [1]:
import torch
import math


#Norm层
class LlamaRMSNorm(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(1024))

    def forward(self, x):
        #[4, 125, 1024] -> [4, 125, 1]
        var = x.pow(2).mean(2, keepdim=True)

        #差不多相当于x除以自身的绝对值的均值,相当于一种缩放
        #计算结果的均值总是在-1到1之间
        #[4, 125, 1024] * [4, 125, 1] -> [4, 125, 1024]
        x = x * (var + 1e-5).rsqrt()

        #[1024] * [4, 125, 1024] -> [4, 125, 1024]
        return self.weight * x


LlamaRMSNorm()(torch.randn(4, 125, 1024)).shape

torch.Size([4, 125, 1024])

In [2]:
#计算结果是常量,有必要的话可以保存起来节省计算资源
@torch.no_grad()
def llama_rotary_embedding(lens):
    #[0.0000, 0.0625, 0.1250, 0.1875, 0.2500, 0.3125, 0.3750, 0.4375, 0.5000, 0.5625, 0.6250, 0.6875, 0.7500, 0.8125, 0.8750, 0.9375]
    inv_freq = torch.arange(0, 32, 2) / 32

    #[1.0000e+00, 4.4037e-01, 1.9392e-01, 8.5397e-02, 3.7606e-02, 1.6560e-02, 7.2927e-03, 3.2114e-03, 1.4142e-03, 6.2277e-04, 2.7425e-04, 1.2077e-04, 5.3183e-05, 2.3420e-05, 1.0313e-05, 4.5417e-06]
    inv_freq = 1.0 / (50_0000.0**inv_freq)

    #[16] -> [1, 16, 1]
    inv_freq = inv_freq.reshape(1, 16, 1)

    #[1, 1, 16]
    position_ids = torch.arange(lens).reshape(1, 1, -1).float()

    #[1, 16, 1] * [1, 16, 1] -> [1, 16, 16]
    freqs = inv_freq.matmul(position_ids).transpose(1, 2)

    #[1, 16, 16+16] -> [1, 16, 32]
    emb = torch.cat((freqs, freqs), 2)

    return emb.cos(), emb.sin()


cos, sin = llama_rotary_embedding(16)

cos.shape, sin.shape

(torch.Size([1, 16, 32]), torch.Size([1, 16, 32]))

In [3]:
#简单线性层
class LlamaMLP(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.gate_proj = torch.nn.Linear(1024, 14336, bias=False)
        self.up_proj = torch.nn.Linear(1024, 14336, bias=False)
        self.down_proj = torch.nn.Linear(14336, 1024, bias=False)
        self.act_fn = torch.nn.SiLU()

    def forward(self, x):
        #[4, 125, 1024] -> [4, 125, 14336]
        left = self.act_fn(self.gate_proj(x))

        #[4, 125, 1024] -> [4, 125, 14336]
        right = self.up_proj(x)

        #[4, 125, 14336] -> [4, 125, 1024]
        return self.down_proj(left * right)


LlamaMLP()(torch.randn(4, 125, 1024)).shape

torch.Size([4, 125, 1024])

In [4]:
def apply_rotary_pos_emb(x, cos, sin):
    #x -> [4, 32, 125, 32]
    #sin -> [1, 125, 32]
    #cos -> [1, 125, 32]

    def rotate_half(x):
        #x -> [4, 32, 125, 32]

        #[4, 32, 125, 32] -> [4, 32, 125, 16]
        left = x[..., :16]
        right = -x[..., 16:]

        #[4, 32, 125, 16+16] -> [4, 32, 125, 32]
        return torch.cat((right, left), -1)

    #[1, 125, 32] -> [1, 1, 125, 32]
    cos = cos.unsqueeze(1)
    #[1, 125, 32] -> [1, 1, 125, 32]
    sin = sin.unsqueeze(1)

    #[4, 32, 125, 32] * [1, 1, 125, 32] + [4, 32, 125, 32] * [1, 1, 125, 32] -> [4, 32, 125, 32]
    x = (x * cos) + (rotate_half(x) * sin)

    return x


input = {
    'x': torch.randn(4, 32, 125, 32),
    'sin': torch.randn(1, 125, 32),
    'cos': torch.randn(1, 125, 32)
}
apply_rotary_pos_emb(**input).shape

torch.Size([4, 32, 125, 32])

In [5]:
def repeat_kv(x):
    shape = list(x.shape)
    shape[1] *= 4
    #[4, 2, 125, 32] -> [4, 2, 4, 125, 32] -> [4, 8, 125, 32]
    return x.unsqueeze(2).repeat(1, 1, 4, 1, 1).reshape(shape)


repeat_kv(torch.randn(4, 2, 125, 32)).shape

torch.Size([4, 8, 125, 32])

In [6]:
#根据attention_mask获取注意力遮罩
#遮罩值为0表示保留,min_value表示丢弃
#遮罩的用法是和注意力矩阵相加后再求softmax
def get_causal_mask(attention_mask):
    # attention_mask -> [4, 125]

    b, lens = attention_mask.shape
    min_value = -1e15

    #上三角矩阵,对角线以上为min_value,对角线以下为0,对角线为0
    #[4, 1, 125, 125]
    causal_mask = torch.full((lens, lens), min_value).triu(diagonal=1)
    causal_mask = causal_mask.reshape(1, 1, lens, lens).repeat(b, 1, 1, 1)
    causal_mask = causal_mask.to(attention_mask.device)

    # 是pad的位置填充为min_value
    # [4, 125] -> [4, 1, 1, 125]
    mask = attention_mask.reshape(b, 1, 1, lens) == 0
    # [4, 1, 125, 125]
    causal_mask = causal_mask.masked_fill(mask, min_value)

    return causal_mask


get_causal_mask(torch.ones(4, 125).long()).shape

torch.Size([4, 1, 125, 125])

In [7]:
#注意力层
class LlamaAttention(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(1024, 1024, bias=False)
        self.k_proj = torch.nn.Linear(1024, 256, bias=False)
        self.v_proj = torch.nn.Linear(1024, 256, bias=False)
        self.o_proj = torch.nn.Linear(1024, 1024, bias=False)

    def forward(self, hidden_states, attention_mask):
        # hidden_states -> [4, 125, 1024]
        # attention_mask -> [4, 125]

        b, lens, _ = hidden_states.shape

        #线性投影,并拆分成多头注意力
        # [4, 125, 1024] -> [4, 125, 1024] -> [4, 125, 32, 32] -> [4, 32, 125, 32]
        q = self.q_proj(hidden_states).reshape(b, lens, 32, 32).transpose(1, 2)
        # [4, 125, 1024] -> [4, 125, 256] -> [4, 125, 8, 32] -> [4, 8, 125, 32]
        k = self.k_proj(hidden_states).reshape(b, lens, 8, 32).transpose(1, 2)
        # [4, 125, 1024] -> [4, 125, 256] -> [4, 125, 8, 32] -> [4, 8, 125, 32]
        v = self.v_proj(hidden_states).reshape(b, lens, 8, 32).transpose(1, 2)

        #计算位置编码
        # [1, 125, 32],[1, 125, 32]
        cos, sin = llama_rotary_embedding(lens)
        cos, sin = cos.to(hidden_states.device), sin.to(hidden_states.device)

        #在q,k上应用位置编码
        #[4, 32, 125, 32] -> [4, 32, 125, 32]
        q = apply_rotary_pos_emb(q, cos, sin)
        #[4, 8, 125, 32] -> [4, 8, 125, 32]
        k = apply_rotary_pos_emb(k, cos, sin)

        #k,v复制4分
        # [4, 8, 125, 32] -> [4, 32, 125, 32]
        k = repeat_kv(k)
        # [4, 8, 125, 32] -> [4, 32, 125, 32]
        v = repeat_kv(v)

        #q,k,v连乘,计算注意力
        # [4, 32, 125, 32] * [4, 32, 32, 125] -> [4, 32, 125, 125]
        attn = q.matmul(k.transpose(2, 3)) / math.sqrt(32)

        #根据attention_mask获得注意力遮罩
        #[4, 125] -> [4, 1, 125, 125]
        attention_mask = get_causal_mask(attention_mask)

        #应用注意力遮罩
        # [4, 32, 125, 125] + [4, 1, 125, 125] -> [4, 32, 125, 125]
        attn = (attn + attention_mask).softmax(3)

        #q,k,v连乘,计算注意力
        # [4, 32, 125, 125] * [4, 32, 125, 32] -> [4, 32, 125, 32]
        attn = attn.matmul(v)

        #合并多头注意力
        # [4, 32, 125, 32] -> [4, 125, 32, 32] -> [4, 125, 1024]
        attn = attn.transpose(1, 2).reshape(b, lens, 1024)

        #线性输出
        # [4, 125, 1024] -> [4, 125, 1024]
        attn = self.o_proj(attn)

        return attn


input = {
    'hidden_states': torch.randn(4, 125, 1024),
    'attention_mask': torch.ones(4, 125)
}
LlamaAttention()(**input).shape

torch.Size([4, 125, 1024])

In [8]:
class LlamaDecoderLayer(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.self_attn = LlamaAttention()
        self.mlp = LlamaMLP()
        self.input_layernorm = LlamaRMSNorm()
        self.post_attention_layernorm = LlamaRMSNorm()

    def forward(self, hidden_states, attention_mask):
        #hidden_states -> [4, 125, 1024]
        #attention_mask -> [4, 125]

        res = hidden_states

        #norm
        #[4, 125, 1024] -> [4, 125, 1024]
        hidden_states = self.input_layernorm(hidden_states)

        #计算注意力,短接
        #[4, 125, 1024],[4, 125] + [4, 125, 1024] -> [4, 125, 1024]
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       attention_mask=attention_mask) + res

        res = hidden_states

        #norm
        #[4, 125, 1024] -> [4, 125, 1024]
        hidden_states = self.post_attention_layernorm(hidden_states)

        #线性计算,短接
        #[4, 125, 1024] + [4, 125, 1024] -> [4, 125, 1024]
        hidden_states = self.mlp(hidden_states) + res

        return hidden_states


input = {
    'hidden_states': torch.randn(4, 125, 1024),
    'attention_mask': torch.ones(4, 125).long()
}
LlamaDecoderLayer()(**input).shape

torch.Size([4, 125, 1024])

In [9]:
class LlamaModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.embed_tokens = torch.nn.Embedding(128256, 1024, None)
        self.layers = torch.nn.ModuleList(
            [LlamaDecoderLayer() for _ in range(4)])
        self.norm = LlamaRMSNorm()

    def forward(self, input_ids, attention_mask):
        #input_ids -> [4, 125]
        #attention_mask -> [4, 125]

        #编码
        #[4, 125] -> [4, 125, 1024]
        hidden_states = self.embed_tokens(input_ids)

        #n层计算
        for layer in self.layers:
            #[4, 125, 1024] -> [4, 125, 1024]
            hidden_states = layer(hidden_states, attention_mask=attention_mask)

        #norm
        #[4, 125, 1024] -> [4, 125, 1024]
        hidden_states = self.norm(hidden_states)

        return hidden_states


input = {
    'input_ids': torch.randint(100, 50000, [4, 125]),
    'attention_mask': torch.ones(4, 125).long(),
}

input['attention_mask'][:, 120:] = 0

LlamaModel()(**input).shape

torch.Size([4, 125, 1024])

In [10]:
class LlamaForCausalLM(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = LlamaModel()
        self.lm_head = torch.nn.Linear(1024, 128256, bias=False)

    def forward(self, input_ids, attention_mask, labels=None):
        #input_ids -> [4, 125]
        #attention_mask -> [4, 125]
        #labels -> [4, 125]

        #[4, 125] -> [4, 125, 1024]
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask)

        #[4, 125, 1024] -> [4, 125, 128256]
        logits = self.lm_head(logits)

        loss = None
        if labels is not None:
            shift_logits = logits[:, :-1].reshape(-1, 128256)
            shift_labels = labels[:, 1:].reshape(-1)
            loss = torch.nn.functional.cross_entropy(shift_logits,
                                                     shift_labels)

        return loss, logits


input = {
    'input_ids': torch.randint(100, 50000, [4, 125]),
    'attention_mask': torch.ones(4, 125).long(),
    'labels': torch.randint(100, 50000, [4, 125]),
}

input['attention_mask'][:, 120:] = 0

loss, logits = LlamaForCausalLM()(**input)

loss, logits.shape

(tensor(11.9310, grad_fn=<NllLossBackward0>), torch.Size([4, 125, 128256]))

In [11]:
# from transformers import LlamaConfig, LlamaForCausalLM as LlamaForCausalLM_Original

# #测试是否和官方模型的计算输出一样
# config = "{'vocab_size': 128256, 'max_position_embeddings': 8192, 'hidden_size': 1024, 'intermediate_size': 14336, 'num_hidden_layers': 4, 'num_attention_heads': 32, 'num_key_value_heads': 8, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-05, 'pretraining_tp': 1, 'use_cache': False, 'rope_theta': 500000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'bfloat16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['LlamaForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 128000, 'pad_token_id': None, 'eos_token_id': 128001, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'transformers_version': '4.38.2', 'model_type': 'llama'}"
# config = LlamaConfig.from_dict(eval(config))

# model_actor1 = LlamaForCausalLM_Original(config)
# model_actor2 = LlamaForCausalLM()

# model_actor2.load_state_dict(model_actor1.state_dict())

# input = {
#     'input_ids': torch.randint(100, 50000, [4, 125]),
#     'attention_mask': torch.ones(4, 125).long(),
#     'labels': torch.randint(100, 50000, [4, 125])
# }
# input['attention_mask'][:, 120:] = 0

# out = model_actor1(**input)
# loss, logits = model_actor2(**input)

# print(out.loss, out.logits.shape)
# print(loss, logits.shape)

# out.loss == loss, (out.logits == logits).all()

tensor(11.9328, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])
tensor(11.9328, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])


(tensor(True), tensor(True))