In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):

    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x / rms


class SimpleMTP(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, num_heads: int = 3, nhead: int = 2):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_heads = num_heads

        self.rmsnorm = RMSNorm()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False)
        self.unembed.weight = self.embed.weight

        self.projections = nn.ModuleList([
            nn.Linear(d_model * 2, d_model)
            for _ in range(num_heads)
        ])

        self.transformers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
            for _ in range(num_heads)
        ])

    def forward(self, token_ids: torch.LongTensor, init_hidden: torch.Tensor = None):
        B, T = token_ids.shape
        device = token_ids.device

        embeds = self.embed(token_ids)

        if init_hidden is None:
            h0_seq = embeds
        else:
            h0_seq = init_hidden

        outputs = []
        max_i = T - self.num_heads
        for i in range(0, max_i):
            h_prev = h0_seq[:, i, :]
            logits_k = []

            for k in range(self.num_heads):
                future_pos = i + (k + 1)
                tok_embed = embeds[:, future_pos, :]

                h_norm = self.rmsnorm(h_prev)
                e_norm = self.rmsnorm(tok_embed)

                merged = torch.cat((h_norm, e_norm), dim=-1)

                proj: torch.Tensor = self.projections[k](merged)

                x = proj.unsqueeze(0)
                x = self.transformers[k](x)
                h_curr = x.squeeze(0)

                logits = self.unembed(h_curr)
                logits_k.append(logits)

                h_prev = h_curr
            logits_k = torch.stack(logits_k, dim=1)
            outputs.append(logits_k)

        # D -> num_heads V -> vocab_size
        # stack along sequence axis -> (T-D, B, D, V) then permute -> (B, T-D, D, V)
        out = torch.stack(outputs)
        out = out.permute(1, 0, 2, 3).contiguous()
        return out

In [None]:
batch_size, seq_len, d_model, vocab_size = 1, 8, 8, 5000
model = SimpleMTP(d_model=d_model, vocab_size=vocab_size)
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
print(tokens)

logits = model(tokens)

print(logits.shape)

print(logits[0, 0, 0])

# get all predictions at i=0 as token IDs
pred_ids = logits[0, 0].argmax(dim=-1)
print(pred_ids)

In [None]:
batch_size, seq_len, vocab_size = 1, 8, 5000
targets = torch.randint(0, vocab_size, (batch_size, seq_len))

logits = model(tokens)
B, L, D, V = logits.shape
_, T = targets.shape

loss = 0.0
for i in range(L):
    for k in range(D):
        logit_ik = logits[:, i, k, :]
        target_ik = targets[:, i + (k + 1)]
        print(logit_ik.shape,target_ik.shape)
        loss += F.cross_entropy(logit_ik, target_ik)
loss = loss / (L * D)

print(loss.item())