# self attention with out weights

In [1]:
import torch 
import torch.nn as nn

In [2]:
inputs=torch.randn(5,3)

In [3]:
query=inputs[2]

attn_score=torch.empty(inputs.shape[0])

for i,val in enumerate(inputs):
    attn_score[i]=torch.dot(val,query)


print(f"the input:{query}")
print(f"the attention score:{attn_score}")


the input:tensor([-0.9503, -0.2291, -0.0102])
the attention score:tensor([-0.1662,  2.4919,  0.9556,  0.7978, -0.0709])


# we know make the attention weight by normalizing the attention score

In [4]:
attn_weight=attn_score/attn_score.sum()
print(f"the attention score is: {attn_score}")
print(f"the attention weight is:{attn_weight}")

the attention score is: tensor([-0.1662,  2.4919,  0.9556,  0.7978, -0.0709])
the attention weight is:tensor([-0.0415,  0.6217,  0.2384,  0.1990, -0.0177])


In [5]:
def local_softmax(x):
    return torch.exp(x)/torch.exp(x).sum(dim=0)

In [6]:
local_softmax(attn_score)

tensor([0.0453, 0.6468, 0.1392, 0.1189, 0.0499])

In [7]:
torch.softmax(attn_score,dim=0)

tensor([0.0453, 0.6468, 0.1392, 0.1189, 0.0499])

# calculate the context vector

In [8]:
#with respect to the third(2) token
context_value=torch.zeros(query.size())
for i,val in enumerate(inputs):
    context_value+=attn_weight[i]*val

print(context_value)

tensor([-2.0000, -0.1502, -0.8964])


# extracting context vector from all inputs

In [9]:
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66], 
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33], 
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55]]
)


In [10]:
#attention score is
attention_score=torch.zeros(6,6)

for i,i_val in enumerate(inputs):
    for j,j_val in enumerate(inputs):
        attention_score[i,j]=torch.dot(i_val,j_val)

In [11]:
#we can also transpose the values here
att_score=inputs@inputs.T

In [12]:
att_score

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [13]:
print(attention_score)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [14]:
#using torch softmax
attnetion_weight=torch.softmax(att_score,dim=1)

In [15]:
attnetion_weight

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [16]:
attnetion_weight.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [17]:
context_vector=attnetion_weight@inputs

In [18]:
context_vector

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

# implementing self attention with trainable weights
* self attention mechanism is also called scaled dot product

In [19]:
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
[0.55, 0.87, 0.66], 
[0.57, 0.85, 0.64],
[0.22, 0.58, 0.33], 
[0.77, 0.25, 0.10],
[0.05, 0.80, 0.55]]
)


In [20]:
x_2=inputs[1]
d_in=inputs.shape[1]
d_out=2

In [21]:
print(d_in)
print(d_out)

3
2


In [22]:
torch.manual_seed(123)

#lets define the query,key and value
w_query=torch.nn.parameter.Parameter(torch.randn(d_in,d_out),requires_grad=False)
w_key=torch.nn.parameter.Parameter(torch.randn(d_in,d_out),requires_grad=False)
w_value=torch.nn.parameter.Parameter(torch.randn(d_in,d_out),requires_grad=False)

print(w_query)
print(w_key)
print(w_value)

Parameter containing:
tensor([[-0.1115,  0.1204],
        [-0.3696, -0.2404],
        [-1.1969,  0.2093]])
Parameter containing:
tensor([[-0.9724, -0.7550],
        [ 0.3239, -0.1085],
        [ 0.2103, -0.3908]])
Parameter containing:
tensor([[ 0.2350,  0.6653],
        [ 0.3528,  0.9728],
        [-0.0386, -0.8861]])


In [23]:
query_2=x_2@w_query
key_2=x_2@w_key
value_2=x_2@w_value

In [24]:
query_2

tensor([-1.1729, -0.0048])

In [25]:
#general key and value 
key=inputs@w_key
value=inputs@w_value

In [26]:
print(f"{key}\n")
print(value)

tensor([[-0.1823, -0.6888],
        [-0.1142, -0.7676],
        [-0.1443, -0.7728],
        [ 0.0434, -0.3580],
        [-0.6467, -0.6476],
        [ 0.3262, -0.3395]])

