# **Attention Mechanisms**

In [1]:
import torch

## The Need for Attention - In a Nutshell

As this notebook is being written (mid 2025), the seminal paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) is nearly 8 years old. The abstract states:

>The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. 
>
>Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data. 

Eschewing encoder-decoder RNNS, which were commonly used in translation tasks, in favour of an architecture which could selectively access all input tokens of a given sequence and assign signifance on a comparative basis helped to usher in the age of LLMs as we know them.

Transformers built using attention mechanisms addressed the limitations of RNNS i.e.:

1. **Short range dependency** i.e. the failure to grasp connections between distant words / sequences.
2. **Limited Parallelism** i.e. slow processing of information due to sequential design.
3. **Focus on Local Context** i.e. primarily considering immediate neighbours and potentially missing critical information from other parts of the input sequence.

## 1. Self Attention

The primary function of self attention is to generate "context-aware" vectors from the input sequence.

![title](.//images/self-attention.webp)

_figure from Deep Learning with Python by Francios Chollet_


> The "self" refers to the mechanisms ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image. - S. Raschka

### 1.1 Simple Self-Attention Without Trainable Weights

The goal is to compute a context vector for each input element that combines information from all other input elements.

In [2]:
# Input tensor representing a sequence which has already by embedded into 3 dimensional vectors
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your    (x^1)
     [0.55, 0.87, 0.66], # journey (x^2)
     [0.57, 0.85, 0.64], # starts  (x^3)
     [0.22, 0.58, 0.33], # with    (x^4)
     [0.77, 0.25, 0.10], # one     (x^5)
     [0.05, 0.80, 0.55]] # step    (x^6)
)

In [3]:
# Calculate intermediate attention scores by taking dot product of the query with every other input token
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    # dot product without transpose since these are 1 dim vectors
    attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [4]:
# Illustration of dot products at work
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


> Beyond viewing the dot product operations as a mathematical tool...(it) is a measure of similarity because it quantifies how closely two vectors are aligned... In the context of self attention mechanisms, the dot product determines the extent to which each element in a sequence focuses on, or 'attends to' any other element.

In [5]:
# Simple normalization of the attention scores, they should sum upto ~1.0
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


In [6]:
# In practice softmax is used since it is better at handling extreme values and always results in positive values.
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights: ", attn_weights_2_naive)
print("Sum: ", attn_weights_2_naive.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [7]:
# To correct for numerical instability i.e. over and underflow...stick with the pytorch version which is extensively optimized
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights: ", attn_weights_2)
print("Sum: ", attn_weights_2.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


In [8]:
# The final step focuses on creating the context vector
# Multiply the embedded input tokens with the corresponding attention weights
# and then summing.
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
    
print(context_vec_2)


tensor([0.4419, 0.6515, 0.5683])


### 1.2 Computing Attention Weights for all Input Tokens

Let's extend the computation to calculate attention weights and context vectors for all inputs.

In [9]:
# Apply the previous step to all pairwise elements to compute the un-normalized attention score matrix
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs): # additional loop to compute the dot products for all pairs of inputs
        attn_scores[i, j] = torch.dot(x_i, x_j)
        
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [10]:
# Replace the for loops with matrix multiplication
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [11]:
# Normalize each row so values sum up to 1
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [12]:
# Quick check
row_3_sum = attn_weights[2].sum(dim=0)
all_row_sum = attn_weights.sum(dim=-1)

print(f"Row 3 sum: {row_3_sum}")
print(f"All row sum: {all_row_sum}")

Row 3 sum: 1.0000001192092896
All row sum: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [13]:
# Using the attention weights to compute all context vectors via matmul
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 2. Implementing Self Attention With Trainable Weights

With the fundamentals out of the way, it is time to implement the self-attention mechanism used in the original transformer architecture. This self-attention mechanism is also called _scaled dot-product attention_.

Context vectors now have to be computed as weighted sums over the input vectors specific to a certain input element.

### 2.1 Computing the Attention Weights

 Let us begin by adding three trainable weight matrices into the mix i.e. $W_{q}$, $W_{k}$ and $W_{v}$. These are used to project the embedded input tokens into query, key and value vectors.

- Query vector: $q^{(i)} = x^{(i)}\,W_q $
- Key vector: $k^{(i)} = x^{(i)}\,W_k $
- Value vector: $v^{(i)} = x^{(i)}\,W_v $

For illustration, we can begin by computing only one context vector before moving onto all context vectors.

In [14]:
# GPT models have similar input and output dimensions, but we're changing things to understand computations better.

x_2 = inputs[1] # 2nd input element
d_in = inputs.shape[1] # input embedding size, d=3
d_out = 2 # output embedding size, d=2

In [15]:
torch.manual_seed(42)
# Setting requires_grad=False for cleaner outputs
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [16]:
# Computing query, key and value vectors wrt the 2nd input element
query_2 = x_2 @ W_query
key_2   = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2, key_2, value_2)

