# Transformer LM


<img src="../images/TransformerLM.png">


In [None]:
from cs336_basics.linear import Linear
from cs336_basics.rmsnorm import RMSNorm
from cs336_basics.embedding import Embedding
from cs336_basics.swinglu import SwiGLUFFN
from cs336_basics.rope import RoPE
from cs336_basics.softmax import Softmax
from cs336_basics.multi_head_self_attention import MultiHeadSelfAttentionRoPE


def run_transformer_lm(
    vocab_size: int,
    context_length: int,
    d_model: int,
    num_layers: int,
    num_heads: int,
    d_ff: int,
    rope_theta: float,
    weights: dict[str, Tensor],
    in_indices: Int[Tensor, " batch_size sequence_length"],
) -> Float[Tensor, " batch_size sequence_length vocab_size"]:
    """Given the weights of a Transformer language model and input indices,
    return the output of running a forward pass on the input indices.

    This function should use RoPE.

    Args:
        vocab_size (int): The number of unique items in the output vocabulary to be predicted.
        context_length (int): The maximum number of tokens to process at once.
        d_model (int): The dimensionality of the model embeddings and sublayer outputs.
        num_layers (int): The number of Transformer layers to use.
        num_heads (int): Number of heads to use in multi-headed attention. `d_model` must be
            evenly divisible by `num_heads`.
        d_ff (int): Dimensionality of the feed-forward inner layer (section 3.3).
        rope_theta (float): The RoPE $Theta$ parameter.
        weights (dict[str, Tensor]):
            State dict of our reference implementation. {num_layers} refers to an
            integer between `0` and `num_layers - 1` (the layer index).
            The keys of this dictionary are:
            - `token_embeddings.weight`
                Token embedding matrix. Shape is (vocab_size, d_model).
            - `layers.{num_layers}.attn.q_proj.weight`
                The query projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_k),
                so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.k_proj.weight`
                The key projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_k),
                so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.v_proj.weight`
                The value projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_v),
                so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.output_proj.weight`
                Weight of the multi-head self-attention output projection
                Shape is ((d_model / num_heads) * num_heads, d_model).
            - `layers.{num_layers}.ln1.weight`
                Weights of affine transform for the first RMSNorm
                applied in the transformer block.
                Shape is (d_model,).
            - `layers.{num_layers}.ffn.w1.weight`
                Weight of the first linear transformation in the FFN.
                Shape is (d_model, d_ff).
            - `layers.{num_layers}.ffn.w2.weight`
                Weight of the second linear transformation in the FFN.
                Shape is (d_ff, d_model).
            - `layers.{num_layers}.ffn.w3.weight`
                Weight of the third linear transformation in the FFN.
                Shape is (d_model, d_ff).
            - `layers.{num_layers}.ln2.weight`
                Weights of affine transform for the second RMSNorm
                applied in the transformer block.
                Shape is (d_model,).
            - `ln_final.weight`
                Weights of affine transform for RMSNorm applied to the output of the final transformer block.
                Shape is (d_model, ).
            - `lm_head.weight`
                Weights of the language model output embedding.
                Shape is (vocab_size, d_model).
        in_indices (Int[Tensor, "batch_size sequence_length"]) Tensor with input indices to run the language model on. Shape is (batch_size, sequence_length), where
            `sequence_length` is at most `context_length`.

    Returns:
        Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalized
        next-word distribution for each token.
    """


    # ── Token Embeddings ──────────────────────────────────────────────────────
    embedding = Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    embedding.token_ids.data = weights["token_embeddings.weight"]

    x = embedding(in_indices)   # (batch, seq_len, d_model)

    seq_len = x.shape[1]
    token_positions = torch.arange(seq_len, device=x.device)

    # ── Transformer Layers ───────────────────────────────────────────────────
    for layer_idx in range(num_layers):
        pfx = f"layers.{layer_idx}"

        # --- Sublayer 1: Pre-norm + Attention + residual ---
        ln1 = RMSNorm(d_model=d_model)
        ln1.weight.data = weights[f"{pfx}.ln1.weight"]

        mhsa = MultiHeadSelfAttentionRoPE(
            d_model=d_model,
            num_heads=num_heads,
            max_seq_len=context_length,
            theta=rope_theta,
        )

        q_proj_weight = weights[f"{pfx}.attn.q_proj.weight"]
        k_proj_weight = weights[f"{pfx}.attn.k_proj.weight"]
        v_proj_weight = weights[f"{pfx}.attn.v_proj.weight"]
        o_proj_weight = weights[f"{pfx}.attn.output_proj.weight"]

        x_norm = ln1(x)
        attn_out = mhsa(
            q_proj_weight,
            k_proj_weight,
            v_proj_weight,
            o_proj_weight,
            x_norm,
            token_positions=token_positions,
        )
        x = x + attn_out

        # --- Sublayer 2: Pre-norm + FFN + residual ---
        ln2 = RMSNorm(d_model=d_model)
        ln2.weight.data = weights[f"{pfx}.ln2.weight"]

        ffn = SwiGLUFFN(d_model=d_model, d_ff=d_ff)
        ffn.w1_weight.data = weights[f"{pfx}.ffn.w1.weight"]
        ffn.w2_weight.data = weights[f"{pfx}.ffn.w2.weight"]
        ffn.w3_weight.data = weights[f"{pfx}.ffn.w3.weight"]

        x_norm = ln2(x)
        ffn_out = ffn(x_norm)
        x = x + ffn_out

    # ── Final Norm ────────────────────────────────────────────────────────────
    ln_final = RMSNorm(d_model=d_model)
    ln_final.weight.data = weights["ln_final.weight"]
    x = ln_final(x)

    # ── LM Head (linear projection to vocab, no softmax — returns logits) ────
    lm_head = Linear(in_features=d_model, out_features=vocab_size)
    lm_head.weight.data = weights["lm_head.weight"]

    logits = lm_head(x)   # (batch, seq_len, vocab_size)

    return logits