In [36]:
import torch 
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [37]:
import torch
import torch.nn as nn

# Define parameters
batch_size = 2
context_length = 5  # Also serves as num_tokens for this example
d_in = 3
d_out = 4
num_heads = 2
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
head_dim = d_out // num_heads  # 2
dropout = 0.0
qkv_bias = False

# Create a batch of random inputs
x = torch.rand(batch_size, context_length, d_in)
print("Input tensor x:")
print(x)
print("Shape of x:", x.shape)  # Expected: (2, 4, 3)

Input tensor x:
tensor([[[0.6506, 0.5452, 0.4177],
         [0.7808, 0.9484, 0.7105],
         [0.2714, 0.9145, 0.9046],
         [0.0708, 0.0866, 0.5198],
         [0.3952, 0.7225, 0.4597]],

        [[0.1547, 0.9644, 0.5771],
         [0.4492, 0.3180, 0.4338],
         [0.2439, 0.5791, 0.9714],
         [0.7821, 0.0414, 0.4141],
         [0.3151, 0.0916, 0.4109]]])
Shape of x: torch.Size([2, 5, 3])


In [None]:
W_query = nn.Linear(d_in, d_out, bias=qkv_bias) #create tensors for query key and value
W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

queries = W_query(x)
keys = W_key(x)
values = W_value(x)

print("Queries tensor:")
print(queries)
print("Shape of queries:", queries.shape)  # Expected: (2, 5, 4) - batch, tokens, d_out

print("\nKeys tensor:")
print(keys)
print("Shape of keys:", keys.shape)  # Expected: (2, 5, 4)

print("\nValues tensor:")
print(values)
print("Shape of values:", values.shape)  # Expected: (2, 5, 4)

Queries tensor:
tensor([[[ 0.1977, -0.3273,  0.1582, -0.0144],
         [ 0.3817, -0.4335,  0.1787,  0.1413],
         [ 0.5048, -0.1818,  0.0891,  0.5025],
         [ 0.1947,  0.0545,  0.1086,  0.2730],
         [ 0.2869, -0.2679,  0.0666,  0.1609]],

        [[ 0.4162, -0.2173, -0.0175,  0.3804],
         [ 0.1734, -0.1793,  0.1494,  0.0653],
         [ 0.4507, -0.0614,  0.1565,  0.5099],
         [ 0.0586, -0.2365,  0.2842, -0.1491],
         [ 0.1274, -0.0663,  0.1504,  0.0893]]], grad_fn=<UnsafeViewBackward0>)
Shape of queries: torch.Size([2, 5, 4])

