In [1]:
import torch


class MLP(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.left = torch.nn.Linear(1024, 2048, bias=False)
        self.right = torch.nn.Linear(1024, 2048, bias=False)
        self.out = torch.nn.Linear(2048, 1024, bias=False)

    def forward(self, x):
        #x -> [b, lens, 1024]

        #[b, lens, 1024] -> [b, lens, 2048]
        left = torch.nn.functional.gelu(self.left(x), approximate='tanh')

        #[b, lens, 1024] -> [b, lens, 2048]
        right = self.right(x)

        #[b, lens, 2048] -> [b, lens, 1024]
        return self.out(left * right)


MLP()(torch.randn(2, 15, 1024)).shape

torch.Size([2, 15, 1024])

In [2]:
class Norm(torch.nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        #两种情况
        #x -> [b, lens, dim]
        #x -> [b, heads, lens, dim]

        #形状不变
        x = x * (x.pow(2).mean(-1, keepdim=True) + 1e-6).rsqrt()
        x = x * (1.0 + self.weight)

        return x


Norm(64)(torch.randn(2, 15, 64)).shape

torch.Size([2, 15, 64])

In [3]:
class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.q = torch.nn.Linear(1024, 2048, bias=False)
        self.k = torch.nn.Linear(1024, 1024, bias=False)
        self.v = torch.nn.Linear(1024, 1024, bias=False)
        self.out = torch.nn.Linear(2048, 1024, bias=False)

        self.norm_q = Norm(dim=256)
        self.norm_k = Norm(dim=256)

        #[128]
        inv_freq = 1.0 / (1e4**(torch.arange(0, 256, 2).float() / 256))
        #[128] -> [1, 128, 1]
        inv_freq = inv_freq.reshape(1, 128, 1).float()
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x, mask):
        #x -> [b, lens, 1024]
        #mask -> [b, 1, lens, lens]

        b, lens = x.shape[0], x.shape[1]

        #[b, lens, 1024] -> [b, lens, 2048] -> [b, lens, 8, 256] -> [b, 8, lens, 256]
        q = self.q(x).reshape(b, lens, -1, 256).transpose(1, 2)
        #[b, lens, 1024] -> [b, lens, 1024] -> [b, lens, 4, 256] -> [b, 4, lens, 256]
        k = self.k(x).reshape(b, lens, -1, 256).transpose(1, 2)
        #[b, lens, 1024] -> [b, lens, 1024] -> [b, lens, 4, 256] -> [b, 4, lens, 256]
        v = self.v(x).reshape(b, lens, -1, 256).transpose(1, 2)

        q = self.norm_q(q)
        k = self.norm_k(k)

        #[b, 1, lens, 256], [b, 1, lens, 256]
        with torch.no_grad():
            #[1, 128, 1] -> [b, 128, 1]
            inv_freq = self.inv_freq.repeat(x.shape[0], 1, 1)

            #[lens]
            position_ids = torch.arange(x.shape[1], device=x.device)
            #[lens] -> [1, 1, lens]
            position_ids = position_ids.reshape(1, 1, -1).float()

            #[b, 128, 1] * [1, 1, lens] -> [b, 128, lens] -> [b, lens, 128]
            emb = (inv_freq @ position_ids).transpose(1, 2)
            #[b, lens, 128] -> [b, lens, 256]
            emb = torch.cat((emb, emb), dim=2)

            #[b, 1, lens, 256], [b, 1, lens, 256]
            cos, sin = emb.cos().unsqueeze(1), emb.sin().unsqueeze(1)

        def rotate(x):
            left = -x[:, :, :, x.shape[3] // 2:]
            right = x[:, :, :, :x.shape[3] // 2]
            return torch.cat((left, right), dim=3)

        #[b, 8, lens, 256]
        q = (q * cos) + (rotate(q) * sin)
        #[b, 4, lens, 256]
        k = (k * cos) + (rotate(k) * sin)

        def repeat_kv(x):
            x = x.unsqueeze(2).repeat(1, 1, 2, 1, 1)
            return x.reshape(x.shape[0], x.shape[1] * 2, x.shape[3],
                             x.shape[4])

        #[b, 8, lens, 256]
        k = repeat_kv(k)
        #[b, 8, lens, 256]
        v = repeat_kv(v)

        #[b, 8, lens, 256] * [b, 8, 256, lens] -> [b, 8, lens, lens]
        out = q @ k.transpose(2, 3)
        out = out * 256**-0.5 + mask
        #[b, 8, lens, lens] * [b, 8, lens, 256] -> [b, 8, lens, 256]
        out = out.softmax(dim=3) @ v

        #out = torch.nn.functional.scaled_dot_product_attention(q,k,v,attn_mask=mask,scale=256**-0.5)

        #[b, 8, lens, 256] -> [b, lens, 8, 256] -> [b, lens, 2048]
        out = out.transpose(1, 2).reshape(b, lens, -1)
        #[b, lens, 2048] -> [b, lens, 1024]
        out = self.out(out)

        return out


Atten()(torch.randn(2, 15, 1024), torch.randint(0, 2, [2, 1, 15, 15])).shape

torch.Size([2, 15, 1024])

In [4]:
class Decoder(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.self_attn = Atten()
        self.mlp = MLP()
        self.norm1 = Norm(1024)
        self.norm2 = Norm(1024)
        self.norm3 = Norm(1024)
        self.norm4 = Norm(1024)

    def forward(self, x, mask):
        #x -> [b, lens, 1024]
        #mask -> [b, 1, lens, lens]

        #形状不变
        res = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        x = self.norm2(x) + res

        #形状不变
        res = x
        x = self.norm3(x)
        x = self.mlp(x)
        x = self.norm4(x) + res

        return x


Decoder()(torch.randn(2, 15, 1024), torch.randint(0, 2, [2, 1, 15, 15])).shape

torch.Size([2, 15, 1024])

In [5]:
def get_mask(attention_mask):
    lens = attention_mask.shape[1]

    #[lens, lens]
    mask = torch.full((lens, lens),
                      fill_value=-1e32,
                      device=attention_mask.device)

    #对角线和对角线以下归零
    mask = torch.triu(mask, diagonal=1)

    #[lens, lens] -> [1, 1, lens, lens]
    mask = mask.unsqueeze(0).unsqueeze(0)

    #[b, lens] -> [b, 1, 1, lens]
    attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)

    #[1, 1, lens, lens] + [b, 1, 1, lens] -> [b, 1, lens, lens]
    padding_mask = mask + attention_mask

    #[b, 1, lens, lens]
    mask = mask.masked_fill(padding_mask == 0, -1e32)

    return mask


get_mask(torch.randint(0, 2, [2, 15])).shape

torch.Size([2, 1, 15, 15])

In [6]:
class Gemma3Model(torch.nn.Module):

    def __init__(self, vocab_size, pad_token_id):
        super().__init__()

        self.embed = torch.nn.Embedding(vocab_size, 1024, pad_token_id)
        self.layers = torch.nn.ModuleList([Decoder() for _ in range(16)])
        self.norm = Norm(1024)

    def forward(self, input_ids, attention_mask):
        #input_ids -> [b, lens]
        #attention_mask -> [b, lens]

        #[b, lens] -> [b, lens, 1024]
        x = self.embed(input_ids) * 1024**0.5

        #[b, 1, lens, lens]
        mask = get_mask(attention_mask)

        for i in self.layers:
            x = i(x, mask)

        return self.norm(x)


Gemma3Model(16, 0)(torch.randint(0, 16, [2, 15]), torch.randint(0, 2,
                                                                [2, 15])).shape

torch.Size([2, 15, 1024])

In [7]:
class Gemma3Actor(torch.nn.Module):

    def __init__(self, vocab_size, pad_token_id):
        super().__init__()
        self.model = Gemma3Model(vocab_size, pad_token_id)
        self.lm_head = torch.nn.Linear(1024, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask, labels=None):
        #input_ids -> [b, lens]
        #attention_mask -> [b, lens]
        #labels -> [b, lens]

        #[b, lens] -> [b, lens, 1024]
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
        #[b, lens, 1024] -> [b, lens, vocab_size]
        logits = self.lm_head(logits)

        loss = None
        if labels is not None:
            labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
            labels = labels[..., 1:]

            loss = torch.nn.functional.cross_entropy(logits.flatten(end_dim=1),
                                                     labels.flatten(),
                                                     ignore_index=-100,
                                                     reduction='mean')

        return loss, logits


Gemma3Actor(16, 0)(input_ids=torch.randint(0, 16, [2, 15]),
                   attention_mask=torch.randint(0, 2, [2, 15]),
                   labels=torch.randint(0, 16, [2, 15]))

(tensor(2.8809, grad_fn=<NllLossBackward0>),
 tensor([[[ 4.9931e-01,  2.6358e-01, -6.9108e-01, -8.2092e-01, -3.5185e-01,
           -1.1082e-01,  6.2080e-02, -7.7483e-01, -1.2577e-01,  3.1631e-01,
            2.6910e-02, -2.7915e-01, -2.5486e-02, -7.3420e-02, -1.3707e+00,
            1.0384e+00],
          [ 4.0725e-01,  8.3875e-02,  4.6034e-01, -3.0774e-01, -7.2191e-01,
           -6.1969e-01, -5.2242e-01, -1.5075e-01, -1.0387e+00,  8.4529e-01,
           -2.0267e-01,  3.5437e-01,  4.0266e-01,  4.3016e-02,  3.5819e-01,
            5.5118e-01],
          [ 1.0875e-01, -1.3081e+00,  9.5620e-01,  3.8039e-01, -3.7686e-01,
            4.5391e-01, -9.2286e-01,  4.3188e-01,  6.1821e-01, -3.4398e-01,
           -9.0507e-02,  3.6013e-01, -3.7015e-01, -3.0767e-01,  3.1152e-03,
           -5.1721e-01],
          [-2.6928e-01, -3.5677e-01, -1.1369e+00, -4.9205e-01, -1.2345e-02,
            6.1982e-01,  5.4395e-01, -6.5105e-01, -1.1177e+00, -7.6442e-01,
           -1.4860e-01,  2.3622e-01, -3.9530

In [8]:
def generate(model_actor, input_ids, pad_token_id, eos_token_id):
    max_length = 32
    attention_mask = (input_ids != pad_token_id).long()

    _, logits = model_actor(input_ids=input_ids, attention_mask=attention_mask)

    predict = logits[:, -1].argmax(1, keepdim=True)

    input_ids = torch.cat((input_ids, predict), 1)

    if input_ids.shape[1] >= max_length:
        return input_ids

    ends = (input_ids == eos_token_id).sum(1) > 0

    if ends.all():
        return input_ids

    return generate(model_actor, input_ids, pad_token_id, eos_token_id)


generate(model_actor=Gemma3Actor(16, 0),
         input_ids=torch.randint(0, 16, [2, 15]),
         pad_token_id=0,
         eos_token_id=2).shape

torch.Size([2, 16])

In [9]:
class Gemma3Critic(torch.nn.Module):

    def __init__(self, vocab_size, pad_token_id, eos_token_id):
        super().__init__()

        self.model = Gemma3Model(vocab_size, pad_token_id)
        self.eos_token_id = eos_token_id
        self.score = torch.nn.Linear(1024, 1, bias=False)

    def forward(self, input_ids, attention_mask, labels=None):
        #input_ids -> [b, lens]
        #attention_mask -> [b, lens]
        #labels -> [b, 1]

        #[b, lens] -> [b, lens, 1024]
        logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
        #[b, lens, 1024] -> [b, lens, 1]
        logits = self.score(logits)

        ends = []
        for i in input_ids:
            i = i.tolist()
            end = len(i) - 1
            if self.eos_token_id in i:
                end = i.index(self.eos_token_id)
            ends.append(end)

        logits = logits[range(input_ids.shape[0]), ends]

        loss = None
        if labels is not None:
            loss = torch.nn.functional.mse_loss(logits.flatten(),
                                                labels.flatten())

        return loss, logits


Gemma3Critic(16, 0, 1)(torch.randint(0, 16, [2, 15]),
                       torch.randint(0, 2, [2, 15]), torch.randn(2, 1))

(tensor(0.0837, grad_fn=<MseLossBackward0>),
 tensor([[-0.6086],
         [ 0.9969]], grad_fn=<IndexBackward0>))