tensor([[ 0.1196, -0.3566],
        [ 0.4107,  0.6274],
        [ 0.4091,  0.6390],
        [ 0.2436,  0.4182],
        [ 0.2653,  0.6668],
        [ 0.2728,  0.3242]])


In [27]:
print(key.shape)
print(value.shape)

torch.Size([6, 2])
torch.Size([6, 2])


In [28]:
#attention score
key_2=key[1]
attn_score=query_2.dot(key_2)
print(f"the attention scoer {attn_score}")

the attention scoer 0.13763877749443054


In [29]:
#general attention score of the token
attn_score=query_2@key.T

In [30]:
attn_score

tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809])

In [31]:
d_k=key.shape[1]
attn_weight_2=torch.softmax(attn_score/d_k**0.5,dim=-1)

In [32]:
attn_weight_2

tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117])

In [33]:
context_value_2=attn_weight_2@value

In [34]:
print(context_value_2)

tensor([0.2854, 0.4081])


# lets formalize the process

In [35]:
class SelfAttentionScoreV1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.d_out=d_out
        self.w_query=nn.Parameter(torch.rand(d_in,d_out))
        self.w_key=nn.Parameter(torch.rand(d_in,d_out))
        self.w_value=nn.Parameter(torch.rand(d_in,d_out))
    
    def forward(self,x):
        query=x@self.w_query
        keys=x@self.w_key
        values=x@self.w_value

        #self attention
        attention_score=query@keys.T   #attenion weight 
        attention_weight=torch.softmax(attention_score/(self.d_out**0.5),dim=-1)
        context_vector=attention_weight@values

        return context_vector

In [36]:
sl_attn=SelfAttentionScoreV1(d_in=3,d_out=2)
print(sl_attn(inputs))

tensor([[1.1672, 1.1043],
        [1.1878, 1.1235],
        [1.1870, 1.1228],
        [1.1620, 1.0994],
        [1.1568, 1.0944],
        [1.1716, 1.1084]], grad_fn=<MmBackward0>)


# using linear layer rather than paramter

In [37]:
torch.manual_seed(123)
class SelfAttentionScoreV2(nn.Module):
    def __init__(self,d_in,d_out,bias=False):
        super().__init__()
        self.d_out=d_out
        self.w_query=nn.Linear(d_in,d_out,bias=bias)
        self.w_key=nn.Linear(d_in,d_out,bias=bias)
        self.w_value=nn.Linear(d_in,d_out,bias=bias)
    
    def forward(self,x):
        query=self.w_query(x)
        keys=self.w_key(x)
        values=self.w_value(x)

        #self attention
        attention_score=query@keys.T   #attenion weight 
        attention_weight=torch.softmax(attention_score/(self.d_out**0.5),dim=-1)
        context_vector=attention_weight@values

        return context_vector

