# 1. Understanding Attention

- Before running the jupyter notebook, don't forget to copy it into your drive **(`File` => `Save a copy in Drive`)**. *Failing to do this step may result in losing the progress of your code.*
- For this notebook, please replace the placeholder answers directly after a `#TODO` comment with your answers.
- Please only use constants. If your want to use one row or column of `key` or `value` as your answer, please write that out (i.e., `torch.tensor([...])`).

## Imports and Setup

In [38]:
print("hello")

hello


In [39]:
import torch
import torch.nn.functional as F

In [40]:
torch.manual_seed(447)

key = torch.randn(4, 3)
key /= torch.norm(key, dim=1, keepdim=True)
key.round_(decimals=2)

value = torch.randn(4, 3)
value /= torch.norm(value, dim=1, keepdim=True)
value.round_(decimals=2)

print(f"key:\n{key}")
print(f"value:\n{value}")

key:
tensor([[ 0.4700,  0.6500,  0.6000],
        [ 0.6400,  0.5000, -0.5900],
        [-0.0300, -0.4800, -0.8800],
        [ 0.4300, -0.8300,  0.3500]])
value:
tensor([[-0.0700, -0.8800,  0.4700],
        [ 0.3700, -0.9300, -0.0700],
        [-0.2500, -0.7500,  0.6100],
        [ 0.9400,  0.2000,  0.2800]])


In [41]:
def attention(query, key, value):
    """
    Note that we remove scaling for simplicity.
    """
    return F.scaled_dot_product_attention(query, key, value, scale=1)


def check_query(query, target, key, value):
    """
    Helper function for you to check if your query is close to the required target matrix.
    """
    a_out = attention(query, key, value)
    return (target - a_out).abs().max()

## 1.2. Selection via Attention

In [42]:
# Define a query vector to ”select” the first value vector


def get_query121():
    return torch.tensor([[4.7, 6.5, 6.0]])


print(get_query121())

# compare output of attention with desired output
diff = check_query(get_query121(), value[0], key=key, value=value)
print(diff)

assert diff < 0.05

tensor([[4.7000, 6.5000, 6.0000]])
tensor(0.0004)


In [48]:
# Define a query matrix which results in an identity mapping – select all the value vectors


def get_query122():
    return 10 * torch.tensor([[ 0.4700,  0.6500,  0.6000],
                        [ 0.6400,  0.5000, -0.5900],
                        [-0.0300, -0.4800, -0.8800],
                        [ 0.4300, -0.8300,  0.3500]])


print(get_query122())

# compare output of attention with desired output
diff = check_query(get_query122(), value, key=key, value=value)
print(diff)

assert diff < 0.05

tensor([[ 4.7000,  6.5000,  6.0000],
        [ 6.4000,  5.0000, -5.9000],
        [-0.3000, -4.8000, -8.8000],
        [ 4.3000, -8.3000,  3.5000]])
tensor(0.0007)


### Write-up part:

A transformer’s attention mechanism can “copy” or “re-use” the most relevant tokens by giving them high attention weight. In language modeling, this is highly valuable because the model often needs to reproduce or refer to words or phrases from the recent context (e.g., entity names, rare tokens, or repeating phrases). Being able to “copy” directly helps maintain consistency and coherence in generated text, especially for longer contexts or structured repetitions.

## 1.3. Averaging via Attention

In [44]:
# define a query vector which averages all the value vectors


def get_query131():
    return torch.tensor([[0.0, 0.0, 0.0]])


print(get_query131())

# compare output of attention with desired output
target = torch.reshape(value.mean(0, keepdims=True), (3,))  # reshape to a vector
diff = check_query(get_query131(), target, key=key, value=value)
print(diff)

assert diff < 0.05

tensor([[0., 0., 0.]])
tensor(0.)


In [45]:
# define a query vector which averages the first two value vectors


def get_query132():
    S = 100.0
    k0 = torch.tensor([[0.47, 0.65, 0.60]])
    k1 = torch.tensor([[0.64, 0.50, -0.59]])
    q = S * ((k0 + k1) / 2)
    return q

print(get_query132())

# compare output of attention with desired output
target = torch.reshape(
    value[(0, 1),].mean(0, keepdims=True), (3,)
)  # reshape to a vector
diff = check_query(get_query132(), target, key=key, value=value)
print(diff)

assert diff < 0.05

tensor([[55.5000, 57.5000,  0.5000]])
tensor(0.0289)


### Write-up part:

In a language‐modeling context, the model often needs to blend or “smoothly combine” information from multiple tokens—e.g. when resolving a coreference that depends on multiple clues, or when synthesizing context from multiple parts of a sentence or paragraph. The ability to aggregate via attention means the model can learn to produce representations that are “mixes” of relevant tokens, rather than copying just a single one. This flexibility to combine rather than only copy is crucial to capturing nuanced context and producing coherent, contextually informed predictions.

## 1.4. Interactions within Attention

In [49]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) averages the first three value vectors.
m_key = key.clone()


def get_key141():
    key = torch.tensor([[ 0.4700,  0.6500,  0.6000],
                  [ 0.6400,  0.5000, -0.5900],
                  [-0.0300, -0.4800, -0.8800],
                  [ 0.4300, -0.8300,  0.3500]])
    return (key[0] + key[1]) / 2

m_key[2] = get_key141()

# compare output of attention with desired output
diff = check_query(get_query132(), value[(0, 1, 2),].mean(0, keepdims=True), key=m_key, value=value)
print(diff)

assert diff < 0.05

tensor(0.0198)


In [50]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) returns the third value vector v[2].
m_key = key.clone()


def get_key142():
    key = torch.tensor([[ 0.4700,  0.6500,  0.6000],
                        [ 0.6400,  0.5000, -0.5900],
                        [-0.0300, -0.4800, -0.8800],
                        [ 0.4300, -0.8300,  0.3500]])
    avg = (key[0] + key[1]) / 2
    return avg / avg.norm()


m_key[2] = get_key142()
m_key[2] /= m_key[2].norm()

# compare output of attention with desired output
diff = check_query(get_query132(), value[2], key=m_key, value=value)
print(f"diff = {diff}")

assert diff < 0.05

diff = 1.1920928955078125e-07
