In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
# Creating x
# we are ignoring the batch dimension for this version of self-attention
# x in this case will be: (token, embedding dimension)
# therefor, there are 5 tokens and each token is 4 dimensions each
x = torch.rand(5, 4)

In [13]:
x

tensor([[0.8311, 0.8966, 0.9652, 0.3645],
        [0.7580, 0.4778, 0.5988, 0.4520],
        [0.7356, 0.1954, 0.3560, 0.4245],
        [0.8762, 0.1240, 0.9374, 0.5385],
        [0.3336, 0.9823, 0.6054, 0.5508]])

In [14]:
# creating our linear layers
# we only want the weights of this linear layer 
q_lin = nn.Linear(5, 3, bias=False)
v_lin = nn.Linear(5, 3, bias=False)
k_lin = nn.Linear(5, 3, bias=False)

In [16]:
# next we need to grab our queries, values, and keys 
# before doing so we will make check the output dimension sizes
print(f'Input dimensions: {x.shape}')
print(f'Output dimensions: {q_lin(x.T).shape}')

Input dimensions: torch.Size([5, 4])
Output dimensions: torch.Size([4, 3])


In [17]:
"""
1. Getting Queries, Values, and Keys
"""
keys = k_lin(x.T)
queries = q_lin(x.T)
values = v_lin(x.T)

In [19]:
"""
2. Attention Filter
"""
# Now that we have gotten our keys, queries, and values. We will be applying our attention filter
# remember, this of the filter as a photo filter. 
# the idea behind this filter is to 'filter' out what is most important or what
# the model pays most attention to in the model
A1 = torch.mm(queries, keys.T)

In [21]:
# this produces our initial filter
# this filter is just the queries dot keys.tranpose
A1, A1.shape

(tensor([[-0.0461, -0.0809, -0.1210, -0.0057],
         [ 0.2983,  0.1843,  0.2416,  0.1894],
         [ 0.0242, -0.0446, -0.0475,  0.0374],
         [ 0.0418,  0.0318, -0.0084,  0.0325]], grad_fn=<MmBackward>),
 torch.Size([4, 4]))

In [25]:
# Now we will scale down by the k dimension
scaler = torch.sqrt(torch.tensor(4.))

In [27]:
# scaling down
A2 = A1/scaler

In [28]:
# This is our scaled down filter
A2

tensor([[-0.0230, -0.0404, -0.0605, -0.0029],
        [ 0.1491,  0.0922,  0.1208,  0.0947],
        [ 0.0121, -0.0223, -0.0238,  0.0187],
        [ 0.0209,  0.0159, -0.0042,  0.0162]], grad_fn=<DivBackward0>)

In [33]:
# Now we will apply Softmax to our filter
# this will produce the softmax filter
# we can use this softmax filter to view 'heat map'
# The shape should be 4x4, which should relate to (token x token)
A3 = F.softmax(A2, 1)
A3, A3.shape

(tensor([[0.2521, 0.2478, 0.2428, 0.2573],
         [0.2588, 0.2445, 0.2516, 0.2451],
         [0.2540, 0.2454, 0.2450, 0.2556],
         [0.2522, 0.2509, 0.2459, 0.2510]], grad_fn=<SoftmaxBackward>),
 torch.Size([4, 4]))

In [37]:
"""
3. Filtered Value
"""
# Now that we have our filter, we will apply it to our input
# although, instead of applying it directly to the input
# we will be applying it to the values which was calculated X.T*Wv
F = torch.mm(A3, values)
F, F.shape

(tensor([[-0.6343,  0.0544, -0.4861],
         [-0.6419,  0.0541, -0.4920],
         [-0.6364,  0.0542, -0.4877],
         [-0.6354,  0.0549, -0.4872]], grad_fn=<MmBackward>),
 torch.Size([4, 3]))

In [38]:
"""
Modulated
"""
class SelfAttention(nn.Module):
    def __init__(self, ni, nf, emb_dim):
        """
        ::param ni: input to linear layer, number of tokens
        ::param nf: output to linear layer, this can be same as number of tokens
        ::param emb_dim: embedding dimension per token
        """
        self.q_lin = nn.Linear(ni, nf, bias=False)
        self.v_lin = nn.Linear(ni, nf, bias=False)
        self.k_lin = nn.Linear(ni, nf, bias=False)
        
    def forward(self, x):
        """
        ::param x: input, we are assuming no batched dimension for this input
        """
        _, k = x.shape
        scaler = torch.sqrt(torch.tensor(float(k)))
        
        """
        1. Grabbing queries, keys, and values
        """
        
        queries = self.q_lin(x.T)
        values = self.v_lin(x.T)
        keys = self.k_lin(x.T)
        
        """
        2. Attention Filter
        """
        A1 = torch.mm(queries, keys.T)
        A2 = A1/scaler
        A3 = F.softmax(A2, 1)
        
        """
        3. Filtered Values
        """
        F = torch.bmm(A3, values)
        
        # Returning both the filtered values -> This will be the input to the linear layer
        # along with A3 which can be used as the heatmap
        return F, A3