In [None]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()
login(os.getenv("HUGGING_FACE_KEY"))

In [None]:
# import torch
# Load model directly
# from transformers import AutoTokenizer, AutoModelForCausalLM
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float32, device_map="auto")

In [None]:
# from transformers import pipeline
# pipe = pipeline("summarization", model="")
# response = pipe("test")
# print(response)

# Coding MLA from scratch

In [None]:
# Initial Setup and Parameters
import torch
import torch.nn as nn
import torch.nn.functional as F


class RopelessMLA(nn.Module):
    def __init__(self, d_model: int, n_heads: int, kv_latent_dim: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.dh = d_model // n_heads  # dimension per head

        self.W_q = nn.Linear(self.d_model, self.d_model, bias=False)  # Query Projection
        self.W_dkv = nn.Linear(self.d_model, kv_latent_dim, bias=False)  # Compress into laten KV space
        self.W_uk = nn.Linear(kv_latent_dim, self.d_model, bias=False)  # Decompress K
        self.W_uv = nn.Linear(kv_latent_dim, self.d_model, bias=False)  # Decompress v
        self.W_o = nn.Linear(self.d_model, self.d_model, bias=False)  # Final output projection

        self.ln = nn.LayerNorm(kv_latent_dim)
        self.register_buffer("absorbed_k", None)  # Holds W_q @ W_uk

    def forward(self, x: torch.Tensor, kv_cache=None, past_length=0):
        B, S, D = x.size()

        if self.absorbed_k is None:
            absorbed_k = torch.matmul(self.W_q.weight, self.W_uk.weight)
            self.absorbed_k = absorbed_k.view(self.n_heads, self.dh, -1)

        new_c_kv = self.ln(self.W_dkv(x))
        if kv_cache is None:
            c_kv = new_c_kv
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1)

        S_full = c_kv.size(1)

        v_full = self.W_uv(c_kv)
        v = v_full.view(B, S_full, self.n_heads, self.dh).transpose(1, 2)

        q = x.view(B, S, self.n_heads, self.dh)

        attn_scores = torch.zeros(B, self.n_heads, S, S_full, device=x.device)
        for h in range(self.n_heads):
            tmp = torch.matmul(q[:, :, h], self.absorbed_k[h])
            attn_scores[:, h] = torch.bmm(tmp, c_kv.transpose(1, 2))  # BMM does parallel multiplication

        attn_scores = attn_scores / (self.dh ** 0.5)
        mask = torch.tril(torch.ones((S, S_full), device=x.device), diagonal=past_length)
        attn_scores = attn_scores.masked_fill(mask.view(1, 1, S, S_full) == 0, float("-inf"))

        attn_weights = attn_scores.softmax(dim=-1)

        out_heads = []
        for h in range(self.n_heads):
            context_h = torch.matmul(attn_weights[:, h], v [:, h])
            out_heads.append(context_h)
        out = torch.cat(out_heads, dim=-1)

        return self.W_o(out), c_kv