# Cosine ML take-home task

## Task Description

Your task is to take Llama 3.1 8B and rearchitect the dense layers to MoE layers. You are supplied with a Llama 3 architecture below.

Architect the MoE-ified Llama 8B such that it performs identically to the base model. That is, when inferencing upon a single prompt at temperature 0 the output of the base dense model and your MoE Llama should be the same. Write a test that demonstrates that the MoE-ified Llama produces the same logits as the base model to within a reasonable tolerance, such that greedy decoding produces the same token sequence in either setting. Describe and/or demonstrate the approach you took to ensure this and whether or not this differs from how you would expect to inference with the model in a production setting.

When writing the MoE layer you should consider issues regarding performance and memory when inferencing in real-world scenarios. Discuss the implications on performance when optimising for memory when inferencing your MoE-ified Llama, and vice-versa. The task is open ended in the sense that you may wish to produce a more performant MoE implementation, and another implementation better suited to situations where memory is the bottleneck. However, this is not required and a detailed discussion can also prove sufficient.

## Core Requirements

1. Replace the dense feed-forward layers (MLP blocks) in Llama with an MoE equivalent. Your MoE design should support:

- an “identity / toy mode” that guarantees logit fidelity with the base dense model
- a “production-like mode” that resembles a realistic MoE setup (e.g., learned router + top-k routing), even if it is not trained. You may decide what “production-like” means, but it should meaningfully reflect real MoE usage (routing, multiple experts, dispatch pattern, etc.) rather than only being a no-op wrapper.

2. Write a runnable test that demonstrates logit fidelity between the original dense Llama 3.1 8B model, and your MoE-ified model in identity/toy mode. The test must:

- load pretrained dense weights
- compare logits between models and assert they match within a reasonable tolerance (you decide what’s reasonable and justify it)
- ideally include a greedy decode check to confirm identical token outputs

3. A discussion regarding your MoE implementation and the performance/memory trade-offs you have considered. Consider the implications of optimizing for: 

- maximum throughput (fast inference)
- minimum memory footprint 

As well as runnability, you will be judged on code quality and readability. With regards to the asset you produce, there are no strict requirements - you can return a notebook, or split it out into a package along with a script.

## Model Code

In [None]:
import math
from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F


@dataclass
class ModelConfig:
    vocab_size: int = 128256
    hidden_size: int = 4096
    intermediate_size: int = 14336
    num_hidden_layers: int = 32
    num_attention_heads: int = 32
    num_key_value_heads: int = 8  # GQA
    max_seq_len: int = 8192
    rms_norm_eps: float = 1e-5
    rope_theta: float = 500000.0


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_position_emb(q, k, cos, sin):
    # q, k: [bs, nheads, seq, head_dim]
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    q = (q * cos) + (rotate_half(q) * sin)
    k = (k * cos) + (rotate_half(k) * sin)
    return q, k


class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len_cache = max_seq_len
        self.base = base

        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self._set_cos_sin_cache(max_seq_len)

    def _set_cos_sin_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)

    @torch.no_grad()
    def forward(self, seq_len: int):
        if seq_len > self.cos_cached.shape[0]:
            self._set_cos_sin_cache(seq_len)
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]