Keys tensor:
tensor([[[-0.4958,  0.3919, -0.1996, -0.6609],
         [-0.7172,  0.4835, -0.2045, -0.9477],
         [-0.5063,  0.2617, -0.1046, -0.7034],
         [-0.0921,  0.1712, -0.2130, -0.2117],
         [-0.4594,  0.2353, -0.0436, -0.5871]],

        [[-0.4576,  0.0979,  0.1017, -0.5638],
         [-0.3291,  0.3204, -0.2247, -0.4772],
         [-0.3737,  0.3171, -0.2647, -0.5959],
         [-0.3617,  0.5415, -0.4623, -0.5681],


In [None]:
# Reshape queries, keys, and values to split into multiple heads
# This adds a num_heads dimension and reduces the last dim to head_dim
# Without reshaping we cant process attention independently per head.
# The full d_out projection is split into num_heads smaller subspaces (each of size head_dim) allowing diverse attention patterns across heads.
queries = queries.view(batch_size, context_length, num_heads, head_dim)
keys = keys.view(batch_size, context_length, num_heads, head_dim)
values = values.view(batch_size, context_length, num_heads, head_dim)

print("Reshaped queries:")
print(queries)
print("Shape of reshaped queries:", queries.shape)  # Expected: (2, 5, 2, 2) - batch, tokens, heads, head_dim

print("\nReshaped keys:")
print(keys)
print("Shape of reshaped keys:", keys.shape)  # Expected: (2, 5, 2, 2)

print("\nReshaped values:")
print(values)
print("Shape of reshaped values:", values.shape)  # Expected: (2, 5, 2, 2)

Reshaped queries:
tensor([[[[ 0.1977, -0.3273],
          [ 0.1582, -0.0144]],

         [[ 0.3817, -0.4335],
          [ 0.1787,  0.1413]],

         [[ 0.5048, -0.1818],
          [ 0.0891,  0.5025]],

         [[ 0.1947,  0.0545],
          [ 0.1086,  0.2730]],

         [[ 0.2869, -0.2679],
          [ 0.0666,  0.1609]]],


        [[[ 0.4162, -0.2173],
          [-0.0175,  0.3804]],

         [[ 0.1734, -0.1793],
          [ 0.1494,  0.0653]],

         [[ 0.4507, -0.0614],
          [ 0.1565,  0.5099]],

         [[ 0.0586, -0.2365],
          [ 0.2842, -0.1491]],

         [[ 0.1274, -0.0663],
          [ 0.1504,  0.0893]]]], grad_fn=<ViewBackward0>)
Shape of reshaped queries: torch.Size([2, 5, 2, 2])

Reshaped keys:
tensor([[[[-0.4958,  0.3919],
          [-0.1996, -0.6609]],

         [[-0.7172,  0.4835],
          [-0.2045, -0.9477]],

         [[-0.5063,  0.2617],
          [-0.1046, -0.7034]],

         [[-0.0921,  0.1712],
          [-0.2130, -0.2117]],

         [[-0.4594

In [None]:
# Transpose to move the heads dimension before the tokens dimension
# Matrix multiplication for @ operator expects the sequence length (tokens) to be the dimension over which we compute dots.
# After transpose, shape becomes (batch, heads, tokens, head_dim)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

print("Transposed queries:")
print(queries)
print("Shape of transposed queries:", queries.shape)  # Expected: (2, 2, 4, 2) - batch, heads, tokens, head_dim

print("\nTransposed keys:")
print(keys)
print("Shape of transposed keys:", keys.shape)  # Expected: (2, 2, 4, 2)

print("\nTransposed values:")
print(values)
print("Shape of transposed values:", values.shape)  # Expected: (2, 2, 4, 2)

Transposed queries:
tensor([[[[ 0.1977, -0.3273],
          [ 0.3817, -0.4335],
          [ 0.5048, -0.1818],
          [ 0.1947,  0.0545],
          [ 0.2869, -0.2679]],

         [[ 0.1582, -0.0144],
          [ 0.1787,  0.1413],
          [ 0.0891,  0.5025],
          [ 0.1086,  0.2730],
          [ 0.0666,  0.1609]]],


        [[[ 0.4162, -0.2173],
          [ 0.1734, -0.1793],
          [ 0.4507, -0.0614],
          [ 0.0586, -0.2365],
          [ 0.1274, -0.0663]],

         [[-0.0175,  0.3804],
          [ 0.1494,  0.0653],
          [ 0.1565,  0.5099],
          [ 0.2842, -0.1491],
          [ 0.1504,  0.0893]]]], grad_fn=<TransposeBackward0>)
Shape of transposed queries: torch.Size([2, 2, 5, 2])

Transposed keys:
tensor([[[[-0.4958,  0.3919],
          [-0.7172,  0.4835],
          [-0.5063,  0.2617],
          [-0.0921,  0.1712],
          [-0.4594,  0.2353]],

         [[-0.1996, -0.6609],
          [-0.2045, -0.9477],
          [-0.1046, -0.7034],
          [-0.2130, -0.21

In [None]:
# Compute the attention scores (raw dot products before scaling/masking/softmax)
# This is queries @ keys^T for each head.
# tokens_q == tokens_k == context_length, so we get a square matrix per head per batch.
# The reshaping/transposing earlier ensures this matrix mult is head-independent and efficient.
attn_scores = queries @ keys.transpose(2, 3)  # Transpose keys to (batch, heads, head_dim, tokens)

print("Keys transposed for matmul:")
print(keys.transpose(2, 3))
print("Shape of transposed keys:", keys.transpose(2, 3).shape)  # Expected: (2, 2, 2, 4)

print("\nAttention scores tensor:")
print(attn_scores)
print("Shape of attention scores:", attn_scores.shape)  # Expected: (2, 2, 4, 4) - batch, heads, tokens, tokens

Keys transposed for matmul:
tensor([[[[-0.4958, -0.7172, -0.5063, -0.0921, -0.4594],
          [ 0.3919,  0.4835,  0.2617,  0.1712,  0.2353]],

         [[-0.1996, -0.2045, -0.1046, -0.2130, -0.0436],
          [-0.6609, -0.9477, -0.7034, -0.2117, -0.5871]]],


        [[[-0.4576, -0.3291, -0.3737, -0.3617, -0.1881],
          [ 0.0979,  0.3204,  0.3171,  0.5415,  0.2749]],

         [[ 0.1017, -0.2247, -0.2647, -0.4623, -0.2589],
          [-0.5638, -0.4772, -0.5959, -0.5681, -0.3209]]]],
       grad_fn=<TransposeBackward0>)
Shape of transposed keys: torch.Size([2, 2, 2, 5])

Attention scores tensor:
tensor([[[[-0.2263, -0.3000, -0.1858, -0.0742, -0.1678],
          [-0.3591, -0.4833, -0.3067, -0.1094, -0.2773],
          [-0.3215, -0.4499, -0.3032, -0.0776, -0.2747],
          [-0.0752, -0.1133, -0.0843, -0.0086, -0.0766],
          [-0.2472, -0.3353, -0.2154, -0.0723, -0.1948]],

         [[-0.0221, -0.0187, -0.0064, -0.0306,  0.0015],
          [-0.1291, -0.1705, -0.1181, -0.0680, 

In [None]:
# Create the causal mask as in the __init__ method
# self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
# This creates a causal mask which is a  triangular matrix of 1's (the u in triu means above the diagonal) and 0's elsewhere.

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print("Causal mask tensor:")
print(mask)
print("Shape of mask:", mask.shape)  

# Convert to boolean for masking (True where we want to mask (fill with -inf)
mask_bool = mask.bool()[:context_length, :context_length]
print("\nBoolean mask:")
print(mask_bool)
print("Shape of boolean mask:", mask_bool.shape)  

Causal mask tensor:
tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])
Shape of mask: torch.Size([5, 5])

Boolean mask:
tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])
Shape of boolean mask: torch.Size([5, 5])


In [None]:
# Apply the mask to the attention scores
# masking prevents attending to future tokens.
# masked_fill_() sets positions where mask_bool is True to -inf, which will become 0 after softmax.
attn_scores.masked_fill_(mask_bool, -torch.inf)

print("Masked attention scores:")
print(attn_scores)
print("Shape of masked attention scores:", attn_scores.shape)  # Still (2, 2, 5, 4)

Masked attention scores:
tensor([[[[-0.2263,    -inf,    -inf,    -inf,    -inf],
          [-0.3591, -0.4833,    -inf,    -inf,    -inf],
          [-0.3215, -0.4499, -0.3032,    -inf,    -inf],
          [-0.0752, -0.1133, -0.0843, -0.0086,    -inf],
          [-0.2472, -0.3353, -0.2154, -0.0723, -0.1948]],

         [[-0.0221,    -inf,    -inf,    -inf,    -inf],
          [-0.1291, -0.1705,    -inf,    -inf,    -inf],
          [-0.3499, -0.4944, -0.3628,    -inf,    -inf],
          [-0.2021, -0.2809, -0.2034, -0.0809,    -inf],
          [-0.1196, -0.1661, -0.1202, -0.0483, -0.0974]]],


        [[[-0.2117,    -inf,    -inf,    -inf,    -inf],
          [-0.0969, -0.1145,    -inf,    -inf,    -inf],
          [-0.2123, -0.1680, -0.1879,    -inf,    -inf],
          [-0.0500, -0.0950, -0.0969, -0.1492,    -inf],
          [-0.0648, -0.0632, -0.0686, -0.0820, -0.0422]],

         [[-0.2162,    -inf,    -inf,    -inf,    -inf],
          [-0.0216, -0.0648,    -inf,    -inf,    -inf]

In [None]:
# Scale the attention scores by sqrt(head_dim) to stabilize gradients (scaled dot-product attention) (i had to look this up)
# Then apply softmax along the last dimension (across keys for each query)
# Dot products can grow large with high dimensions, leading to small gradients after softmax which turns scores into probabilities between 0 and 1 (attention weights).
scale_factor = keys.shape[-1] ** 0.5  # head_dim ** 0.5
attn_weights = torch.softmax(attn_scores / scale_factor, dim=-1)

print("Attention weights after softmax:")
print(attn_weights)
print("Shape of attention weights:", attn_weights.shape)  # Expected: (2, 2, 5, 4)

Attention weights after softmax:
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5219, 0.4781, 0.0000, 0.0000, 0.0000],
          [0.3417, 0.3121, 0.3462, 0.0000, 0.0000],
          [0.2491, 0.2424, 0.2475, 0.2611, 0.0000],
          [0.1949, 0.1831, 0.1993, 0.2205, 0.2022]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5073, 0.4927, 0.0000, 0.0000, 0.0000],
          [0.3456, 0.3120, 0.3424, 0.0000, 0.0000],
          [0.2479, 0.2344, 0.2476, 0.2700, 0.0000],
          [0.1986, 0.1922, 0.1985, 0.2089, 0.2018]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5031, 0.4969, 0.0000, 0.0000, 0.0000],
          [0.3280, 0.3384, 0.3337, 0.0000, 0.0000],
          [0.2585, 0.2504, 0.2501, 0.2410, 0.0000],
          [0.1999, 0.2001, 0.1994, 0.1975, 0.2031]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5076, 0.4924, 0.0000, 0.0000, 0.0000],
          [0.3396, 0.3380, 0.3224, 0.0000, 0.0000],
          [0.2664, 0.24

In [None]:
# Apply dropout to attention weights (though dropout=0.0, so no change here)
# dropout basically drops a few random weights from the tensor so that during training we dont encounter overfitting.
# In the class, self.dropout = nn.Dropout(dropout)
dropout_layer = nn.Dropout(dropout)
attn_weights = dropout_layer(attn_weights)

print("Attention weights after dropout:")
print(attn_weights)
print("Shape after dropout:", attn_weights.shape)  # Still (2, 2, 4, 4)

Attention weights after dropout:
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5219, 0.4781, 0.0000, 0.0000, 0.0000],
          [0.3417, 0.3121, 0.3462, 0.0000, 0.0000],
          [0.2491, 0.2424, 0.2475, 0.2611, 0.0000],
          [0.1949, 0.1831, 0.1993, 0.2205, 0.2022]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5073, 0.4927, 0.0000, 0.0000, 0.0000],
          [0.3456, 0.3120, 0.3424, 0.0000, 0.0000],
          [0.2479, 0.2344, 0.2476, 0.2700, 0.0000],
          [0.1986, 0.1922, 0.1985, 0.2089, 0.2018]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5031, 0.4969, 0.0000, 0.0000, 0.0000],
          [0.3280, 0.3384, 0.3337, 0.0000, 0.0000],
          [0.2585, 0.2504, 0.2501, 0.2410, 0.0000],
          [0.1999, 0.2001, 0.1994, 0.1975, 0.2031]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5076, 0.4924, 0.0000, 0.0000, 0.0000],
          [0.3396, 0.3380, 0.3224, 0.0000, 0.0000],
          [0.2664, 0.24

In [None]:
# Compute the context vectors: weighted sum of values using attention weights
# attn_weights (b, heads, tokens_q, tokens_k) @ values (b, heads, tokens_k, head_dim) -> (b, heads, tokens_q, head_dim)
# this aggregates information from relevant tokens per query.
context_vec = attn_weights @ values

print("Context vectors before transpose:")
print(context_vec)
print("Shape of context_vec:", context_vec.shape)  # Expected: (2, 2, 4, 2) - batch, heads, tokens, head_dim

Context vectors before transpose:
tensor([[[[ 0.1867, -0.3404],
          [ 0.2780, -0.4507],
          [ 0.3282, -0.5084],
          [ 0.2450, -0.4368],
          [ 0.2569, -0.4251]],

         [[-0.0205,  0.3134],
          [ 0.0517,  0.3706],
          [ 0.1448,  0.3457],
          [ 0.1017,  0.2873],
          [ 0.1210,  0.2811]]],


        [[[ 0.4876, -0.4963],
          [ 0.2888, -0.3873],
          [ 0.2677, -0.4408],
          [ 0.1784, -0.3860],
          [ 0.1347, -0.3467]],

         [[ 0.4297,  0.2038],
          [ 0.1909,  0.2225],
          [ 0.1760,  0.2447],
          [ 0.0526,  0.2654],
          [ 0.0137,  0.2496]]]], grad_fn=<UnsafeViewBackward0>)
Shape of context_vec: torch.Size([2, 2, 5, 2])


In [None]:
# Transpose back --> move heads after tokens to prepare for combining heads
# (b, heads, tokens, head_dim) -> (b, tokens, heads, head_dim)
context_vec = context_vec.transpose(1, 2)

print("Transposed context_vec:")
print(context_vec)
print("Shape after transpose:", context_vec.shape)  # Expected: (2, 5, 2, 2)

Transposed context_vec:
tensor([[[[ 0.1867, -0.3404],
          [-0.0205,  0.3134]],

         [[ 0.2780, -0.4507],
          [ 0.0517,  0.3706]],

         [[ 0.3282, -0.5084],
          [ 0.1448,  0.3457]],

         [[ 0.2450, -0.4368],
          [ 0.1017,  0.2873]],

         [[ 0.2569, -0.4251],
          [ 0.1210,  0.2811]]],


        [[[ 0.4876, -0.4963],
          [ 0.4297,  0.2038]],

         [[ 0.2888, -0.3873],
          [ 0.1909,  0.2225]],

         [[ 0.2677, -0.4408],
          [ 0.1760,  0.2447]],

         [[ 0.1784, -0.3860],
          [ 0.0526,  0.2654]],

         [[ 0.1347, -0.3467],
          [ 0.0137,  0.2496]]]], grad_fn=<TransposeBackward0>)
Shape after transpose: torch.Size([2, 5, 2, 2])


In [None]:
# Combine the heads, reshape back to (b, tokens, d_out) by merging heads and head_dim
#  We split into heads for multi-head diversity; now concatenate the head outputs back into full d_out.
context_vec = context_vec.contiguous().view(batch_size, context_length, d_out)

print("Combined context_vec:")
print(context_vec)
print("Shape after view:", context_vec.shape)  # Expected: (2, 5, 4) - batch, tokens, d_out

Combined context_vec:
tensor([[[ 0.1867, -0.3404, -0.0205,  0.3134],
         [ 0.2780, -0.4507,  0.0517,  0.3706],
         [ 0.3282, -0.5084,  0.1448,  0.3457],
         [ 0.2450, -0.4368,  0.1017,  0.2873],
         [ 0.2569, -0.4251,  0.1210,  0.2811]],

        [[ 0.4876, -0.4963,  0.4297,  0.2038],
         [ 0.2888, -0.3873,  0.1909,  0.2225],
         [ 0.2677, -0.4408,  0.1760,  0.2447],
         [ 0.1784, -0.3860,  0.0526,  0.2654],
         [ 0.1347, -0.3467,  0.0137,  0.2496]]], grad_fn=<ViewBackward0>)
Shape after view: torch.Size([2, 5, 4])


In [49]:
# Finally, apply the output projection: a linear layer to mix the combined head outputs
# In the class, self.out_proj = nn.Linear(d_out, d_out)
# This is optional but allows learning a better combination of heads.
out_proj = nn.Linear(d_out, d_out)  # Note: weights are random here, as in init

context_vec = out_proj(context_vec)

print("Final context_vec after output projection:")
print(context_vec)
print("Shape of final context_vec:", context_vec.shape)  # Expected: (2, 4, 4)

Final context_vec after output projection:
tensor([[[ 0.0787,  0.6367, -0.2230,  0.4223],
         [ 0.1166,  0.6805, -0.1950,  0.5090],
         [ 0.1025,  0.6753, -0.1976,  0.5708],
         [ 0.0571,  0.6449, -0.2240,  0.5105],
         [ 0.0575,  0.6312, -0.2254,  0.5138]],

        [[ 0.0580,  0.5379, -0.2327,  0.6814],
         [ 0.0375,  0.5728, -0.2434,  0.5246],
         [ 0.0336,  0.6157, -0.2356,  0.5385],
         [ 0.0337,  0.6351, -0.2368,  0.4610],
         [ 0.0198,  0.6258, -0.2460,  0.4240]]], grad_fn=<ViewBackward0>)
Shape of final context_vec: torch.Size([2, 5, 4])
