In [87]:
import math
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
bs = 2
seq_len = 16
dim = 512
n_heads = 8
head_dim = dim // n_heads
inputs = torch.randn((bs, seq_len, dim))

In [18]:
def precompute_freqs_cis(
        seq_len: int,
        head_dim: int,
        base: int = 10_000
) -> torch.Tensor:
    freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))    # (d // 2)
    t = torch.arange(seq_len, device=freqs.device)                                                 # [0, 1, 2, ..., seq_len-1], (seql_len)
    freqs = torch.outer(t, freqs)                                                                  # (seq_len, d // 2)
    # x' = x * cos(theta) + x * sin(theta) * i, ahol i az imaginárius szám
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)                                         # (seq_len, d // 2)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)                                  # (seq_len, d // 2, 2)
    return cache.to(dtype=torch.bfloat16)

In [19]:

def apply_rotary_emb(
    x: torch.Tensor,                                                                          # (bs, block_size, n_heads, head_dim)
    freqs_cis: torch.Tensor                                                                   # (block_size, head_dim // 2, 2)
) -> torch.Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)                                    # (bs, block_size, n_heads, head_dim // 2, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)               # (1, block_size, 1, head_dim // 2, 2)
    x_out = torch.stack(
        [
            # első komponens rotáció: x1 * cos(theta) - x2 * sin(theta)
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            # második komponens rotáció: x2 * cos(theta) + x1 * sin(theta)
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
        ],
        -1,
    )
    x_out = x_out.flatten(3)
    return x_out.type_as(x)

In [88]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, head_dim: int, n_kv_heads: int, dropout: float) -> None:
        super().__init__()

        self.n_heads = n_heads
        self.head_dim = head_dim                        # dim // n_heads
        self.n_kv_heads = n_kv_heads                    # n groups

        self.repeats = self.n_heads // self.n_kv_heads  # group size

        self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, cache = None, mask = None) -> torch.Tensor:

        assert mask is None or cache is None
        bs, seq_len, _ = x.shape

        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        q = q.view(bs, seq_len, self.n_heads, self.head_dim)
        k = k.view(bs, seq_len, self.n_kv_heads, self.head_dim)
        v = v.view(bs, seq_len, self.n_kv_heads, self.head_dim)

        q = apply_rotary_emb(q, freqs_cis=freqs_cis)
        k = apply_rotary_emb(k, freqs_cis=freqs_cis)

        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

        if cache is not None:
            raise NotImplementedError # TODO

        k = k.repeat_interleave(self.repeats, dim=1)
        v = v.repeat_interleave(self.repeats, dim=1)

        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)

        if mask is not None:
            raise NotImplementedError # TODO

        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous()
        output = output.view(bs, seq_len, self.n_heads * self.head_dim)
        return self.wo(output)

In [89]:
bs = 2
seq_len = 16
dim = 512
n_heads = 8
head_dim = dim // n_heads
inputs = torch.randn((bs, seq_len, dim))

In [90]:
freqs_cis = precompute_freqs_cis(seq_len=seq_len, head_dim=head_dim)
freqs_cis.shape

torch.Size([16, 32, 2])

In [91]:
gqa = GroupedQueryAttention(dim, n_heads, head_dim, n_kv_heads=4, dropout=0.1)

In [92]:
gqa(inputs, freqs_cis).shape

torch.Size([2, 16, 512])

In [94]:
class GatedFeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.wg = nn.Linear(dim, hidden_dim, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = self.act_fn(self.w1(inputs)) # up
        hidden_states = hidden_states * self.wg(inputs) # gate
        hidden_states = self.w2(hidden_states) # down
        return hidden_states

In [95]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()

        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


In [98]:
class MoELayer(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, n_experts: int, n_experts_per_tok: int = 2) -> None:
        super().__init__()

        self.n_experts = n_experts
        self.n_experts_per_tok = n_experts_per_tok

        # gating, router, routing function (dim -> num_experts)
        self.gate = nn.Linear(dim, n_experts, bias=False)
        self.experts = nn.ModuleList([
            GatedFeedForward(dim=dim, hidden_dim=hidden_dim) for _ in range(n_experts)
        ])

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:

        batch_size, seq_len, dim = inputs.shape

        inputs = inputs.view(-1, dim)  # (batch_size * seq_len, dim)
        router_logits = self.gate(inputs) # (batch_size * seq_len, n_experts)

        # routing_weights := ?, (batch_size * seq_len, num_experts_per_tok)
        # selected_experts := indices of each chosen expert per token, (batch_size * seq_len, num_experts_per_tok)
        routing_weights, selected_experts = torch.topk(router_logits, self.n_experts_per_tok)
        routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float).to(inputs.dtype)

        results = torch.zeros_like(inputs, dtype=inputs.dtype, device=inputs.device) # (batch_size * seq_len, dim)

        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])

        results = results.reshape(batch_size, seq_len, dim)
        return results

In [104]:
x = torch.randn((2, 4, 4)) # (batch_size, seq_len, dim)
moe_layer = MoELayer(dim=4, hidden_dim=8, n_experts=3, n_experts_per_tok=2)

In [105]:
moe_layer(x).shape

torch.Size([2, 4, 4])

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, head_dim: int, norm_eps: float, attn_dropout: float) -> None:
        super().__init__()

        self.n_heads = n_heads
        self.dim = dim
        self.attention = GroupedQueryAttention(
            dim=dim, n_heads=n_heads, head_dim=head_dim, n_kv_heads=n_kv_heads, dropout=attn_dropout
        )
        self.attention_norm = RMSNorm(dim, eps=norm_eps)
        self.ffn_norm = RMSNorm(dim, eps=norm_eps)

        self.feed_forward: nn.Module
        if moe is not None:
            self.feed_forward = MoeLayer(
                experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)],
                gate=nn.Linear(dim, moe.num_experts, bias=False),
                moe_args=moe,
            )
        else:
            self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        cache: Optional[CacheView] = None,
        mask: Optional[BlockDiagonalMask] = None,
    ) -> torch.Tensor:
        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out
