# DeepSeek

In this notebook, I'll be experimenting with the architectural enhancements that characterize DeepSeek.

## Attention

The attention mechanism used by DeepSeek is Multi-Head Latent Attention (MLA), so I'll be working towards that.
But rather than immediately jumping to MLA, I'll build up to it in steps:

1. MHA with a KV cache
2. Multi-Query Attention
3. Grouped-Query Attention
4. Multi-Head Latent Attention

### KV Cache

The idea behind the KV cache is that the autoregressive nature of LLM text
generation results in a ton of redundant calculations.
Since tokens are predicted one at a time, and only the final context vector in a
sequence influences the prediction, we should only have to calculate one context
vector at a time.

In order to make that possible, we cache previously-calculated context vectors
and recall them when needed.

In [2]:
import torch
import torch.nn as nn
from typing import Optional

class KVCache(torch.Tensor):
    def __init__(self):
        self.x = None
        self.w_key = None
        self.w_val = None

    def save_key(self, x: torch.Tensor, val: nn.Linear):
        pass

    def save_val(self, x: torch.Tensor, val: nn.Linear):
        pass

    def get_key(self, x: torch.Tensor) -> Optional[nn.Linear]:
        return None

    def get_val(self, x: torch.Tensor) -> Optional[nn.Linear]:
        return None

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in: int,  # embedding dimension
        d_out: int, # embedding dimension
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias: bool = False,
    ):
        super().__init__()
        if d_out % num_heads != 0:
            raise ValueError("The number of heads must evenly divide d_out.")
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_width = d_out // num_heads
        self.qkv_bias = qkv_bias

        # construct the weights for Q, K, and V.
        # these will be registered as trainable parameters automatically.
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # create and register a KV cache
        self.kv_cache = KVCache()
        self.register_buffer("kv_cache", self.kv_cache)

        # and the output projection, also trainable.
        self.w_out = nn.Linear(d_out, d_out)

        # and the dropout layer. not trainable, just drops random values
        # to zero with a probability determined by the dropout parameter
        self.dropout = nn.Dropout(dropout)

        # and the mask, which prevents each token from "seeing" later ones
        mask = torch.triu(  # an upper triangular matrix
            torch.ones(context_length, context_length),  # consisting of ones
            diagonal=1,  # starting one row above the diagonal, leaving the diagonal itself as zeroes.
        )
        self.register_buffer(
            "mask", mask
        )  # register this tensor as non-trainable, but keep it on the same device
        self.mask: torch.Tensor  # to make the type-checker happy

    def forward_cached(self, x: torch.Tensor) -> torch.Tensor:
        prev_x = x[:, :-1, :]

        keys = self.kv_cache.get_key(prev_x)
        assert(keys is not None)
        values = self.kv_cache.get_val(prev_x)
        assert(values is not None)

        queries = self.w_query(x)

        pass

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, num_tokens, d_in = x.shape
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)

        # Split the last dimension of the tensors into multiple heads
        q_heads = queries.view(batch, num_tokens, self.num_heads, self.head_width)
        k_heads = keys.view(batch, num_tokens, self.num_heads, self.head_width)
        v_heads = values.view(batch, num_tokens, self.num_heads, self.head_width)

        #                                  [  0  ,     1     ,    2     ,      3    ]
        # {q,k,v}_heads now have the shape [batch, num_tokens, num_heads, head_width],
        # but we want them to be:          [batch, num_heads, num_tokens, head_width]
        q_heads = q_heads.transpose(1, 2)
        k_heads = k_heads.transpose(1, 2)
        v_heads = v_heads.transpose(1, 2)

        # now we need to calculate the raw dot-product attention scores between Q and K^T,
        # where K^T has the shape [batch, num_heads, head_width, num_tokens].
        # that gives attention_scores the shape [batch, num_heads, num_tokens, num_tokens]
        attention_scores = q_heads @ k_heads.transpose(2, 3)
        # and apply the causal mask
        mask = self.mask[:num_tokens, :num_tokens]
        attention_scores = attention_scores.masked_fill(mask == 1, float("-inf"))

        # and we construct the weights using softmax on the scaled final dimension
        attention_weights = torch.softmax(
            attention_scores / self.head_width**0.5, dim=-1
        )
        # and apply dropout
        attention_weights = self.dropout(attention_weights)

        #                                 [  0  ,     1    ,     2     ,     3     ]
        # attention_weights has the shape [batch, num_heads, num_tokens, num_tokens]
        # v_heads has the shape:          [batch, num_heads, num_tokens, head_width]
        # if we multiply them, we get:    [batch, num_heads, num_tokens, head_width]
        # but in the end, we want:        [batch, num_tokens, d_out]
        context = (
            attention_weights @ v_heads
        )  # [batch, num_heads, num_tokens, head_width]

        # so we need to first transpose and get [batch, num_tokens, num_heads, head_width]
        context = context.transpose(1, 2)
        # and then concatenate the last two dimensions together to get d_out
        context = context.contiguous().view(batch, num_tokens, self.d_out)
        # and multiply by the output projection
        return self.w_out(context)

