In [None]:
%load_ext autoreload
%autoreload 2

### Embeddings

In [None]:
import torch
from transformer.layers import Embedding

In [None]:
d_vocabulary = 10
d_model = 16
d_sentence = 8
d_batch =  3

In [None]:
x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
x

In [None]:
model_embedding = Embedding(d_vocabulary, d_model, d_sentence)
emb = model_embedding(x)
#
assert emb.shape == torch.Size((d_batch, d_sentence, d_model))

### Padding Mask

In [None]:
from transformer.utils import get_attn_mask

In [None]:
# mask without heads
d_vocabulary = 4
d_batch = 3
d_sentence = 5

x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
mask = get_attn_mask(x)

assert torch.equal(mask, (x == 0).unsqueeze(1).repeat(1, d_sentence, 1))

In [None]:
# mask with heads
n_heads = 2
d_vocabulary = 4
d_batch = 3
d_sentence = 5

x = torch.randint(low=0, high=d_vocabulary, size=(d_batch, d_sentence))
mask = get_attn_mask(x, n_heads=n_heads)

assert mask.shape == torch.Size((d_batch, n_heads, d_sentence, d_sentence))

### ScaledDotProductAttention

In [None]:
import torch
from transformer.layers import ScaledDotProductAttention
from transformer.utils import get_attn_mask

In [None]:
# without dimension for heads
#
d_vocabulary = 7
d_b = 4  # batch size
d_k = 3  # dim of W_k
d_v = 5  # dim of W_v
d_l = 6  # length of sentences
#
Q = torch.rand((d_b, d_l, d_k))
K = torch.rand((d_b, d_l, d_k))
V = torch.rand((d_b, d_l, d_v))
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x)

In [None]:
model_sdpa = ScaledDotProductAttention(d_k)
context, attn = model_sdpa(Q, K, V, mask)
#
assert context.shape == torch.Size((d_b, d_l, d_v))
assert attn.shape == torch.Size((d_b, d_l, d_l))

In [None]:
# with dimensions for heads
#
d_vocabulary = 7
d_b = 4  # batch size
d_k = 3  # dim of W_k
d_v = 5  # dim of W_v
d_l = 6  # length of sentences
n_h = 2  # number of heads
#
Q = torch.rand((d_b, n_h, d_l, d_k))
K = torch.rand((d_b, n_h, d_l, d_k))
V = torch.rand((d_b, n_h, d_l, d_v))
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x, n_h)

In [None]:
model_sdpa = ScaledDotProductAttention(d_k)
context, attn = model_sdpa(Q, K, V, mask)
#
assert context.shape == torch.Size((d_b, n_h, d_l, d_v))
assert attn.shape == torch.Size((d_b, n_h, d_l, d_l))

### Multihead Attention

In [None]:
from transformer.layers import MultiHeadAttention

In [None]:
d_m = 8
d_v = 8 # must be equal to d_m so far. sorry, crappy cupling of modules ;)
#
d_k = 6
n_h = 2
d_l = 7
d_b = 3
#
model_mha = MultiHeadAttention(d_m, d_k, d_v, n_h)
#
x = torch.randint(low=0, high=d_vocabulary, size=(d_b, d_l))
mask = get_attn_mask(x)

# random embedding
emb = torch.rand((d_b, d_l, d_m))

In [None]:
output, attn = model_mha(emb, mask)
#
assert output.shape == torch.Size((d_b, d_l, d_v))
assert attn.shape == torch.Size((d_b, n_h, d_l, d_l))

### Position Wise Feed Forward

In [None]:
from transformer.layers import PoswiseFeedForwardNet

In [None]:
d_b = 1
d_m = 3
d_ff = 4
d_l = 8
#
x = torch.rand((d_b, d_l, d_m))
#
model_pffn = PoswiseFeedForwardNet(d_m, d_ff)
#
out = model_pffn(x)
#
assert out.shape == torch.Size((d_b, d_l, d_m))

