In [1]:
from math import sqrt
import torch
import torch.nn as nn

In [2]:
class SelfAttention(nn.Module):
    # input: batch_size * seq_len * input_dim
    # q: batch_size * input_dim * dim_k
    # k: batch_size * input_dim * dim_k
    # v: batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v):
        super(SelfAttention, self).__init__()
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
    
    
    def forward(self, x):
        Q = self.q(x)   # Q: batch_size * seq_len * dim_k
        K = self.k(x)   # K: batch_size * seq_len * dim_k
        V = self.v(x)   # V: batch_size * seq_len * dim_v

        # Q * K.T: batch_size * seq_len * seq_len
        atten = nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1)) * self._norm_fact)
        
        # Q * K.T * V: batch_size * seq_len * dim_v
        output = torch.bmm(atten, V)

        return output

In [3]:
X = torch.randn(4, 3, 2)
print(X)
print(X.size())

tensor([[[-0.7180,  0.1240],
         [-0.5025, -0.3688],
         [-1.4666, -1.7369]],

        [[-0.4069,  1.3145],
         [ 1.9539,  0.5227],
         [-1.3205, -2.3105]],

        [[ 0.7404,  1.4654],
         [ 0.7514, -1.5712],
         [-0.7534, -0.8105]],

        [[ 0.9951, -0.8740],
         [ 0.0501, -0.8154],
         [ 0.1744,  0.2346]]])
torch.Size([4, 3, 2])


In [4]:
self_attention = SelfAttention(2, 4, 5)
res = self_attention(X)
print(res)
print(res.size())

tensor([[[-0.7596,  0.2057,  0.0712, -0.3409, -0.4780],
         [-0.7930,  0.1943,  0.0763, -0.3505, -0.4690],
         [-0.8356,  0.1842,  0.0718, -0.3530, -0.4443]],

        [[ 0.2769,  0.6979, -0.4471,  0.2858, -0.3194],
         [-0.2596,  0.3906, -0.0443, -0.1604, -0.5647],
         [ 0.7067,  1.1031, -1.1776,  1.0127,  0.3715]],

        [[ 0.2204,  0.6942, -0.4782,  0.3054, -0.2562],
         [-0.1416,  0.6858, -0.7164,  0.4663,  0.1968],
         [-0.0530,  0.7529, -0.8249,  0.5780,  0.2882]],

        [[ 0.0608,  0.7631, -0.7695,  0.5450,  0.1694],
         [ 0.0708,  0.7870, -0.8234,  0.5954,  0.2303],
         [ 0.0765,  0.7644, -0.7615,  0.5402,  0.1526]]],
       grad_fn=<BmmBackward0>)
torch.Size([4, 3, 5])


In [32]:
class SelfAttentionMultiHead(nn.Module):
    # input: batch_size * seq_len * input_dim
    # q: batch_size * input_dim * dim_k
    # k: batch_size * input_dim * dim_k
    # v: batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v, num_heads):
        super(SelfAttentionMultiHead, self).__init__()
        assert dim_k % num_heads == 0
        assert dim_v % num_heads == 0
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)

        self.num_heads = num_heads
        self.dim_k = dim_k
        self.dim_v = dim_v
        self._norm_fact = 1 / sqrt(dim_k)
    

    def forward(self, x):
        Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k // self.num_heads)
        K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k // self.num_heads)
        V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v // self.num_heads)
        
        print(x.shape)
        print(Q.size())

        atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0, 1, 3, 2))) * self._norm_factor # Q * K.T() # batch_size * seq_len * seq_len

        output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1) # Q * K.T() * V # batch_size * seq_len * dim_v

        return output

In [33]:
X = torch.randn(6, 10, 12)
print(X)
print(X.size())

tensor([[[ 5.8137e-01, -6.9116e-01, -2.1204e-01,  9.1610e-01,  7.1715e-02,
           6.1210e-01,  8.6303e-02, -7.8856e-02, -1.7583e+00,  1.4410e-01,
          -1.5569e+00,  1.5800e+00],
         [ 9.8535e-01, -6.9716e-01,  4.1188e-01, -8.6874e-01, -2.5236e+00,
          -1.0130e+00,  1.2952e+00,  5.3968e-01,  1.9581e-01, -2.9223e-01,
           7.9378e-01, -1.3396e-01],
         [ 2.1546e-01, -3.4884e-01,  3.5852e-01,  5.9518e-01, -2.8232e+00,
          -1.4204e+00,  2.0986e-01, -2.1886e-01,  9.2015e-01,  1.2289e+00,
          -2.4733e-01, -2.0131e+00],
         [ 6.6980e-01, -4.4293e-01, -6.5551e-01, -6.3324e-01, -6.7950e-01,
          -1.1531e+00,  2.1958e+00, -1.0471e+00, -7.9204e-01, -5.7783e-01,
           4.1470e-01, -2.6242e-01],
         [-1.2670e-01,  8.8041e-01,  1.6829e+00,  1.0102e+00, -7.8223e-01,
           1.0122e+00,  1.7825e+00, -9.7064e-01,  1.4639e+00, -8.5901e-01,
           6.4491e-01,  7.1196e-01],
         [ 8.5569e-01, -3.0760e-01, -4.2791e-01,  9.1599e-01,  2.

In [35]:
self_attention = SelfAttentionMultiHead(12, 4, 6, 2)
res = self_attention(X)
print(res)
print(res.size())

torch.Size([6, 10, 12])
torch.Size([2, 6, 10, 2])
tensor([[[ 2.0342e-01,  3.4062e-01, -5.3461e-02,  1.7939e-01,  4.1577e-01,
          -2.6760e-01],
         [ 2.0257e-01,  3.4769e-01, -8.5435e-02,  1.4806e-01,  4.9734e-01,
          -2.8467e-01],
         [ 2.1995e-01,  2.4925e-01, -3.6779e-03,  1.6946e-01,  4.5658e-01,
          -2.3527e-01],
         [ 2.1141e-01,  2.9893e-01, -3.1834e-02,  1.8276e-01,  4.1040e-01,
          -2.4414e-01],
         [ 1.8845e-01,  4.0583e-01, -4.1880e-02,  1.6045e-01,  4.6153e-01,
          -2.9468e-01],
         [ 1.4068e-01,  4.4603e-01, -2.6512e-01,  1.1452e-01,  3.2217e-01,
          -1.3647e-01],
         [ 1.3684e-01,  4.0859e-01, -2.2066e-01,  1.2875e-01,  3.9079e-01,
          -2.0597e-01],
         [ 1.2489e-01,  4.0010e-01, -2.1458e-01,  1.2532e-01,  3.7144e-01,
          -1.8546e-01],
         [ 1.2418e-01,  3.6196e-01, -1.7459e-01,  1.2473e-01,  3.6365e-01,
          -1.7481e-01],
         [ 9.3679e-02,  3.5114e-01, -1.7168e-01,  1.3597e-0