In [39]:
import torch
import torch.nn as nn

In [40]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your      (x^1)
    [0.55, 0.87, 0.66], # journey   (x^2)
    [0.57, 0.85, 0.64], # starts    (x^3)
    [0.22, 0.58, 0.33], # with      (x^4)
    [0.77, 0.25, 0.10], # one       (x^5)
    [0.05, 0.80, 0.55]  # step      (x^6)
])

In [41]:
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        
        self.Wq = nn.Parameter(torch.rand(d_in, d_out))
        self.Wk = nn.Parameter(torch.rand(d_in, d_out))
        self.Wv = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        queries = x @ self.Wq
        keys = x @ self.Wk
        values = x @ self.Wv
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        
        return context_vec

In [42]:
d_in = inputs.shape[1]
d_out = 2

In [43]:
torch.manual_seed(666)

self_attn_v1 = SelfAttentionV1(d_in, d_out)

In [44]:
print(self_attn_v1(inputs))

tensor([[0.9859, 1.0707],
        [0.9891, 1.0748],
        [0.9892, 1.0749],
        [0.9825, 1.0657],
        [0.9863, 1.0707],
        [0.9825, 1.0659]], grad_fn=<MmBackward0>)


In [45]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        queries = self.Wq(x)
        keys = self.Wk(x)
        values = self.Wv(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        
        return context_vec

In [46]:
self_attn_v2 = SelfAttentionV2(d_in, d_out)

In [47]:
print(self_attn_v2(inputs))

tensor([[-0.5231,  0.2215],
        [-0.5259,  0.2240],
        [-0.5260,  0.2239],
        [-0.5282,  0.2260],
        [-0.5281,  0.2231],
        [-0.5274,  0.2267]], grad_fn=<MmBackward0>)


In [48]:
# Exercise 3.1 - Comparing SelfAttention_v1 and SelfAttention_v2

sa_v1 = SelfAttentionV1(d_in, d_out)
sa_v1.Wq = nn.Parameter(self_attn_v2.Wq.weight.T)
sa_v1.Wk = nn.Parameter(self_attn_v2.Wk.weight.T)
sa_v1.Wv = nn.Parameter(self_attn_v2.Wv.weight.T)

In [50]:
print("SelfAttentionV1 Context Vectors:\n\n", sa_v1(inputs))

SelfAttentionV1 Context Vectors:

 tensor([[-0.5231,  0.2215],
        [-0.5259,  0.2240],
        [-0.5260,  0.2239],
        [-0.5282,  0.2260],
        [-0.5281,  0.2231],
        [-0.5274,  0.2267]], grad_fn=<MmBackward0>)