In [None]:
v1 = 0.6
v2 = 0.7
#
x = torch.rand((d_b, d_l, d_m))
for i in range(d_l):
    if i % 2 == 0:
        x[0][i,:] = v1
    else:
        x[0][i,:] = v2
x

In [None]:
out = model_pffn(x)
out

### EncoderLayer

In [None]:
import torch
from transformer.layers import AttentionEncoder
from transformer.utils import get_attn_mask

In [None]:
d_voc = 10
d_m = d_v = 3
d_k = 4 
n_h = 2
d_ff = 4 * d_m
#
d_l = 6
#
d_b = 8
#
x = torch.randint(low=0, high=d_voc, size=(d_b, d_l))
mask = get_attn_mask(x)
#
# random embedding
emb = torch.rand((d_b, d_l, d_m))
#
model_el = AttentionEncoder(d_m, d_k, d_v, n_h, d_ff)
#

In [None]:
out, attn = model_el.forward(emb, mask)
#
assert out.shape == torch.Size((d_b, d_l, d_v))
assert attn.shape == torch.Size((d_b, n_h, d_l, d_l))

### BERT

In [None]:
import torch
import torch.nn as nn
from transformer.layers import Embedding, AttentionEncoder
from transformer.utils import get_attn_mask

In [None]:
class BERT(nn.Module):
    def __init__(
        self, d_vocab: int, d_model: int, d_sentence: int,
        n_layers, n_heads, d_k, d_v, d_ff
    ):
        super(BERT, self).__init__()
        #
        self.d_vocab = d_vocab
        self.d_model = d_model
        self.d_sentence = d_sentence
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_ff = d_ff
        #
        assert self.d_v == self.d_model # not optimal but hey ...
        
        # Input Embeddings
        self.embedding = Embedding(d_vocab, d_model, d_sentence)
        
        # Attention Layers
        self.layers = []
        for _ in range(n_layers):
            layer = AttentionEncoder(d_model, d_k, d_v, n_heads, d_ff)
            self.layers.append(layer)
        self.layers = nn.ModuleList(self.layers)
        
        # Output Head
        self.norm = nn.LayerNorm(d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.gelu = torch.nn.GELU()
        
        # Output Decoder
        #  = inverse Embedding
        # There might be a better solution
        self.decoder = nn.Linear(d_model, d_vocab)
        self.decoder.weight = self.embedding.tok_emb.weight
        self.decoder_bias = nn.Parameter(torch.zeros(d_vocab))
    
    
    def forward(self, input_ids, input_mask_pos):
        mask = get_attn_mask(input_ids)
        out = self.embedding(input_ids)
        for layer in self.layers:
            out, attn = layer(out, mask)
        
        # [b, max_pred, d_model]
        masked_pos = input_mask_pos[:, :, None].expand(-1, -1, out.size(-1))
        h_masked = torch.gather(out, 1, masked_pos)
        h_masked = self.norm(self.gelu(self.linear(h_masked)))
        #
        logits = self.decoder(h_masked) + self.decoder_bias
        
        return logits

In [None]:
d_vocab = 10
d_model = d_v = 6
d_sentence = 8
n_layers = 4
n_heads = 5
d_k = 7
d_ff = 4 * d_model
d_batch = 2
#
d_pred_max = 3

In [None]:
input_ids = torch.randint(low=0, high=d_vocab, size=(d_batch, d_sentence))
input_mask_pos = torch.zeros((d_batch, d_pred_max), dtype=torch.long)
for i in range(2):
    for j in range(2):
        input_mask_pos[i][j] = i * 2 + j + 1
input_mask_pos

In [None]:
model = BERT(d_vocab,
             d_model,
             d_sentence,
             n_layers,
             n_heads,
             d_k,
             d_v,
             d_ff)

In [None]:
out = model.forward(input_ids, input_mask_pos)
assert out.shape == torch.Size((d_batch, d_pred_max, d_vocab))

In [None]:
import torch.optim as optim
import torch.nn as nn

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)