In [None]:
w_q = nn.Linear(768, 768, False)
w_k = nn.Linear(768, 768, False)
w_v = nn.Linear(768, 768, False)

x0 = torch.randn(1, 68, 768).type(torch.float32) # the "previous" sequence
new_tok = torch.randn(1, 768).unsqueeze(0) # the new row/token
x1 = torch.cat([x0, new_tok], dim=1) # the new sequence

q0 = w_q(x0)
k0 = w_k(x0)
v0 = w_v(x0)
qkt0 = q0 @ k0.transpose(1,2)

q1 = w_q(x1)
k1 = w_k(x1)
v1 = w_v(x1)
qkt1 = q1 @ k1.transpose(1,2)

cache = qkt0.clone()
qkt1c = cache

q_new = (q1[:, -1:, :] @ k1.transpose(1, 2))[:, :, :-1]
print(f"q_new shape: {q_new.shape}")
qkt1c = torch.cat([qkt1c, q_new], dim=1)
print(f"qkt1c shape: {qkt1c.shape}")
k_new = (q1 @ k1[:, -1:, :].transpose(1,2))
print(f"k_new shape: {k_new.shape}")
qkt1c = torch.cat([qkt1c, k_new], dim=2)
print(f"qkt1c shape: {qkt1c.shape}")

In [57]:
qkt1

tensor([[[235.0335,  -2.6566, -12.2506,  ..., -24.8320,  -2.8388,  15.2576],
         [ -2.6566, 221.3925, -12.1649,  ...,   5.8514,  -6.2172,   1.1360],
         [-12.2506, -12.1649, 301.8590,  ...,  18.5389,   5.1033,  -1.4170],
         ...,
         [-24.8320,   5.8514,  18.5389,  ..., 222.6481, -20.0580, -36.5833],
         [ -2.8388,  -6.2172,   5.1033,  ..., -20.0580, 237.3867,   2.4074],
         [ 15.2576,   1.1360,  -1.4170,  ..., -36.5833,   2.4074, 268.4286]]],
       grad_fn=<UnsafeViewBackward0>)

In [58]:
qkt1c

tensor([[[-4.7412e+00,  1.5995e+00, -3.0473e+00,  ..., -8.5874e+00,
          -4.3166e+00,  1.5258e+01],
         [ 1.0625e+01,  2.1516e+01, -8.1298e+00,  ..., -1.4035e+01,
           7.6252e-01,  1.1360e+00],
         [-8.6974e-01,  1.4532e+01,  1.0480e+00,  ..., -6.3004e-01,
           2.4673e+01, -1.4170e+00],
         ...,
         [-7.2940e-02,  5.4469e+00, -3.5401e+00,  ...,  1.5306e+00,
           3.0205e+00, -3.6583e+01],
         [ 2.4017e+01,  1.1086e+01, -4.3359e+00,  ..., -1.0554e+01,
          -1.7331e+00,  2.4074e+00],
         [ 1.5258e+01,  1.1360e+00, -1.4170e+00,  ..., -3.6583e+01,
           2.4074e+00,  2.6843e+02]]], grad_fn=<CatBackward0>)