In [2]:
import torch
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self,x):
        keys = x @ self.W_key
        queries  = x @ self.W_query
        values = x @ self.W_value

        attn_score = queries @ keys.T
        attn_weights = torch.softmax(
            attn_score / keys.shape[-1]**0.5, dim = -1
        )
        context_vec = attn_weights @ values
        return context_vec


In [3]:
torch.manual_seed(123)
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

d_in = inputs.shape[1]
d_out = 2
sa_v1 = SelfAttention_v1(d_in,d_out)

In [4]:
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [5]:
import torch
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)



    def forward(self,x):
        keys = self.W_key(x)
        queries  = self.W_query(x)
        values = self.W_value(x)

        attn_score = queries @ keys.T
        attn_weights = torch.softmax(
            attn_score / keys.shape[-1]**0.5, dim = -1
        )
        context_vec = attn_weights @ values
        return context_vec

In [6]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in,d_out)
#sa_v2(inputs)

In [7]:
sa_v2.W_key.weight

Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)

In [8]:
sa_v2.W_key.weight

Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)

In [9]:
m = 4
mask = torch.triu(torch.ones(m,m),diagonal=1)
attn_scores = torch.rand(m,m)
attn_scores


tensor([[0.7662, 0.8018, 0.6371, 0.3464],
        [0.2020, 0.1266, 0.4615, 0.7050],
        [0.3849, 0.3085, 0.6775, 0.0966],
        [0.9746, 0.7144, 0.4761, 0.2971]])

In [10]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.7662,   -inf,   -inf,   -inf],
        [0.2020, 0.1266,   -inf,   -inf],
        [0.3849, 0.3085, 0.6775,   -inf],
        [0.9746, 0.7144, 0.4761, 0.2971]])

In [11]:
torch.softmax(masked, dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5188, 0.4812, 0.0000, 0.0000],
        [0.3061, 0.2836, 0.4102, 0.0000],
        [0.3465, 0.2671, 0.2105, 0.1760]])

In [12]:
# implementing a compact causal attention class

In [13]:
batch = torch.stack((inputs, inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [14]:
# Lets understand dropouts

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)

example = torch.ones(6,3)/6
print(dropout(example))

tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.0000, 0.3333, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3333],
        [0.0000, 0.3333, 0.0000]])


In [15]:
torch.triu(torch.ones(5, 5), diagonal=1)

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

In [16]:
#Listing 3.3: A compact causal attention class

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, 
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens,:num_tokens], -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec



In [17]:
torch.manual_seed(123)
context_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length,0.0)

In [18]:
context_vecs = ca(batch)
context_vecs

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

In [19]:
# Extendeing single-head attention to multi-head attention

In [20]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out,context_length, dropout,qkv_bias)
             for _ in range(num_heads)]
        )
    
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
    

In [21]:
context_length

6

In [22]:
torch.manual_seed(123)
context_length = batch.shape[1]

d_in, d_out = 3,1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)
context_vecs

tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)

In [32]:
# An efficient multi-head attention class

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout, num_heads, qkv_bias = True):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self,x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b,num_tokens,self.num_heads, self.head_dim)
        values = values.view(b,num_tokens,self.num_heads, self.head_dim)
        queries = queries.view(b,num_tokens,self.num_heads, self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(
            attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec)
        return context_vec
 
        



b = 1
num_heads = 2
num_token = 3
d_in = 4
d_out = 9
num_heads = 3
context_length = num_token

a = torch.rand(b,num_heads, num_token,d_in)
mha = MultiHeadAttention(d_in, d_out,context_length,0.0,num_heads)



In [33]:
mha(batch)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (12x3 and 4x9)