tensor([1.0760, 1.7344]) tensor([1.5764, 0.9441]) tensor([1.7073, 1.0646])


In [17]:
# 6 input tokens projected from 3d to 2d embedding space
keys = inputs @ W_key 
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [18]:
# Now, computing the un-normalized attention scores by taking the dot product of the query and each
# key vector
keys_2 = keys[1]
attn_score_22 = query_2.dot(key_2)
print(attn_score_22)

tensor(3.3338)


With regard to attention scores, the difference from the previous section is that instead of directly computing the dot-product between the input elements, we will use the query and key obtained by transforming the inputs via respective weight matrices.

In [19]:
# Generalizing to all attention scores via matrix multiplications
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([2.7084, 3.3338, 3.3013, 1.7563, 1.7869, 2.1966])


Attention weights are computed by scaling the attention scores and using softmax. But, at this point, we can scale the attention scores by dividing them by the square root of the embedding dimension of the keys, which is the mathematically similar to exponentiating by 0.5.

Normalizing by the embedding dimension size improves training performance by avoiding small gradients caused by the usage of softmax when scaling up.

> As dot products increase, the softmax function behaves like a step function, resulting in gradients nearing zero...Scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.

In [20]:
d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1723, 0.2681, 0.2620, 0.0879, 0.0898, 0.1200])


In [21]:
# Computing context vector
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([1.4201, 0.8892])


### 2.2 Creating a Compact SelfAttention Class 

In [22]:
import torch.nn as nn 

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        
        context_vec = attn_weights @ values
        return context_vec        
    

torch.manual_seed(42)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[1.3751, 0.8610],
        [1.4201, 0.8892],
        [1.4198, 0.8890],
        [1.3533, 0.8476],
        [1.3746, 0.8606],
        [1.3620, 0.8532]], grad_fn=<MmBackward0>)


In [23]:
 # Quick check
context_vec_2

tensor([1.4201, 0.8892])

The Self-Attention class can make use of `nn.Linear` layers to carry out matrix multiplications (especially when bias units have been disabled). These layers also have optimized weight initialization schemes which contribute to more stable training. Consequently, when compared, `SelfAttention_v1` and `SelfAttention_v2` will have differing outputs due to the differences in weight initialization.

In [24]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys =  self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vec = attn_weights @ values
        return context_vec
    
torch.manual_seed(42)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)


In [25]:
# Exercise - Assigning the weights from an instance of SelfAttention__v2 to an instance of v1
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)

sa_v1(inputs)

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

## 3. Implementing Causal Attention 

Causal or _masked_ attention restricts the model to only consider previous and current inputs in a sequence when processing any given input token's attention scores. Attention weights "above the diagonal" are masked out and non-masked attention weights are normalized so that they sum to 1. in each row.

![title](.//images/causal-self-attention.webp)

To apply this mask the previous self-attention mechanism will be converted to causal self-attention.

In [26]:
# We shall reuse the query and key weights of SelfAttention_v2
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1605, 0.1726, 0.1714, 0.1681, 0.1473, 0.1801],
        [0.1627, 0.1780, 0.1758, 0.1648, 0.1306, 0.1880],
        [0.1625, 0.1782, 0.1759, 0.1648, 0.1302, 0.1885],
        [0.1661, 0.1726, 0.1715, 0.1654, 0.1475, 0.1768],
        [0.1596, 0.1777, 0.1755, 0.1664, 0.1312, 0.1896],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<SoftmaxBackward0>)


In [27]:
# Pytorch's tril function is a simple way to create a mask.
# Here values above diagonal are zero
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [28]:
# Multiplying the mask with the attention weights to zero out the values above the diagonal
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1605, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1627, 0.1780, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1625, 0.1782, 0.1759, 0.0000, 0.0000, 0.0000],
        [0.1661, 0.1726, 0.1715, 0.1654, 0.0000, 0.0000],
        [0.1596, 0.1777, 0.1755, 0.1664, 0.1312, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<MulBackward0>)


It should be noted that applying the mask after normalizing with softmax would disrupt the probability distribution of the initial softmax norm.

