In [1]:
!pip install torch

Collecting torch
  Downloading torch-2.5.1-cp312-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Using cached typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.5-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting setuptools (from torch)
  Downloading setuptools-75.8.0-py3-none-any.whl.metadata (6.7 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Using cached MarkupSafe-3.

# Key-Value Attention

In [4]:
import torch


def key_value_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, k_v_cache: dict) -> torch.Tensor:

    if not k_v_cache:
        k_v_cache = {"key": [], "value": []}
    if k_v_cache["key"] is None:
        k_v_cache["key"] = key
        k_v_cache["value"] = value
    else:
        k_v_cache["key"] = torch.cat([k_v_cache["key"], key], dim=1)
        k_v_cache["value"] = torch.cat([k_v_cache["value"], value], dim=1)
    # Compute attention
    scale = 1/ torch.sqrt(torch.tensor(query.shape[-1])) 
    compatibility = torch.matmul(query, k_v_cache["key"].transpose(-2, -1)) # (batch_size, seq_len, seq_len)
    attention = torch.softmax(compatibility * scale, dim=-1) # (batch_size, seq_len, seq_len)
    weighted_attention = torch.matmul(attention, k_v_cache["value"]) # (batch_size, seq_len, d_value)

    return weighted_attention, k_v_cache




In [5]:
# Test
query = torch.randn(1, 1, 10)
key = torch.randn(1, 1, 10)
value = torch.randn(1, 1, 10)

print(key_value_attention(query, key, value))

tensor([[[ 0.0574,  1.3514,  0.4958,  0.0324,  0.1390,  1.5929,  0.3412,
           1.3585, -0.2910, -0.5900]]])