In [38]:
sl_attn=SelfAttentionScoreV2(d_in=3,d_out=2)
print(sl_attn(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


# tranfer weight from linear into paramter format between classes

In [39]:
#linear weight
query_weight=sl_attn.w_query.weight.data
key_weight=sl_attn.w_key.weight.data
value_weight=sl_attn.w_value.weight.data

#paramter weight
new_vl_class=SelfAttentionScoreV2(d_in,d_out)

new_vl_class.w_query.weight.data=query_weight
new_vl_class.w_key.weight.data=key_weight
new_vl_class.w_value.weight.data=value_weight

#feed forward the value
outputs=new_vl_class(inputs)
print(outputs)


tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


# applying causal attention

In [40]:
#lets make the attention score and weight
attn_class=SelfAttentionScoreV2(d_in,d_out)

#pass values into each layer oursefs
queries=attn_class.w_query(inputs)
keys=attn_class.w_key(inputs)
values=attn_class.w_value(inputs)

attention_score=queries@keys.T
attn_weight=torch.softmax(attention_score/d_out**0.5,dim=-1)

In [41]:
print(attention_score)

tensor([[-0.0320, -0.2465, -0.2400, -0.1684, -0.0549, -0.2388],
        [-0.0337, -0.2816, -0.2742, -0.1929, -0.0627, -0.2733],
        [-0.0330, -0.2810, -0.2736, -0.1926, -0.0625, -0.2728],
        [-0.0181, -0.1395, -0.1358, -0.0953, -0.0311, -0.1351],
        [-0.0108, -0.1907, -0.1857, -0.1326, -0.0421, -0.1872],
        [-0.0263, -0.1566, -0.1524, -0.1060, -0.0351, -0.1506]],
       grad_fn=<MmBackward0>)


In [42]:
print(attn_weight)

tensor([[0.1825, 0.1568, 0.1576, 0.1657, 0.1796, 0.1577],
        [0.1852, 0.1554, 0.1562, 0.1655, 0.1814, 0.1563],
        [0.1852, 0.1554, 0.1562, 0.1654, 0.1814, 0.1563],
        [0.1756, 0.1611, 0.1615, 0.1662, 0.1740, 0.1616],
        [0.1804, 0.1589, 0.1594, 0.1655, 0.1765, 0.1593],
        [0.1760, 0.1605, 0.1610, 0.1664, 0.1749, 0.1612]],
       grad_fn=<SoftmaxBackward0>)


## lets prepare the mask

In [43]:
context_length=attn_weight.shape[0]
mask_sample=torch.tril(torch.ones(context_length,context_length))

In [44]:
mask_sample

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [45]:
masked_sample=attn_weight*mask_sample

In [46]:
masked_sample

tensor([[0.1825, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1852, 0.1554, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1852, 0.1554, 0.1562, 0.0000, 0.0000, 0.0000],
        [0.1756, 0.1611, 0.1615, 0.1662, 0.0000, 0.0000],
        [0.1804, 0.1589, 0.1594, 0.1655, 0.1765, 0.0000],
        [0.1760, 0.1605, 0.1610, 0.1664, 0.1749, 0.1612]],
       grad_fn=<MulBackward0>)

In [47]:
#lets normalize the inputs again
normalized_version=masked_sample/masked_sample.sum(dim=1,keepdim=True)
print(normalized_version)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5437, 0.4563, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3728, 0.3128, 0.3144, 0.0000, 0.0000, 0.0000],
        [0.2642, 0.2425, 0.2431, 0.2502, 0.0000, 0.0000],
        [0.2146, 0.1890, 0.1896, 0.1969, 0.2099, 0.0000],
        [0.1760, 0.1605, 0.1610, 0.1664, 0.1749, 0.1612]],
       grad_fn=<DivBackward0>)


In [48]:
mask=torch.triu(torch.ones(6,6),diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [49]:
print(attention_score)

tensor([[-0.0320, -0.2465, -0.2400, -0.1684, -0.0549, -0.2388],
        [-0.0337, -0.2816, -0.2742, -0.1929, -0.0627, -0.2733],
        [-0.0330, -0.2810, -0.2736, -0.1926, -0.0625, -0.2728],
        [-0.0181, -0.1395, -0.1358, -0.0953, -0.0311, -0.1351],
        [-0.0108, -0.1907, -0.1857, -0.1326, -0.0421, -0.1872],
        [-0.0263, -0.1566, -0.1524, -0.1060, -0.0351, -0.1506]],
       grad_fn=<MmBackward0>)


## a simpler implementation of mask

In [50]:
masked_score=attention_score.masked_fill(mask.bool(),-torch.inf)
masked_score

tensor([[-0.0320,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0337, -0.2816,    -inf,    -inf,    -inf,    -inf],
        [-0.0330, -0.2810, -0.2736,    -inf,    -inf,    -inf],
        [-0.0181, -0.1395, -0.1358, -0.0953,    -inf,    -inf],
        [-0.0108, -0.1907, -0.1857, -0.1326, -0.0421,    -inf],
        [-0.0263, -0.1566, -0.1524, -0.1060, -0.0351, -0.1506]],
       grad_fn=<MaskedFillBackward0>)

In [51]:
attn_weight=torch.softmax(masked_score/keys.shape[-1]**0.5,dim=1)
print(attn_weight)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5437, 0.4563, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3728, 0.3128, 0.3144, 0.0000, 0.0000, 0.0000],
        [0.2642, 0.2425, 0.2431, 0.2502, 0.0000, 0.0000],
        [0.2146, 0.1890, 0.1896, 0.1969, 0.2099, 0.0000],
        [0.1760, 0.1605, 0.1610, 0.1664, 0.1749, 0.1612]],
       grad_fn=<SoftmaxBackward0>)


