In [46]:
import torch
import torch.nn as nn
import math


In [48]:
PAD = 0
VOCAB_SIZE = 12
MAX_LEN = 10

def generate_batch(batch_size=32):
    src = torch.randint(1, VOCAB_SIZE, (batch_size, MAX_LEN))
    tgt = torch.flip(src, dims=[1])
    return src, tgt


In [50]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:x.size(1)]


In [52]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, V)


In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch, seq, _ = x.size()
        x = x.view(batch, seq, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))


        out = scaled_dot_product_attention(Q, K, V, mask)

        # concatenate heads (paper: "Concat(head₁, …, head_h)")
        out = out.transpose(1, 2).contiguous()
        batch, seq_len, _, _ = out.size()
        out = out.reshape(batch, seq_len, self.num_heads * self.d_k)


        return self.W_o(out)


In [56]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(torch.relu(self.linear1(x)))


In [58]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ffn(x))
        return x


In [60]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.norm1(x + self.self_attn(x, x, x, tgt_mask))
        x = self.norm2(x + self.enc_attn(x, enc_out, enc_out, src_mask))
        x = self.norm3(x + self.ffn(x))
        return x


In [62]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, d_ff=2048, N=6):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)

        self.encoder = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) for _ in range(N)
        ])

        self.decoder = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) for _ in range(N)
        ])

        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.pos(self.embed(src))
        tgt = self.pos(self.embed(tgt))

        for layer in self.encoder:
            src = layer(src, src_mask)

        for layer in self.decoder:
            tgt = layer(tgt, src, src_mask, tgt_mask)

        return self.out(tgt)


In [64]:
def subsequent_mask(size):
    mask = torch.tril(torch.ones(size, size)).bool()
    return mask.unsqueeze(0).unsqueeze(1)


In [66]:
model = Transformer(VOCAB_SIZE)
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)

def lr_schedule(step, d_model=512, warmup=4000):
    return d_model ** -0.5 * min(step ** -0.5, step * warmup ** -1.5)

criterion = nn.CrossEntropyLoss(ignore_index=PAD)

for step in range(1, 2000):
    src, tgt = generate_batch()
    tgt_in = tgt[:, :-1]
    tgt_out = tgt[:, 1:]

    src_mask = (src != PAD).unsqueeze(1).unsqueeze(2)
    tgt_mask = subsequent_mask(tgt_in.size(1))

    logits = model(src, tgt_in, src_mask, tgt_mask)
    loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_out.reshape(-1))

    optimizer.zero_grad()
    loss.backward()

    for p in optimizer.param_groups:
        p["lr"] = lr_schedule(step)

    optimizer.step()


In [67]:
def greedy_decode(model, src):
    src_mask = (src != PAD).unsqueeze(1).unsqueeze(2)
    enc = model.pos(model.embed(src))

    for layer in model.encoder:
        enc = layer(enc, src_mask)

    ys = torch.ones(src.size(0), 1).long()

    for _ in range(MAX_LEN):
        tgt_mask = subsequent_mask(ys.size(1))
        out = model(src, ys, src_mask, tgt_mask)
        next_word = out[:, -1].argmax(dim=-1).unsqueeze(1)
        ys = torch.cat([ys, next_word], dim=1)

    return ys
