<a href="https://colab.research.google.com/github/jiwan-gharti-savi/free-for-dev/blob/master/SelfAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [59]:
values = nn.Linear(2, 256, bias=False)


In [60]:
values

Linear(in_features=2, out_features=256, bias=False)

In [61]:
linear_features = values(torch.tensor([1,2], dtype=torch.float32))

In [62]:
linear_features.size()

torch.Size([256])

In [63]:
values.weight

Parameter containing:
tensor([[ 6.7313e-01, -1.8661e-02],
        [-6.2013e-02,  2.8488e-01],
        [-2.8281e-01,  3.0558e-02],
        [-6.0827e-01, -2.3294e-01],
        [-4.6213e-01,  2.8557e-03],
        [ 5.0369e-01, -8.0881e-02],
        [ 2.9752e-02,  6.5976e-01],
        [ 2.9833e-01,  6.2319e-01],
        [-2.4784e-01, -5.5923e-01],
        [-9.2070e-03,  5.0430e-01],
        [-4.5312e-01,  6.2466e-01],
        [-1.8034e-01, -2.3809e-01],
        [ 1.7344e-01, -4.3611e-01],
        [-6.6708e-01,  1.6521e-01],
        [-1.5760e-01,  3.1007e-01],
        [ 6.1163e-01,  6.6549e-01],
        [-6.2187e-01,  1.6859e-01],
        [-6.6035e-01,  1.2335e-01],
        [ 5.6807e-01,  3.6127e-01],
        [-5.6038e-02,  3.8307e-01],
        [ 6.4438e-01, -1.3951e-02],
        [-3.3232e-01,  6.8673e-01],
        [ 4.2567e-01,  1.3697e-01],
        [ 6.1080e-02,  4.9819e-01],
        [ 5.6329e-02, -3.8066e-01],
        [-5.7878e-01, -3.3766e-01],
        [-1.0353e-02, -2.9903e-01],
      

In [64]:
values.bias

In [65]:
class SimpleSelfAttention(nn.Module):

  def __init__(self, d_in, d_out):
    super(SimpleSelfAttention, self).__init__()

    self.d_in = d_in
    self.d_out = d_out

    self.Q = nn.Linear(d_in, d_out)
    self.K = nn.Linear(d_in, d_out)
    self.V = nn.Linear(d_in, d_out)


  def forward(self, x):
    queries = self.Q(x)
    keys = self.K(x)
    values = self.V(x)

    scores = torch.bmm(queries, keys.transpose(1,2))
    scores = scores / (self.d_out ** 0.5)

    attention = F.softmax(scores)

    hidden_state = torch.bmm(attention, values)

    return hidden_state


In [66]:
SOS_TOKEN = 0
EOS_TOKEN = 1

index2words = {
    SOS_TOKEN: "SOS",
    EOS_TOKEN: "EOS"
}

words = "How are you doing ? I am good and you ?"

words_list = set(words.lower().split(' '))

In [67]:
words_list


{'?', 'am', 'and', 'are', 'doing', 'good', 'how', 'i', 'you'}

In [68]:
for word in words_list:
  index2words[len(index2words)] = word

In [69]:
index2words


{0: 'SOS',
 1: 'EOS',
 2: 'and',
 3: 'are',
 4: 'i',
 5: 'good',
 6: 'doing',
 7: 'you',
 8: 'am',
 9: '?',
 10: 'how'}

In [70]:
words2index = {w: i for i, w in index2words.items()}

In [71]:
words2index

{'SOS': 0,
 'EOS': 1,
 'and': 2,
 'are': 3,
 'i': 4,
 'good': 5,
 'doing': 6,
 'you': 7,
 'am': 8,
 '?': 9,
 'how': 10}

In [72]:
def convert2tensor(sentence):
  words_list = sentence.lower().split(' ')
  indexes = [words2index[word] for word in words_list]
  return torch.tensor(indexes, dtype=torch.long).view(1,-1)

In [73]:
sentence = "How are you doing ?"
convert2tensor(sentence)

tensor([[10,  3,  7,  6,  9]])

In [74]:
indexes = convert2tensor(sentence)

In [75]:
indexes.size()

torch.Size([1, 5])

In [76]:
HIDDEN_SIZE = 10
VOCAB_SIZE = len(words2index)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
attention = SimpleSelfAttention(HIDDEN_SIZE, HIDDEN_SIZE)


sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)
embedded = embedding(input_tensor)
embedded

tensor([[[ 0.4484,  2.8696, -0.9035,  0.5159,  1.9699,  0.4072, -1.2199,
          -0.1126, -0.1233,  0.9234],
         [-1.3750,  0.1516,  1.3441,  0.9616, -1.5485,  1.9630, -0.2597,
          -0.0829, -0.3105, -0.6952],
         [ 0.5953, -0.5703, -1.5308, -0.0853, -2.4764,  0.2231, -0.1105,
           0.1674,  1.3860, -1.0135],
         [ 1.3144, -0.3012,  1.5836,  1.6488, -0.5957,  0.9042,  0.1411,
           0.8006,  0.3337, -0.6912],
         [-0.3179,  1.5882, -0.6722,  0.9146,  0.3372, -1.2006, -0.2145,
           0.2696,  0.2773,  0.4300]]], grad_fn=<EmbeddingBackward0>)

In [77]:
embedded.size()

torch.Size([1, 5, 10])

In [78]:
hidden_states = attention(embedded)

  attention = F.softmax(scores)


In [79]:
hidden_states

tensor([[[-2.3638, -2.8103,  0.6418, -1.5894,  1.1509, -0.6191,  0.9160,
          -0.8259, -1.4455, -0.9779],
         [-2.3638, -2.8103,  0.6418, -1.5894,  1.1509, -0.6191,  0.9160,
          -0.8259, -1.4455, -0.9779],
         [-2.3638, -2.8103,  0.6418, -1.5894,  1.1509, -0.6191,  0.9160,
          -0.8259, -1.4455, -0.9779],
         [-2.3638, -2.8103,  0.6418, -1.5894,  1.1509, -0.6191,  0.9160,
          -0.8259, -1.4455, -0.9779],
         [-2.3638, -2.8103,  0.6418, -1.5894,  1.1509, -0.6191,  0.9160,
          -0.8259, -1.4455, -0.9779]]], grad_fn=<BmmBackward0>)

In [80]:
hidden_states.size()

torch.Size([1, 5, 10])

In [81]:
d_in = HIDDEN_SIZE
d_out = HIDDEN_SIZE

Q = nn.Linear(d_in, d_out)
K = nn.Linear(d_in, d_out)
V = nn.Linear(d_in, d_out)

In [82]:
Q

Linear(in_features=10, out_features=10, bias=True)

In [83]:
K

Linear(in_features=10, out_features=10, bias=True)

In [84]:
V

Linear(in_features=10, out_features=10, bias=True)

In [85]:
Q.weight

Parameter containing:
tensor([[ 0.0751,  0.0735,  0.2432, -0.2527, -0.1377, -0.1305,  0.0895,  0.1730,
          0.2954,  0.2423],
        [ 0.1623, -0.3091,  0.2639,  0.0650, -0.0324, -0.1573, -0.3036, -0.1084,
          0.2875, -0.1767],
        [ 0.0217,  0.2514, -0.0302, -0.2532,  0.1251, -0.0370,  0.0336,  0.2756,
         -0.2655,  0.1307],
        [-0.3132,  0.2718, -0.1704, -0.3006, -0.0312,  0.2924, -0.2788,  0.2873,
          0.2895, -0.0085],
        [ 0.0088, -0.2562,  0.1172, -0.2874,  0.0125,  0.3108,  0.0103, -0.3132,
          0.1727,  0.3104],
        [ 0.1546, -0.0415,  0.1655, -0.2884, -0.1611,  0.3038, -0.1051,  0.0412,
          0.2226, -0.2285],
        [-0.2137,  0.0745,  0.2620, -0.0351, -0.1941,  0.2743,  0.1025, -0.2799,
         -0.2072, -0.3105],
        [-0.2244,  0.0528, -0.1830, -0.1941,  0.1719,  0.3031,  0.1722, -0.1224,
         -0.0223,  0.0301],
        [-0.1100,  0.2692,  0.0654,  0.3083,  0.1927,  0.1840, -0.1219,  0.2234,
          0.0012, -0.1684

In [86]:
sentence = "How are you doing ?"
input_tensor = convert2tensor(sentence)
embedded = embedding(input_tensor)
embedded

tensor([[[ 0.4484,  2.8696, -0.9035,  0.5159,  1.9699,  0.4072, -1.2199,
          -0.1126, -0.1233,  0.9234],
         [-1.3750,  0.1516,  1.3441,  0.9616, -1.5485,  1.9630, -0.2597,
          -0.0829, -0.3105, -0.6952],
         [ 0.5953, -0.5703, -1.5308, -0.0853, -2.4764,  0.2231, -0.1105,
           0.1674,  1.3860, -1.0135],
         [ 1.3144, -0.3012,  1.5836,  1.6488, -0.5957,  0.9042,  0.1411,
           0.8006,  0.3337, -0.6912],
         [-0.3179,  1.5882, -0.6722,  0.9146,  0.3372, -1.2006, -0.2145,
           0.2696,  0.2773,  0.4300]]], grad_fn=<EmbeddingBackward0>)

In [89]:
keys = K(embedded)

In [90]:
keys

tensor([[[-0.4545,  0.5926, -0.4372,  0.3479, -0.1554,  0.9619, -0.3094,
           1.3898, -0.0463, -0.0612],
         [-0.1358,  0.5239, -0.6684, -0.7774, -0.1650,  0.1561,  0.0172,
          -0.1913,  0.9861, -1.1609],
         [ 0.0363, -0.0446,  0.6692, -1.6894, -0.0305, -0.2130, -0.2201,
          -0.1407, -0.5690, -0.1875],
         [-0.2554,  1.3012, -0.4328, -0.3100,  0.6252, -0.7503,  0.1489,
          -0.6630,  0.2348, -0.3658],
         [-0.2768,  0.3774,  0.2837, -0.1215,  0.0984,  0.5114,  0.4355,
           1.2921, -0.5621,  0.2994]]], grad_fn=<ViewBackward0>)

In [91]:
keys.size()

torch.Size([1, 5, 10])

In [92]:
values = V(embedded)
queries = Q(embedded)

In [115]:
# variance
values.var(), queries.var()

(tensor(0.3691, grad_fn=<VarBackward0>),
 tensor(0.4447, grad_fn=<VarBackward0>))

In [93]:
values.size()

torch.Size([1, 5, 10])

In [94]:
queries.size()

torch.Size([1, 5, 10])

In [96]:
scores = torch.bmm(input=keys, mat2=queries.transpose(1,2))

In [116]:
# variance
values.var(), queries.var()

(tensor(0.3691, grad_fn=<VarBackward0>),
 tensor(0.4447, grad_fn=<VarBackward0>))

In [117]:
scores.var()

tensor(1.5931, grad_fn=<VarBackward0>)

In [97]:
scores

tensor([[[ 0.2310,  1.3922,  1.7466,  0.2899, -0.7407],
         [ 0.1478,  1.5709,  0.1406,  3.4693,  0.5907],
         [-0.9966, -1.9246, -1.0777,  0.7015,  0.3237],
         [-1.9914, -0.1161, -0.2837,  1.4838, -0.8076],
         [-0.8486,  0.9610,  0.5063, -0.7855, -1.4285]]],
       grad_fn=<BmmBackward0>)

In [98]:
scores.size()

torch.Size([1, 5, 5])

In [101]:
#normalize variance
normalize_scores = scores / (d_out ** 0.5)

In [125]:
print("Variance of queries, values, keuys: ",  queries.var(), values.var(), keys.var())
# before normalize
print("before_ normailze:", scores.var())
# after normalize
print("After Normalize: ", normalize_scores.var())

Variance of queries, values, keuys:  tensor(0.4447, grad_fn=<VarBackward0>) tensor(0.3691, grad_fn=<VarBackward0>) tensor(0.3649, grad_fn=<VarBackward0>)
before_ normailze: tensor(1.5931, grad_fn=<VarBackward0>)
After Normalize:  tensor(0.1593, grad_fn=<VarBackward0>)


In [126]:
normalize_scores

tensor([[[ 0.0730,  0.4403,  0.5523,  0.0917, -0.2342],
         [ 0.0467,  0.4968,  0.0445,  1.0971,  0.1868],
         [-0.3151, -0.6086, -0.3408,  0.2218,  0.1024],
         [-0.6297, -0.0367, -0.0897,  0.4692, -0.2554],
         [-0.2683,  0.3039,  0.1601, -0.2484, -0.4517]]],
       grad_fn=<DivBackward0>)

In [103]:
normalize_scores.size()

torch.Size([1, 5, 5])

In [104]:
# change to probability

probability_scores = attentions =  torch.softmax(normalize_scores, dim=-1)

In [106]:
probability_scores

tensor([[[0.1720, 0.2484, 0.2778, 0.1753, 0.1265],
         [0.1320, 0.2070, 0.1317, 0.3774, 0.1519],
         [0.1681, 0.1253, 0.1638, 0.2876, 0.2552],
         [0.1114, 0.2015, 0.1911, 0.3342, 0.1619],
         [0.1623, 0.2877, 0.2492, 0.1656, 0.1351]]],
       grad_fn=<SoftmaxBackward0>)

In [108]:
probability_scores.size()

torch.Size([1, 5, 5])

In [110]:
hidden_state = torch.bmm(attentions, values)


In [111]:
hidden_state

tensor([[[-0.5315,  0.0225,  0.1819, -0.3955,  0.2597, -0.3418, -0.0068,
          -0.5554,  0.1463,  0.2392],
         [-0.5591,  0.1517,  0.1867, -0.4996,  0.1146, -0.2975,  0.2468,
          -0.4611,  0.1741,  0.1936],
         [-0.3604,  0.0889,  0.1225, -0.4348, -0.0771, -0.2743,  0.2131,
          -0.4430,  0.2728,  0.1760],
         [-0.5295,  0.1151,  0.1471, -0.4857,  0.1896, -0.3399,  0.1387,
          -0.4723,  0.1482,  0.2493],
         [-0.5731,  0.0462,  0.1938, -0.4060,  0.2623, -0.3116,  0.0116,
          -0.5832,  0.1413,  0.2076]]], grad_fn=<BmmBackward0>)

In [112]:
hidden_state.size()

torch.Size([1, 5, 10])

# Positional Encoding


In [159]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_sequence_length):
    super(PositionalEncoding, self).__init__()

    self.max_sequence_length = max_sequence_length
    self.d_model = d_model


  def forward(self):
    even_i = torch.arange(0, self.d_model, 2).float()
    denominator = torch.pow(10000, (2 * even_i / self.d_model))
    position = torch.arange(self.max_sequence_length).reshape(self.max_sequence_length, 1)
    even_PE = torch.sin(position / denominator)
    odd_PE = torch.cos(position/denominator)
    stacked = torch.stack([even_PE, odd_PE], dim=2)
    PE = torch.flatten(stacked, start_dim=1, end_dim=2)
    return PE

In [160]:
tensor1 = torch.tensor([1,2,3], dtype=torch.float32)
tensor2 = torch.tensor([11,12,13], dtype=torch.float32)



In [161]:
two = torch.stack((tensor1, tensor2), dim=-1)
two

tensor([[ 1., 11.],
        [ 2., 12.],
        [ 3., 13.]])

In [162]:
torch.flatten(two,start_dim=-2)

tensor([ 1., 11.,  2., 12.,  3., 13.])

In [163]:
pe = PositionalEncoding(d_model=6, max_sequence_length=10)
pe.forward()


tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  2.1544e-03,  1.0000e+00,  4.6416e-06,
          1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  4.3089e-03,  9.9999e-01,  9.2832e-06,
          1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  6.4633e-03,  9.9998e-01,  1.3925e-05,
          1.0000e+00],
        [-7.5680e-01, -6.5364e-01,  8.6176e-03,  9.9996e-01,  1.8566e-05,
          1.0000e+00],
        [-9.5892e-01,  2.8366e-01,  1.0772e-02,  9.9994e-01,  2.3208e-05,
          1.0000e+00],
        [-2.7942e-01,  9.6017e-01,  1.2926e-02,  9.9992e-01,  2.7850e-05,
          1.0000e+00],
        [ 6.5699e-01,  7.5390e-01,  1.5080e-02,  9.9989e-01,  3.2491e-05,
          1.0000e+00],
        [ 9.8936e-01, -1.4550e-01,  1.7235e-02,  9.9985e-01,  3.7133e-05,
          1.0000e+00],
        [ 4.1212e-01, -9.1113e-01,  1.9389e-02,  9.9981e-01,  4.1774e-05,
          1.0000e+00]])