In [1]:
import math
import numpy as np
from numpy.random import randn

# I. Define the input data X
# X consists out of 32 samples, each sample has dimensionality 256
n = 32
d = 256
X = randn(n, d) # (32, 256)

# II. Generate the projection weights
Wq = randn(d, d) #(256, 256)
Wk = randn(d, d)
Wv = randn(d, d)

# III. Project X to find its query, keys and values vectors
Q = np.dot(X, Wq) # (32, 256)
K = np.dot(X, Wk)
V = np.dot(X, Wv)

# IV. Compute the self-attention score, denoted by A
# A = softmax(QK^T / \sqrt{d})
# Define the softmax function
def softmax(z):
    z = np.clip(z, 100, -100) # clip in case softmax explodes
    tmp = np.exp(z)
    res = np.exp(z) / np.sum(tmp, axis=1)
    return res

A = softmax(np.dot(Q, K.transpose())/math.sqrt(d)) #(32, 32)

# V. Compute the self-attention output
# outputs = A * V
outputs = np.dot(A, V) #(32, 256)

print("The attention outputs are\n {}".format(outputs))

The attention outputs are
 [[ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]
 [ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]
 [ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]
 ...
 [ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]
 [ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]
 [ 0.12928972  0.56180036  1.44882858 ... -1.74847076 -1.55449529
  -5.00044474]]


In [2]:
import math
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, dim_input, dim_q, dim_v):
        '''
        dim_input: the dimension of each sample
        dim_q: dimension of Q matrix, should be equal to dim_k
        dim_v: dimension of V matrix, also the  dimension of the attention output
        '''
        super(SelfAttention, self).__init__()

        self.dim_input = dim_input
        self.dim_q = dim_q
        self.dim_k = dim_q
        self.dim_v = dim_v

        # Define the linear projection
        self.linear_q = nn.Linear(self.dim_input, self.dim_q, bias=False)
        self.linear_k = nn.Linear(self.dim_input, self.dim_k, bias=False)
        self.linear_v = nn.Linear(self.dim_input, self.dim_v, bias=False)
        self._norm_fact = 1 / math.sqrt(self.dim_k)

    def forward(self, x):
        batch, n, dim_q = x.shape

        q = self.linear_q(x) # (batchsize, seq_len, dim_q)
        k = self.linear_k(x) # (batchsize, seq_len, dim_k)
        v = self.linear_v(x) # (batchsize, seq_len, dim_v)
        print(f'x.shape:{x.shape} \n Q.shape:{q.shape} \n K.shape:{k.shape} \n V.shape:{v.shape}')

        dist = torch.bmm(q, k.transpose(1,2)) * self._norm_fact
        dist = torch.softmax(dist, dim=-1)
        print('attention matrix: ', dist.shape)

        outputs = torch.bmm(dist, v)
        print('attention outputs: ', outputs.shape)

        return outputs


batch_size = 32 # number of samples in a batch
dim_input = 128 # dimension of each item in the sample sequence
seq_len = 20 # sequence length for each sample
x = torch.randn(batch_size, seq_len, dim_input)
self_attention = SelfAttention(dim_input, dim_q = 64, dim_v = 32)

attention = self_attention(x)

print(attention)

x.shape:torch.Size([32, 20, 128]) 
 Q.shape:torch.Size([32, 20, 64]) 
 K.shape:torch.Size([32, 20, 64]) 
 V.shape:torch.Size([32, 20, 32])
attention matrix:  torch.Size([32, 20, 20])
attention outputs:  torch.Size([32, 20, 32])
tensor([[[ 0.1804,  0.1362,  0.1774,  ...,  0.2700,  0.2799, -0.2076],
         [ 0.1551,  0.0536,  0.1839,  ...,  0.2479,  0.2560, -0.1974],
         [ 0.0705,  0.0541,  0.1718,  ...,  0.1769,  0.2233, -0.2124],
         ...,
         [ 0.0950,  0.0392,  0.1534,  ...,  0.2462,  0.2793, -0.2168],
         [ 0.1311,  0.1044,  0.2314,  ...,  0.2689,  0.2349, -0.1613],
         [ 0.1274,  0.0536,  0.1555,  ...,  0.2373,  0.2401, -0.1581]],

        [[ 0.1901,  0.0841, -0.0561,  ...,  0.2223,  0.0370,  0.0209],
         [ 0.1805,  0.0632, -0.1261,  ...,  0.2702,  0.0404,  0.0188],
         [ 0.1816,  0.0806, -0.0515,  ...,  0.2256,  0.0554,  0.0304],
         ...,
         [ 0.2062,  0.0815, -0.0661,  ...,  0.2092,  0.0722, -0.0224],
         [ 0.1423,  0.0710, -0.0