In [5]:
import torch
import torch.nn as nn
import math
from torch.nn import Module, ModuleList
import torch.nn.functional as F


In [2]:
class MultiHeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, latent_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.latent_dim = latent_dim

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.kv_latent_proj = nn.Linear(embed_dim, latent_dim)
        self.latent_to_kv_proj = nn.Linear(latent_dim, 2 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj.weight)
        nn.init.xavier_uniform_(self.kv_latent_proj.weight)
        nn.init.xavier_uniform_(self.latent_to_kv_proj.weight)
        nn.init.xavier_uniform_(self.out_proj.weight)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Linear transformations
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        kv_latent = self.kv_latent_proj(x)

        # Latent KV projection
        latent_kv = self.latent_to_kv_proj(kv_latent)
        key, value = torch.chunk(latent_kv, 2, dim=-1)
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, value)

        # Reshape and output projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.out_proj(attn_output)

        return output

In [3]:

class RotaryPositionalEmbeddings(nn.Module):

  def __init__(self, d: int, base: int = 10_000):

    super().__init__()
    self.base = base
    self.d = d
    self.cos_cached = None
    self.sin_cached = None

  def _build_cache(self, x: torch.Tensor):

    if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
      return

    seq_len = x.shape[0]

    theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)

    seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]

    idx_theta = torch.einsum('n,d->nd', seq_idx, theta)  #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]

    idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]


    self.cos_cached = idx_theta2.cos()[:, None, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
    self.sin_cached = idx_theta2.sin()[:, None, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]

  def _neg_half(self, x: torch.Tensor):

    d_2 = self.d // 2 #

    return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]


  def forward(self, x: torch.Tensor):

    self._build_cache(x)

    neg_half_x = self._neg_half(x)

    x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]

    return x_rope

In [6]:
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.gamma * self.scale

In [8]:
class GatingNetwork(nn.Module):
    def __init__(self, input_size, num_experts):
        super(GatingNetwork, self).__init__()
        self.linear = nn.Linear(input_size, num_experts)

    def forward(self, x):
        # x: (batch_size, input_size)
        gates = self.linear(x)
        # gates: (batch_size, num_experts)
        return F.softmax(gates, dim=1)

class MoE(nn.Module):
    def __init__(self, input_size, output_size, num_experts):
        super(MoE, self).__init__()
        self.gating_network = GatingNetwork(input_size, num_experts)
        self.experts = nn.ModuleList([nn.Linear(input_size, output_size) for _ in range(num_experts)])

    def forward(self, x):
        # x: (batch_size, input_size)
        gate_weights = self.gating_network(x)
        # gate_weights: (batch_size, num_experts)

        expert_outputs = []
        for i, expert in enumerate(self.experts):
            expert_outputs.append(expert(x).unsqueeze(2))
        # expert_outputs: list of (batch_size, output_size, 1)

        expert_outputs = torch.cat(expert_outputs, dim=2)
        # expert_outputs: (batch_size, output_size, num_experts)

        weighted_outputs = expert_outputs * gate_weights.unsqueeze(1)
        # weighted_outputs: (batch_size, output_size, num_experts)

        return torch.sum(weighted_outputs, dim=2)