In [None]:
%load_ext autoreload
%autoreload 2

### Embeddings

In [None]:
import torch
from transformer.layers import Embedding

In [None]:
d_vocabulary = 10
d_model = 16
d_sentence = 8
d_batch =  3

In [None]:
x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
x

In [None]:
model_embedding = Embedding(d_vocabulary, d_model, d_sentence)
emb = model_embedding(x)
#
assert emb.shape == torch.Size((d_batch, d_sentence, d_model))

### Padding Mask

In [None]:
from transformer.utils import get_attn_mask

In [None]:
# mask without heads
d_vocabulary = 4
d_batch = 3
d_sentence = 5

x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
mask = get_attn_mask(x)

assert torch.equal(mask, (x == 0).unsqueeze(1).repeat(1, d_sentence, 1))

In [None]:
# mask with heads
n_heads = 2
d_vocabulary = 4
d_batch = 3
d_sentence = 5

x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
mask = get_attn_mask(x, n_heads=n_heads)

assert mask.shape == torch.Size((d_batch, n_heads, d_sentence, d_sentence))

### ScaledDotProductAttention

In [None]:
import torch
from transformer.layers import ScaledDotProductAttention
from transformer.utils import get_attn_mask

In [None]:
# without dimension for heads
#
d_vocabulary = 7
d_b = 4  # batch size
d_k = 3  # dim of W_k
d_v = 5  # dim of W_v
d_l = 6  # length of sentences
#
Q = torch.rand((d_b, d_l, d_k))
K = torch.rand((d_b, d_l, d_k))
V = torch.rand((d_b, d_l, d_v))
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x)

In [None]:
model_sdpa = ScaledDotProductAttention(d_k)
context, attn = model_sdpa(Q, K, V, mask)
#
assert context.shape == torch.Size((d_b, d_l, d_v))
assert attn.shape == torch.Size((d_b, d_l, d_l))

In [None]:
# with dimensions for heads
#
d_vocabulary = 7
d_b = 4  # batch size
d_k = 3  # dim of W_k
d_v = 5  # dim of W_v
d_l = 6  # length of sentences
n_h = 2  # number of heads
#
Q = torch.rand((d_b, n_h, d_l, d_k))
K = torch.rand((d_b, n_h, d_l, d_k))
V = torch.rand((d_b, n_h, d_l, d_v))
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x, n_h)

In [None]:
model_sdpa = ScaledDotProductAttention(d_k)
context, attn = model_sdpa(Q, K, V, mask)
#
assert context.shape == torch.Size((d_b, n_h, d_l, d_v))
assert attn.shape == torch.Size((d_b, n_h, d_l, d_l))

### Multihead Attention

In [None]:
from transformer.layers import MultiHeadAttention

In [None]:
d_m = 8
d_v = 8 # must be equal to d_m so far. sorry, crappy cupling of modules ;)
#
d_k = 6
n_h = 2
d_l = 7
d_b = 3
#
model_mha = MultiHeadAttention(d_m, d_k, d_v, n_h)
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x)

# random embedding
emb = torch.rand((d_b, d_l, d_m))

In [None]:
output, attn = model_mha(emb, mask)
#
assert output.shape == torch.Size((d_b, d_l, d_v))
assert attn.shape == torch.Size((d_b, n_h, d_l, d_l))

### Position Wise Feed Forward

In [None]:
from transformer.layers import PoswiseFeedForwardNet

In [None]:
d_b = 1
d_m = 3
d_ff = 4
d_l = 8
#
x = torch.rand((d_b, d_l, d_m))
#
model_pffn = PoswiseFeedForwardNet(d_m, d_ff)
#
out = model_pffn(x)
#
assert out.shape == torch.Size((d_b, d_l, d_m))

In [None]:
v1 = 0.6
v2 = 0.7
#
x = torch.rand((d_b, d_l, d_m))
for i in range(d_l):
    if i % 2 == 0:
        x[0][i,:] = v1
    else:
        x[0][i,:] = v2
x

In [None]:
out = model_pffn(x)
out

### EncoderLayer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)
        
    def forward(self, x, attn_mask):
        enc_outputs, attn = self.enc_self_attn(x, attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

In [None]:
model_el = EncoderLayer(config.d_model, config.d_k, config.d_v, config.n_heads, config.d_ff)

In [None]:
out, attn = model_el.forward(x, attn_mask)

In [None]:
print(out.shape)
print(attn.shape)