# adding dropout

In [52]:
ones=torch.ones(6,6)
dropout=torch.nn.Dropout(0.5)
print(dropout(ones))

tensor([[2., 0., 0., 0., 2., 0.],
        [0., 2., 2., 2., 2., 0.],
        [2., 2., 0., 2., 2., 0.],
        [2., 0., 0., 2., 0., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 2., 0., 0., 2.]])


In [53]:
torch.ones(4,4)

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [54]:
print(dropout(attn_weight))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6256, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5284, 0.0000, 0.4862, 0.0000, 0.0000, 0.0000],
        [0.4292, 0.0000, 0.3793, 0.0000, 0.0000, 0.0000],
        [0.3520, 0.3210, 0.3220, 0.3327, 0.3498, 0.3224]],
       grad_fn=<MulBackward0>)


# compact form of casual attention class

In [55]:
#simluating batch size
batch=torch.stack((inputs,inputs),dim=0)
print(batch)
print(f"shape is {batch.shape}")

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
shape is torch.Size([2, 6, 3])


In [56]:
class CasualAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout_value,weight_bias=False) -> None:
        super().__init__()
        self.d_out=d_out
        self.w_query=nn.Linear(d_in,d_out,bias=weight_bias)
        self.w_key=nn.Linear(d_in,d_out,bias=weight_bias)
        self.w_value=nn.Linear(d_in,d_out,bias=weight_bias)
        #self.dropout=nn.Dropout(dropout_value)
        self.dropout = nn.Dropout(dropout_value)
        #print(f"initialize dropout as:{dropout_value}")

        #buffer
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    
    def forward(self,x):
         #some values
        b,num_tokens,d_in=x.shape
        
        queries=self.w_query(x)
        keys=self.w_key(x)
        values=self.w_value(x)

        attention_score=queries@keys.transpose(1,2) #this is for batches of input like 3D inputs
     
        #lets apply casual attetion
        attention_score.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)   #dyanmic even for smaller number of tokens
       
        attention_weight=torch.softmax(attention_score/self.d_out**0.5,dim=-1)  
       
        attention_weight=self.dropout(attention_weight)

        #lets make the context vector
        context_vector=attention_weight@values

        return context_vector

In [57]:
torch.manual_seed(123)
# lets pass values into it
context_length=batch.shape[1]
attention=CasualAttention(d_in=3,d_out=2,context_length=context_length,dropout_value=0.0)
context_vec=attention(batch)

In [61]:
context_vec.shape

torch.Size([2, 6, 2])

In [60]:
print(context_vec)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


# multihead attention 

In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,n_heads,drop=0.5,bias=False) -> None:
        super().__init__()
        self.heads=nn.ModuleList(CasualAttention(d_in,d_out,context_length,drop,bias) for _ in range(n_heads))
    def forward(self,x):
        #this is the context vector concatinated along the column dim
        return torch.cat([head(x) for head in self.heads],dim=-1)

In [74]:
#lets test the output
context_length=batch.shape[1]
d_in,d_out=3,1
mha=MultiHeadAttention(d_in,d_out,context_length,n_heads=2,drop=0)
context_vec=mha(batch)

In [75]:
context_vec

tensor([[[0.1970, 0.0763],
         [0.3793, 0.2361],
         [0.4401, 0.2830],
         [0.4132, 0.2812],
         [0.4079, 0.2272],
         [0.4020, 0.2577]],

        [[0.1970, 0.0763],
         [0.3793, 0.2361],
         [0.4401, 0.2830],
         [0.4132, 0.2812],
         [0.4079, 0.2272],
         [0.4020, 0.2577]]], grad_fn=<CatBackward0>)

In [76]:
context_vec.shape

torch.Size([2, 6, 2])

In [78]:
temp=torch.randn(6,5,4,3)

In [None]:
a=torch.randn(4,3,5,6)
b=torch.randn(4,3,6,5)
c=a

# parallel computation of attention