class Attention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads

        assert self.head_dim * self.num_heads == self.hidden_size, (
            "hidden_size must be divisible by num_heads"
        )
        assert self.num_heads % self.num_kv_heads == 0, (
            "num_heads must be multiple of num_kv_heads"
        )

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=False
        )
        self.k_proj = nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.v_proj = nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

        self.rotary_emb = RotaryEmbedding(
            dim=self.head_dim,
            max_seq_len=config.max_seq_len,
            base=config.rope_theta,
        )

    @staticmethod
    def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
        if n_rep == 1:
            return x
        bsz, kv_heads, seq, hd = x.shape
        x = x[:, :, None, :, :].expand(bsz, kv_heads, n_rep, seq, hd)
        return x.reshape(bsz, kv_heads * n_rep, seq, hd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [bsz, seq, hidden]
        past_key_value:
          k: [bsz, kv_heads, past_seq, head_dim]
          v: [bsz, kv_heads, past_seq, head_dim]
        """
        bsz, seq_len, _ = x.shape

        q = (
            self.q_proj(x)
            .view(bsz, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )  # [bsz, heads, seq, hd]
        k = (
            self.k_proj(x)
            .view(bsz, seq_len, self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )  # [bsz, kv, seq, hd]
        v = (
            self.v_proj(x)
            .view(bsz, seq_len, self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )

        cos, sin = self.rotary_emb(seq_len)
        q, k = apply_position_emb(q, k, cos, sin)

        n_rep = self.num_heads // self.num_kv_heads
        k = self.repeat_kv(k, n_rep)  # [bsz, heads, total_seq, hd]
        v = self.repeat_kv(v, n_rep)

        total_seq = k.size(2)

        attn_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(
            self.head_dim
        )  # [bsz, heads, seq, total_seq]
        causal = torch.tril(
            torch.ones(seq_len, total_seq, device=x.device, dtype=torch.bool)
        )
        attn_scores = attn_scores.masked_fill(~causal, float("-inf"))

        attn_probs = F.softmax(attn_scores, dim=-1)

        out = torch.matmul(attn_probs, v)  # [bsz, heads, seq, hd]
        out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
        out = self.o_proj(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


class DenseBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = Attention(config)
        self.ffn = FeedForward(config)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        attn_out = self.self_attn(self.attn_norm(x))
        x = residual + attn_out

        residual = x
        x = residual + self.ffn(self.ffn_norm(x))

        return x


class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList(
            [DenseBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embed_tokens(x)  # [bsz, seq, hidden]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.lm_head(x)

## State Dict loading

In [None]:
import json
import os

import torch
from safetensors.torch import load_file as safetensors_load_file

mapping = {
    ".mlp": ".ffn",
    ".post_attention_layernorm": ".ffn_norm",
    ".input_layernorm": ".attn_norm",
}


def load_weights(
    model: torch.nn.Module,
    model_dir: str | None,
    dtype: torch.dtype | None = None,
    device: torch.device | None = None,
    strict: bool = False,
    verbose: bool = True,
):
    """
    Minimal loader for a model whose weights are sharded across multiple safetensors files.

    Maps between the llama weight names as given in the HF safetensors and the parameter names in the above implementation.
    """

    sd = {}
    if model_dir is not None:
        # Read model.safetensors.index.json
        index_path = os.path.join(model_dir, "model.safetensors.index.json")
        if not os.path.isfile(index_path):
            raise FileNotFoundError(f"Missing index file: {index_path}")

        with open(index_path, "r", encoding="utf-8") as f:
            idx = json.load(f)

        if "weight_map" not in idx:
            raise ValueError(f"Index file missing 'weight_map': {index_path}")

        weight_map = idx["weight_map"]
        shard_files = sorted(set(weight_map.values()))

        # Load each sharded list in one merged state_dict
        for shard in shard_files:
            shard_path = os.path.join(model_dir, shard)
            if not os.path.isfile(shard_path):
                raise FileNotFoundError(
                    f"Shard referenced by index not found: {shard_path}"
                )
            sd.update(safetensors_load_file(shard_path))
    else:
        shard_files = []

    # iterate over keys and rename according to mapping
    # optionally cast dtype and/or move to device
    for param in list(sd.keys()):
        weights = sd.pop(param)
        if param.startswith("model."):
            param = param.replace("model.", "")
        for old, new in mapping.items():
            if old in param:
                param = param.replace(old, new)
        if dtype is not None:
            weights = weights.to(dtype)
        if device is not None:
            weights = weights.to(device)
        sd[param] = weights

    meta = {
        "model_dir": model_dir,
        "num_shards": len(shard_files),
        "shards": shard_files,
    }

    if not sd:
        for name, tensor in model.state_dict().items():
            if tensor.is_floating_point():
                rand = torch.randn_like(tensor)
            else:
                rand = torch.zeros_like(tensor)

            if dtype is not None and rand.is_floating_point():
                rand = rand.to(dtype)
            if device is not None:
                rand = rand.to(device)

            sd[name] = rand

    # Load state_dict into model
    missing, unexpected = model.load_state_dict(sd, strict=strict)
    model.to(device=device)

    if verbose:
        print("=== load_weights_sharded_safetensors_only ===")
        print(f"model_dir: {model_dir}")
        print(f"num_shards: {meta['num_shards']}")
        print(f"num_tensors_loaded: {len(sd)}")
        print(f"missing_keys: {len(missing)}")
        print(f"unexpected_keys: {len(unexpected)}")

    return {
        "model_dir": model_dir,
        "checkpoint_meta": meta,
        "num_tensors_loaded": len(sd),
        "missing_keys": missing,
        "unexpected_keys": unexpected,
    }

In [None]:
cache_dir = f"{os.path.expanduser('~')}/.cache/huggingface/hub/"
model_name = "meta-llama/Llama-3.1-8B-Instruct"

snapshots_dir = os.path.join(
    cache_dir, "models--" + model_name.replace("/", "--"), "snapshots"
)
assert os.path.isdir(snapshots_dir), f"Model directory not found: {snapshots_dir}"
model_dir = os.path.join(snapshots_dir, os.listdir(snapshots_dir)[0])

config = ModelConfig()
model = Transformer(config)

load_weights(
    model,
    model_dir=model_dir,
    dtype=torch.bfloat16,
    device="cuda",
    strict=True,
    verbose=True,
)

In [None]:
from transformers import AutoTokenizer, PreTrainedTokenizer

system_prompt = (
    "You are a helpful coding assistant. Answer the user's questions like a pirate."
)
user_prompt = "Explain the difference between a list and a tuple in Python."

conversation = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt},
]

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct"
)