When negative `inf` values are present in a row, the softmax function treats them as zero probability, since $e^{-inf}$ approaches zero. So we can create a more effective mask by replacing zeros above the diagonal with $-inf$

In [29]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # insert 1s above the diagonal
masked = attn_scores.masked_fill(mask.bool(), -torch.inf) # swap with -inf
print(masked)

tensor([[ 0.0508,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2157,  0.3428,    -inf,    -inf,    -inf,    -inf],
        [ 0.2163,  0.3467,  0.3282,    -inf,    -inf,    -inf],
        [ 0.1257,  0.1799,  0.1707,  0.1191,    -inf,    -inf],
        [ 0.1667,  0.3193,  0.3012,  0.2258, -0.1098,    -inf],
        [ 0.1269,  0.1548,  0.1475,  0.0978, -0.0247,  0.1731]],
       grad_fn=<MaskedFillBackward0>)


In [30]:
# Normalizing the updated masked attention weights
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4775, 0.5225, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3146, 0.3450, 0.3405, 0.0000, 0.0000, 0.0000],
        [0.2459, 0.2555, 0.2538, 0.2448, 0.0000, 0.0000],
        [0.1969, 0.2193, 0.2165, 0.2053, 0.1619, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<SoftmaxBackward0>)


### 3.1 Masking Attention Weights with Dropout

In [31]:
torch.manual_seed(42)
dropout = torch.nn.Dropout(0.5) # Dropout rate of 50% for illustration
example = torch.ones(6, 6)

print(dropout(example))

tensor([[2., 2., 2., 2., 0., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 2., 0., 2., 0.],
        [0., 2., 2., 0., 2., 2.],
        [2., 2., 0., 2., 2., 2.],
        [2., 2., 2., 2., 0., 0.]])


In [33]:
# Doing the same for the attention weights we calculated in the previous section
torch.manual_seed(42)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6809, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5110, 0.5077, 0.0000, 0.0000, 0.0000],
        [0.3938, 0.4387, 0.0000, 0.4106, 0.3239, 0.0000],
        [0.3364, 0.3431, 0.3413, 0.3295, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


In [34]:
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4775, 0.5225, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3146, 0.3450, 0.3405, 0.0000, 0.0000, 0.0000],
        [0.2459, 0.2555, 0.2538, 0.2448, 0.0000, 0.0000],
        [0.1969, 0.2193, 0.2165, 0.2053, 0.1619, 0.0000],
        [0.1682, 0.1715, 0.1707, 0.1648, 0.1511, 0.1738]],
       grad_fn=<SoftmaxBackward0>)

With a 50% dropout rate, the values of the remaining elements are scaled up by a factor of `1 / 0.5 = 2`, which is crucial to maintain the overall balance of the attention weights and the average attention mechanism remains consistent during training and inference.

Numerical values of the tensor passed to `nn.Dropout()` produces different ouputs on different OS'. This is currently an open issue.

### 3.2 Implementing a Compact Causal Self-Attention Class

In [39]:
batch = torch.stack((inputs, inputs), dim=0)
# This results in batch size of 2 with 6 tokens each and each token has embedding dimension of 3
batch.shape

torch.Size([2, 6, 3])

In [40]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out   = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        # b is the batch dim
        b, num_tokens, d_in = x.shape
        # num_tokens exceeding contenxt_length will result in errors during masking
        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1, 2)
        # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec

In [41]:
torch.manual_seed(42)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape", context_vecs.shape)

tensor([[[0.3755, 0.2777],
         [0.3761, 0.2831],
         [0.3761, 0.2833],
         [0.3768, 0.2763],
         [0.3754, 0.2836],
         [0.3772, 0.2746]],

        [[0.3755, 0.2777],
         [0.3761, 0.2831],
         [0.3761, 0.2833],
         [0.3768, 0.2763],
         [0.3754, 0.2836],
         [0.3772, 0.2746]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape torch.Size([2, 6, 2])


In [None]:
torch.manual_seed(42)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.5) # Testing 50% dropout rate.

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape", context_vecs.shape)

tensor([[[0.5627, 0.4484],
         [0.3605, 0.3069],
         [0.5189, 0.3813],
         [0.7535, 0.5526],
         [0.3299, 0.2006],
         [0.5052, 0.2738]],

        [[0.3351, 0.2701],
         [0.4826, 0.4155],
         [0.4810, 0.4097],
         [0.3352, 0.2702],
         [0.3169, 0.1386],
         [0.3387, 0.1419]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape torch.Size([2, 6, 2])
