In [1]:
vocab_size = 256000
padding_idx = 0
hidden_size = 2304 // 4
num_hidden_layers = 26 // 4

hidden_size, num_hidden_layers

(576, 6)

In [2]:
import torch

def repeat_kv(x):
    #x -> [4, 4, 125, 256]

    #[4, 4, 125, 256] -> [4, 4, 2, 125, 256]
    x = x.unsqueeze(2).expand(-1, -1, 2, -1, -1)

    #[4, 4, 2, 125, 256] -> [4, 8, 125, 256]
    return x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3],
                     x.shape[4])


repeat_kv(torch.randn(4, 4, 125, 256)).shape

torch.Size([4, 8, 125, 256])

In [3]:
def apply_rotary_pos_emb(q, k, cos, sin):
    #q -> [4, 8, 125, 256]
    #k -> [4, 4, 125, 256]
    #cos -> [1, 125, 256]
    #sin -> [1, 125, 256]

    cos = cos.unsqueeze(1)
    sin = sin.unsqueeze(1)

    def rotate_half(x):
        #从最后一个维度上切分
        left = x[..., :x.shape[-1] // 2]
        right = x[..., x.shape[-1] // 2:]

        #左右交换顺序,右边部分符号取反,重新组合在一起
        return torch.cat((-right, left), dim=-1)

    #和两个三角函数分别加权求和,在qk中融入位置信息
    q = (q * cos) + (rotate_half(q) * sin)
    k = (k * cos) + (rotate_half(k) * sin)

    return q, k


out = apply_rotary_pos_emb(torch.randn(4, 8, 125, 256),
                           torch.randn(4, 4, 125, 256),
                           torch.randn(1, 125, 256), torch.randn(1, 125, 256))

out[0].shape, out[1].shape

(torch.Size([4, 8, 125, 256]), torch.Size([4, 4, 125, 256]))

In [4]:
class Gemma2RotaryEmbedding(torch.nn.Module):

    def __init__(self):
        super().__init__()
        #[128] tensor([1.0000e+00, 9.3057e-01, 8.6596e-01, ..., 1.2409e-04, 1.1548e-04, 1.0746e-04])
        inv_freq = 1.0 / (1_0000.0**(torch.arange(0, 256, 2) / 256))

        #[128] -> [1, 128, 1]
        inv_freq = inv_freq.reshape(1, -1, 1)

        self.register_buffer('inv_freq', tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, lens):
        position_ids = torch.arange(lens,
                                    device=self.inv_freq.device,
                                    dtype=self.inv_freq.dtype)
        position_ids = position_ids.reshape(1, 1, -1)

        #[1, 128, 1] * [1, 1, 125] -> [1, 128, 125]
        freqs = self.inv_freq.matmul(position_ids)

        #[1, 128, 125] -> [1, 125, 128]
        freqs = freqs.transpose(1, 2)

        #[1, 125, 128] -> [1, 125, 256]
        emb = torch.cat((freqs, freqs), dim=-1)

        #[1, 125, 256],[1, 125, 256]
        return emb.cos(), emb.sin()


out = Gemma2RotaryEmbedding()(125)

out[0].shape, out[1].shape

(torch.Size([1, 125, 256]), torch.Size([1, 125, 256]))

In [5]:
class Gemma2Attention(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.q_proj = torch.nn.Linear(hidden_size, 2048, bias=False)
        self.k_proj = torch.nn.Linear(hidden_size, 1024, bias=False)
        self.v_proj = torch.nn.Linear(hidden_size, 1024, bias=False)
        self.o_proj = torch.nn.Linear(2048, hidden_size, bias=False)
        self.rotary_emb = Gemma2RotaryEmbedding()

    def forward(self, hidden_states, attention_mask):
        #hidden_states -> [4, 125, hidden_size]
        #attention_mask -> [4, 1, 125, 125]

        b, lens, _ = hidden_states.size()

        #[4, 125, hidden_size] -> [4, 125, 2048]
        query_states = self.q_proj(hidden_states)
        #[4, 125, hidden_size] -> [4, 125, 1024]
        key_states = self.k_proj(hidden_states)
        #[4, 125, hidden_size] -> [4, 125, 1024]
        value_states = self.v_proj(hidden_states)

        #[4, 125, 2048] -> [4, 125, 8, 256] -> [4, 8, 125, 256]
        query_states = query_states.reshape(b, lens, 8, 256).transpose(1, 2)
        #[4, 125, 1024] -> [4, 125, 4, 256] -> [4, 4, 125, 256]
        key_states = key_states.reshape(b, lens, 4, 256).transpose(1, 2)
        #[4, 125, 1024] -> [4, 125, 4, 256] -> [4, 4, 125, 256]
        value_states = value_states.reshape(b, lens, 4, 256).transpose(1, 2)

        #[1, 125, 256],[1, 125, 256]
        cos, sin = self.rotary_emb(lens)

        #维度不变
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin)

        #[4, 4, 125, 256] -> [4, 8, 125, 256]
        key_states = repeat_kv(key_states)
        #[4, 4, 125, 256] -> [4, 8, 125, 256]
        value_states = repeat_kv(value_states)

        #[4, 8, 125, 256] * [4, 8, 256, 125] -> [4, 8, 125, 125]
        atten = query_states.matmul(key_states.transpose(2, 3))
        atten = atten * 256**-0.5

        #维度不变
        atten = (atten / 50.0).tanh() * 50.0
        atten = atten + attention_mask

        #[4, 8, 125, 125] * [4, 8, 125, 256] -> [4, 8, 125, 256]
        atten = atten.softmax(dim=-1).matmul(value_states)

        #[4, 8, 125, 256] -> [4, 125, 8, 256] -> [4, 125, 8, 2048]
        atten = atten.transpose(1, 2).reshape(b, lens, -1)

        #[4, 125, 8, 2048] -> [4, 125, 8, hidden_size]
        atten = self.o_proj(atten)

        return atten


Gemma2Attention()(torch.randn(4, 125, hidden_size),
                  torch.randn(4, 1, 125, 125)).shape

torch.Size([4, 125, 576])

In [6]:
class Gemma2MLP(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.gate_proj = torch.nn.Linear(hidden_size, 9216, bias=False)
        self.up_proj = torch.nn.Linear(hidden_size, 9216, bias=False)
        self.down_proj = torch.nn.Linear(9216, hidden_size, bias=False)

    def forward(self, x):
        #x -> [4, 125, hidden_size]

        #[4, 125, hidden_size] -> [4, 125, 9216]
        left = torch.nn.functional.gelu((self.gate_proj(x)),
                                        approximate='tanh')

        #[4, 125, hidden_size] -> [4, 125, 9216]
        right = self.up_proj(x)

        #[4, 125, 9216] -> [4, 125, hidden_size]
        return self.down_proj(left * right)


Gemma2MLP()(torch.randn(4, 125, hidden_size)).shape

torch.Size([4, 125, 576])

In [7]:
class Gemma2RMSNorm(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x):
        #x -> [4, 125, hidden_size]

        #约等于二范数的倒数,区别是公式中的mean改成sum
        norm2_reciprocal_like = x.pow(2).mean(-1, keepdim=True) + 1e-6
        norm2_reciprocal_like = norm2_reciprocal_like.rsqrt()

        #约等于除以自己的二范数,起到规范化的作用
        x = x * norm2_reciprocal_like

        #线性投影
        #[4, 125, hidden_size]
        x = x * (1.0 + self.weight)

        return x


Gemma2RMSNorm()(torch.randn(4, 125, hidden_size)).shape

torch.Size([4, 125, 576])

In [8]:
class Gemma2DecoderLayer(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.self_attn = Gemma2Attention()
        self.mlp = Gemma2MLP()
        self.input_layernorm = Gemma2RMSNorm()
        self.post_attention_layernorm = Gemma2RMSNorm()
        self.pre_feedforward_layernorm = Gemma2RMSNorm()
        self.post_feedforward_layernorm = Gemma2RMSNorm()

    def forward(self, hidden_states, attention_mask):
        #hidden_states -> [4, 125, hidden_size]
        #attention_mask -> [4, 1, 125, 125]

        res = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        #维度不变
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       attention_mask=attention_mask)

        hidden_states = self.post_attention_layernorm(hidden_states) + res

        res = hidden_states

        hidden_states = self.pre_feedforward_layernorm(hidden_states)

        #维度不变
        hidden_states = self.mlp(hidden_states)

        hidden_states = self.post_feedforward_layernorm(hidden_states) + res

        return hidden_states


Gemma2DecoderLayer()(torch.randn(4, 125, hidden_size),
                     torch.randn(4, 1, 125, 125)).shape

torch.Size([4, 125, 576])

In [9]:
def get_mask(attention_mask, dtype, device):
    #attention_mask -> [4, 125]

    b, lens = attention_mask.shape
    min_value = torch.finfo(dtype).min

    #填充极大负数
    mask = torch.full((lens, lens),
                      fill_value=min_value,
                      device=device,
                      dtype=dtype)

    #对角线和对角线以下归零
    if lens != 1:
        mask = torch.triu(mask, diagonal=1)

    #扩展尺寸
    mask = mask.reshape(1, 1, lens, lens)
    mask = mask.expand(b, 1, lens, lens)

    #pad的位置填充负极大数
    pad_mask = attention_mask.reshape(b, 1, 1, lens) == 0
    mask = mask.masked_fill(pad_mask, min_value)

    return mask


get_mask(torch.randint(0, 2, (4, 125)), torch.float32, 'cpu').shape

torch.Size([4, 1, 125, 125])

In [10]:
class Gemma2Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.embed_tokens = torch.nn.Embedding(num_embeddings=vocab_size,
                                               embedding_dim=hidden_size,
                                               padding_idx=padding_idx)
        self.layers = torch.nn.ModuleList(
            [Gemma2DecoderLayer() for _ in range(num_hidden_layers)])
        self.norm = Gemma2RMSNorm()

    def forward(self, input_ids, attention_mask):
        #input_ids -> [4, 125]
        #attention_mask -> [4, 125]

        #[4, 125] -> [4, 125, hidden_size]
        hidden_states = self.embed_tokens(input_ids)

        #[4, 125] -> [4, 1, 125, 125]
        attention_mask = get_mask(attention_mask, hidden_states.dtype,
                                  hidden_states.device)

        hidden_states = hidden_states * hidden_size**0.5

        for layer in self.layers:
            #维度不变
            hidden_states = layer(hidden_states, attention_mask=attention_mask)

        hidden_states = self.norm(hidden_states)

        return hidden_states


Gemma2Model()(torch.randint(100, 10000, (4, 125)),
              torch.randint(0, 2, (4, 125))).shape

torch.Size([4, 125, 576])

In [11]:
class Gemma2ForCausalLM(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.model = Gemma2Model()
        self.lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask, labels=None):
        #input_ids -> [4, 125]
        #attention_mask -> [4, 125]
        #labels -> [4, 125]

        #[4, 125, hidden_size]
        hidden_states = self.model(input_ids=input_ids,
                                   attention_mask=attention_mask)

        #[4, 125, vocab_size]
        logits = self.lm_head(hidden_states)

        logits = (logits / 30.0).tanh() * 30.0
        logits = logits.float()

        loss = None
        if labels is not None:
            #[4, 125, vocab_size] -> [4, 124, vocab_size] -> [4*124, vocab_size]
            shift_logits = logits[:, :-1].flatten(end_dim=1)

            #[4, 125] -> [4, 124] -> [4*124]
            shift_labels = labels[..., 1:].flatten()

            loss = torch.nn.functional.cross_entropy(shift_logits,
                                                     shift_labels)

        return loss, logits


out = Gemma2ForCausalLM()(torch.randint(100, 10000, (4, 125)),
                          torch.randint(0, 2, (4, 125)),
                          torch.randint(100, 10000, (4, 125)))

out[0], out[1].shape

(tensor(12.6092, grad_fn=<NllLossBackward0>), torch.Size([4, 125, 256000]))

In [12]:
# from transformers import Gemma2Config, Gemma2ForCausalLM as Gemma2ForCausalLM_Original

# config = "{'vocab_size': 256000, 'max_position_embeddings': 8192, 'hidden_size': hidden_size, 'intermediate_size': 9216, 'num_hidden_layers': 26, 'num_attention_heads': 8, 'head_dim': 256, 'num_key_value_heads': 4, 'hidden_activation': 'gelu_pytorch_tanh', 'initializer_range': 0.02, 'rms_norm_eps': 1e-06, 'use_cache': True, 'rope_theta': 10000.0, 'attention_bias': False, 'attention_dropout': 0.0, 'attn_logit_softcapping': 50.0, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'bfloat16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['Gemma2ForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 2, 'pad_token_id': 0, 'eos_token_id': [1, 107], 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': 'google/gemma-2-2b-it', 'transformers_version': '4.43.3', 'cache_implementation': 'hybrid', 'hidden_act': 'gelu_pytorch_tanh', 'model_type': 'gemma2', 'final_logit_softcapping': 30.0, 'query_pre_attn_scalar': 256, 'sliding_window': 4096}"
# config = Gemma2Config.from_dict(eval(config))
# config.vocab_size = vocab_size
# config.padding_idx = padding_idx
# config.hidden_size = hidden_size
# config.num_hidden_layers = num_hidden_layers
# config.use_cache = False
# config.cache_implementation = None

# model = Gemma2ForCausalLM()
# model_original = Gemma2ForCausalLM_Original(config)

# model.load_state_dict(model_original.state_dict())

# input = {
#     'input_ids': torch.randint(100, 10000, [4, 125]),
#     'attention_mask': torch.ones(4, 125).long(),
#     'labels': torch.randint(100, 10000, [4, 125])
# }
# input['attention_mask'][:, 120:] = 0

# with torch.no_grad():
#     loss, logits = model(**input)
#     out_original = model_original(**input)

# loss == out_original.loss, (logits == out_original.logits).all()

(tensor(True), tensor(True))

In [13]:
LlamaModel = Gemma2Model
LlamaForCausalLM = Gemma2ForCausalLM