In [1]:
import sys
import math
from tqdm import tqdm
sys.path.insert(0, '../')

import torch
from torch import nn
from torch.nn import functional as F

from attention import MultiHeadAttention
from encoder import Encoder

In [2]:
D_MODEL = 6
NUM_HEADS = 2
MAX_LEN = 4
BATCH_SIZE = 5
encoder = Encoder(d_model=6, num_heads=2, max_len=4)
sample_batch = torch.rand(10, 4, 6)

In [3]:
attn_layer = MultiHeadAttention(d_model=D_MODEL, num_heads=NUM_HEADS, mask=True)

In [4]:
query = torch.rand(BATCH_SIZE, MAX_LEN, D_MODEL)
key = torch.rand(BATCH_SIZE, MAX_LEN, D_MODEL)
value = torch.rand(BATCH_SIZE, MAX_LEN, D_MODEL)

In [21]:
query = query.view(BATCH_SIZE, -1, NUM_HEADS, D_MODEL // NUM_HEADS)
key = key.view(BATCH_SIZE, -1, NUM_HEADS, D_MODEL // NUM_HEADS)
value = value.view(BATCH_SIZE, -1, NUM_HEADS, D_MODEL // NUM_HEADS)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

In [25]:
attention_raw = F.softmax(
    torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(D_MODEL), dim=-1
)
attention_raw.size()

torch.Size([5, 2, 4, 4])

In [45]:
attention_raw.size()[-2:]

torch.Size([4, 4])

In [48]:
mask = torch.triu(torch.ones(4,4), diagonal=1).bool()
mask

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [42]:
F.softmax(attn_masked, dim=-1)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5024, 0.4976, 0.0000, 0.0000],
          [0.3337, 0.3322, 0.3340, 0.0000],
          [0.2446, 0.2492, 0.2455, 0.2607]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4960, 0.5040, 0.0000, 0.0000],
          [0.3316, 0.3362, 0.3322, 0.0000],
          [0.2435, 0.2456, 0.2418, 0.2691]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4799, 0.5201, 0.0000, 0.0000],
          [0.3257, 0.3425, 0.3318, 0.0000],
          [0.2407, 0.2599, 0.2457, 0.2537]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4972, 0.5028, 0.0000, 0.0000],
          [0.3361, 0.3287, 0.3352, 0.0000],
          [0.2524, 0.2487, 0.2521, 0.2469]]],


        [[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5083, 0.4917, 0.0000, 0.0000],
          [0.3363, 0.3234, 0.3404, 0.0000],
          [0.2520, 0.2488, 0.2558, 0.2434]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4999, 0.5001, 0.0000, 0.0000],
          [0.3283,

In [12]:
pad_id = 0
vocab_size = 100
max_len = 10
hidden_dim = 6

def padding(data: list, pad_id: int=0) -> (list, int):
    max_len = len(max(data, key=len))
    output = [sample + [pad_id]*(max_len-len(sample)) for sample in tqdm(data)]
    return output, max_len

data = [
     [62, 13, 47, 39, 78, 33, 56, 13, 39, 29],
     [60, 96, 51, 32, 90, 44, 86, 71, 36, 18],
     [35, 45, 48, 65, 91, 99, 92, 10, 31, 21],
     [75, 51, 45, 48, 65, 91, 99, 11, 13, 28],
     [66, 88, 98, 47, 48, 65, 17, 13, 67, 99],
     ]

# data, max_len = padding(data)
data = torch.LongTensor(data)
print(data.shape)

torch.Size([5, 10])


In [13]:
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_dim)
embedding.weight.data.uniform_(-1, 1)
print('Weight initialized')

Weight initialized


In [14]:
X_embedded = embedding(data)
X_embedded.size()

torch.Size([5, 10, 6])

In [15]:
w_query = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) # Query
w_key = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) # Key
w_value = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) # Value

query = w_query(X_embedded)
key = w_key(X_embedded)
value = w_value(X_embedded)

In [19]:
attention = MultiHeadAttention(hidden_dim=hidden_dim, num_heads=2)
attention_mat = attention(query, key, value)

In [21]:
attention_mat.size()

torch.Size([5, 10, 6])

In [26]:
attn_layer_norm = nn.LayerNorm(normalized_shape=(10, 6))

In [25]:
attention_mat.size()[1:]

torch.Size([10, 6])

In [27]:
attn_layer_norm(attention_mat).size()

torch.Size([5, 10, 6])