In [6]:
import torch
import torch.nn as nn

In [10]:
import torch
import torch.nn as nn

class RoPE(nn.Module):
    def __init__(self, head_dim, max_seq_len=2048):
        super().__init__()
        assert head_dim % 2 == 0, "head_dim must be even for RoPE"
        self.head_dim = head_dim

        # Frequencies for each dimension
        theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        positions = torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = positions * theta  # shape: (max_seq_len, head_dim // 2)

        # Precompute cos and sin
        self.register_buffer("cos", freqs.cos()[None, None, :, :])  # (1, 1, seq_len, head_dim // 2)
        self.register_buffer("sin", freqs.sin()[None, None, :, :])  # (1, 1, seq_len, head_dim // 2)

    def forward(self, x):
        # x: (batch, num_heads, seq_len, head_dim)
        b, h, seq_len, d = x.shape

        # Separate real and imaginary parts
        x1 = x[..., ::2]  # (b, h, seq_len, head_dim // 2)
        x2 = x[..., 1::2]  # (b, h, seq_len, head_dim // 2)

        cos = self.cos[:, :, :seq_len, :]
        sin = self.sin[:, :, :seq_len, :]

        # Apply rotation
        x_rotated = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x_rotated.flatten(-2)  # back to shape: (b, h, seq_len, head_dim)

In [None]:
class FeedForwardWithSwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear1 = nn.Linear(d_model, 4 * d_model)
        self.linear2 = nn.Linear(4 * d_model, d_model)
        self.linear3 = nn.Linear(d_model, 4 * d_model)
        self.activation = nn.SiLU()

    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x) * self.linear3(x)))
   





In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale_weight = nn.Parameter(torch.ones(dim))

    def forward(self,x):
        norm = torch.sqrt(torch.mean(x ** 2, dim = -1, keepdim = True) + self.eps)
        return (x / norm) * self.scale_weight 

    

In [None]:
class GMQA(nn.Module):
    def __init__(self, d_out, num_heads, d_in, context_length, num_kv_groups):
        super().__init__()
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.num_kv_groups = num_kv_groups
        self.pos_emb = RoPE(self.head_dim, context_length)

        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, self.head_dim * num_kv_groups)
        self.W_value = nn.Linear(d_in, self.head_dim * num_kv_groups)
        self.W_out = nn.Linear(d_out, d_out)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim)
        keys = self.pos_emb(keys)
        values = self.pos_emb(values)
        queries = self.W_query(x)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        
        heads_per_group = self.num_heads // self.num_kv_groups
        keys = keys.unsqueeze(2).expand(b, num_tokens, heads_per_group, self.num_kv_groups, self.head_dim)
        values = values.unsqueeze(2).expand(b, num_tokens, heads_per_group, self.num_kv_groups, self.head_dim)

        keys = keys.reshape(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)     # (b, h, T, d)
        values = values.reshape(b, num_tokens, self.num_heads,x self.head_dim).transpose(1, 2) # (b, h, T, d)

        attn_scores = queries @ keys.transpose(-2, -1) #(b, num_heads, num_tokens, num_tokens)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim = -1)
        attn_output = attn_weights @ values # (b, num_heads, num_tokens, head_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
        attn_output = self.W_out(attn_output)

        return attn_output



        




In [None]:
import torch
import torch.nn as nn
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GMQA(cfg.d_out, cfg.num_heads, cfg.d_in, cfg.context_length, cfg.num_kv_groups)
        self.norm1 = RMSNorm(cfg.d_out, eps = cfg.norm_eps)
        self.norm2 = RMSNorm(cfg.d_out, eps = cfg.norm_eps)
        self.ff = FeedForwardWithSwiglu()

    def forward (self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
                          #x = self.drop(x) # dropout is not used in llama models
        x = x + shortcut

        shortcut = x 
        x = self.norm2(x)
        x = self.ff(x)
                           #x = self.drop(x)
        x = x + shortcut
        return x


    

