In [21]:
#https://peterbloem.nl/blog/transformers
import torch
import torch.nn.functional as F

In [22]:
# Assume we have some tensor x with size (b, t, k)
x = [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
"""
x format: 
   b1       b2
|1,2,3| |7 ,8 ,9 | 
|4,5,6| |10,11,12| 

x.transpose format: 
 b1        b2
|1,4| |7,10,11,12|
|2,5| |8,10,11,12|
|3,6| |9,10,11,12|
"""

xt=torch.tensor(x)
xt.transpose(1,2)
raw_weights = torch.bmm(xt, xt.transpose(1, 2))
print(raw_weights)

tensor([[[ 14,  32],
         [ 32,  77]],

        [[194, 266],
         [266, 365]]])


In [31]:
weights = F.softmax(raw_weights.float(), dim=2)
print("weights: ",weights)
y = torch.bmm(weights, xt.float())
print("y: ",y)

weights:  tensor([[[1.5230e-08, 1.0000e+00],
         [2.8625e-20, 1.0000e+00]],

        [[5.3802e-32, 1.0000e+00],
         [1.0089e-43, 1.0000e+00]]])
y:  tensor([[[ 4.,  5.,  6.],
         [ 4.,  5.,  6.]],

        [[10., 11., 12.],
         [10., 11., 12.]]])


Three k×k weight matrices 𝐖q, 𝐖k,𝐖v and compute three linear transformations of each xi,

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [93]:
k=10
key = nn.Linear(k, k, bias=False,dtype=torch.float)
#x = torch.tensor([1,0,0,0,0,0,0,0,0,0])
x = torch.randn(1,10)
y= key(x.float())
print("y: \n",y)
# for pa in key.parameters():
#     print(pa)


y: 
 tensor([[-0.1837,  0.1492, -0.4361, -0.8322,  1.1574, -0.0720,  0.5413, -0.0395,
          0.6671, -1.2583]], grad_fn=<MmBackward0>)


In [79]:
x = torch.randn(1,12)
print(x)
x.view(3,2,2)

tensor([[-1.0626, -0.4637,  0.5218,  0.7086, -0.1552,  1.1626, -0.7062,  1.8758,
         -1.2824, -0.4376, -0.7546,  0.5570]])


tensor([[[-1.0626, -0.4637],
         [ 0.5218,  0.7086]],

        [[-0.1552,  1.1626],
         [-0.7062,  1.8758]],

        [[-1.2824, -0.4376],
         [-0.7546,  0.5570]]])

In [95]:
xt = torch.tensor([[[10,0.2],[0.3,0.4]]])
print(torch.bmm(xt, xt.transpose(1, 2)))
print(torch.bmm(xt, xt))
print("xt.transpose(1, 2): \n",xt.transpose(1, 2))
print("xt: \n",xt)

tensor([[[100.0400,   3.0800],
         [  3.0800,   0.2500]]])
tensor([[[100.0600,   2.0800],
         [  3.1200,   0.2200]]])
xt.transpose(1, 2): 
 tensor([[[10.0000,  0.3000],
         [ 0.2000,  0.4000]]])
xt: 
 tensor([[[10.0000,  0.2000],
         [ 0.3000,  0.4000]]])


In [123]:
#Note the difference between matrix reshaping and the matrix transpose (done with the .t() function):
matrix = torch.tensor([[3,5,2],[1, 3, 4]])

# view test
mv = matrix.view(3, 2)
mv[0][0]=99
print("mv",mv,"matrix:",matrix)

# reshape test
mr=matrix.contiguous().view((3, 2))
mr[0][0]=0
print("mv \n",mv,"\nmatrix: \n",matrix,"\nmr: \n",mr )

print("matrix.view(3, 2): ",(matrix.t().contiguous().view(2,3)))

print("matrix: ", matrix)
mv = matrix.view(2,3)
mv[0][0] = 99
print(mv)
print("matrix",matrix)

mv tensor([[99,  5],
        [ 2,  1],
        [ 3,  4]]) matrix: tensor([[99,  5,  2],
        [ 1,  3,  4]])
mv 
 tensor([[0, 5],
        [2, 1],
        [3, 4]]) 
matrix: 
 tensor([[0, 5, 2],
        [1, 3, 4]]) 
mr: 
 tensor([[0, 5],
        [2, 1],
        [3, 4]])
matrix.view(3, 2):  tensor([[0, 1, 5],
        [3, 2, 4]])
matrix:  tensor([[0, 5, 2],
        [1, 3, 4]])
tensor([[99,  5,  2],
        [ 1,  3,  4]])
matrix tensor([[99,  5,  2],
        [ 1,  3,  4]])


# Forward compute of attention

In [23]:
x = torch.randn(1,2,40)
b,t,k = x.size()
h = 4

# Create net
tokeys    = nn.Linear(k, k, bias=False)
toqueries = nn.Linear(k, k, bias=False)
tovalues  = nn.Linear(k, k, bias=False)
unifyheads = nn.Linear(k, k)

# Calculate q,k,v
queries = toqueries(x)
keys    = tokeys(x)   
values  = tovalues(x)

s = k // h
keys    = keys.view(b, t, h, s)
queries = queries.view(b, t, h, s)
values  = values.view(b, t, h, s)

print(keys.size())

# - fold heads into the batch dimension. This ensures that we can use torch.bmm()
fkeys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
fqueries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
fvalues = values.transpose(1, 2).contiguous().view(b * h, t, s)
print("fkeys.size()",fkeys.size())
print("fkeys.transpose(1, 2).size():",fkeys.transpose(1, 2).size())
# Get dot product of queries and keys, and scale
dot = torch.bmm(fqueries, fkeys.transpose(1, 2))

print("dot.size()",dot.size())
# -- dot has size (b*h, t, t) containing raw weights
# scale the dot product
dot = dot / (k ** (1/2))
# normalize 
dot = F.softmax(dot, dim=2)
# print("dot: ",dot,"\n size: \n",dot.size())
print("\n dot size: ",dot.size())

# apply the self attention to the values
print(fvalues[0][0])
print(dot[0][0])
out = torch.bmm(dot, fvalues)
print(out[0][0])
# swap h, t back, unify heads
out = out.transpose(1, 2).contiguous().view(b, t, s * h)
print("out.size(): ",out.size())

torch.Size([1, 2, 4, 10])
fkeys.size() torch.Size([4, 2, 10])
fkeys.transpose(1, 2).size(): torch.Size([4, 10, 2])
dot.size() torch.Size([4, 2, 2])

 dot size:  torch.Size([4, 2, 2])
tensor([ 3.0379e-01,  1.1340e+00, -5.7817e-02, -5.4739e-04, -7.4721e-02,
        -5.8388e-01,  2.6483e-01, -1.5910e-01,  1.1531e-01, -2.3734e-01],
       grad_fn=<SelectBackward0>)
tensor([0.4273, 0.5727], grad_fn=<SelectBackward0>)
tensor([ 0.1169,  0.5041, -0.3150,  0.7244, -0.0101, -0.4413,  0.1939,  0.0449,
        -0.2114, -0.0252], grad_fn=<SelectBackward0>)
out.size():  torch.Size([1, 2, 40])


: 