In [None]:
import numpy as np
import math
import torch
import torch.nn.functional as F
import random
from torch.nn import Linear, Module, LayerNorm, Dropout, ReLU, Embedding, ModuleList, CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm

In [None]:
def fix_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
def attention(queries, keys, values, Q, K, V, mask=None):
    # queries, keys and values have dimensions: (batch_size, length, emb_size)
    # mask has dimensions (batch_size, q_length, k_length)
    # Q, K and V are linear layers: emb_size -> emb_size
    emb_size = queries.size(-1)
    
    queries = Q(queries)
    keys = K(keys)
    values = V(values)
    
    a = torch.matmul(queries, keys.transpose(-1, -2)) / math.sqrt(emb_size)
    
    if mask is not None:
        a = a.masked_fill(mask == 0, -torch.inf)
        
    alpha = F.softmax(a, -1)
    
    return torch.matmul(alpha, values)
    

def multi_head_attention(queries, keys, values, Q, K, V, proj, n_heads=8, mask=None):
    # queries, keys and values have dimensions (batch_size, length, emb_size)
    # mask has dimensions (batch_size, q_length, k_length)
    # Q, K and V are linear layers: emb_size -> emb_size
    
    batch_size = queries.size(0)
    emb_size = queries.size(-1)
    head_emb_size = emb_size // n_heads
    
    assert emb_size % n_heads == 0
    
    queries = Q(queries).view(batch_size, -1, n_heads, head_emb_size).transpose(1, 2)
    keys = K(keys).view(batch_size, -1, n_heads, head_emb_size).transpose(1, 2)
    values = V(values).view(batch_size, -1, n_heads, head_emb_size).transpose(1, 2)
    
    a = torch.matmul(queries, keys.transpose(-1, -2)) / math.sqrt(head_emb_size)
    
    if mask is not None:
        mask = mask.unsqueeze(1)
        a = a.masked_fill(mask == 0, -torch.inf)
        
    alpha = F.softmax(a, -1)
    
    z = torch.matmul(alpha, values).transpose(1, 2).contiguous().view(batch_size, -1, emb_size)
    
    return proj(z)

In [None]:
queries = torch.randn((2, 10, 32))
keys = torch.randn((2, 5, 32))
values = torch.randn((2, 5, 32))

Q = Linear(32, 32)
K = Linear(32, 32)
V = Linear(32, 32)
proj = Linear(32, 32)

mask = torch.ones((2, 10, 5))

In [None]:
attention(queries, keys, values, Q, K, V, mask).shape

In [None]:
multi_head_attention(queries, keys, values, Q, K, V, proj, 8, mask).shape

In [None]:
(batch_size, n_queries, n_keys)

In [None]:
def make_masks(x, y, pad_id=0):
    # x and y have dimensions (batch_size, length)
    enc_mask = (x != 0).unsqueeze(1)
    dec_mask = ~torch.triu(torch.ones((1, y.size(-1), y.size(-1))), 1).to(torch.bool)
    dec_mask = dec_mask & (y != 0).unsqueeze(1)
    
    return enc_mask.to(torch.uint8), dec_mask.to(torch.uint8)

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

y = [[1, 2, 5, 0, 0],
     [1, 2, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)

In [None]:
make_masks(x, y)

In [None]:
class MultiHeadAttentionBlock(Module):
    def __init__(self, emb_size=512, n_heads=8, dropout_p=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.Q = Linear(emb_size, emb_size)
        self.K = Linear(emb_size, emb_size)
        self.V = Linear(emb_size, emb_size)
        self.proj = Linear(emb_size, emb_size)
        self.layernorm = LayerNorm(emb_size)
        self.dropout = Dropout(0.1)
        
    def forward(self, queries, keys, values, mask=None):
        z = multi_head_attention(queries, keys, values, self.Q, self.K, self.V, self.proj, self.n_heads, mask)
        return self.layernorm(queries + self.dropout(z))

In [None]:
class FCNNBlock(Module):
    def __init__(self, emb_size=512, hidden_size=2048, dropout_p=0.1):
        super().__init__()
        self.linear1 = Linear(emb_size, hidden_size)
        self.linear2 = Linear(hidden_size, emb_size)
        self.layernorm = LayerNorm(emb_size)
        self.dropout = Dropout(dropout_p)
        
    def forward(self, x):
        z = self.linear2(F.relu(self.linear1(x)))
        return self.layernorm(x + self.dropout(z))

In [None]:
class EncoderLayer(Module):
    def __init__(self, emb_size=512, n_heads=8, fcnn_hidden_size=2048, dropout_p=0.1):
        super().__init__()
        self.mha = MultiHeadAttentionBlock(emb_size, n_heads, dropout_p)
        self.fcnn = FCNNBlock(emb_size, fcnn_hidden_size, dropout_p)
        
    def forward(self, x, mask=None):
        return self.fcnn(self.mha(x, x, x, mask))

In [None]:
x = torch.randn((2, 10, 512))
encoder_layer = EncoderLayer()

In [None]:
encoder_layer(x).shape

In [None]:
class Embeddings(Module):
    def __init__(self, vocab_size, emb_size=512, max_length=4096, dropout_p=0.1):
        super().__init__()
        self.embeddings = Embedding(vocab_size, emb_size)
        self.dropout = Dropout(dropout_p)
        
        i = torch.arange(max_length).unsqueeze(1)
        j = torch.arange(emb_size // 2)
        pe = torch.zeros(max_length, emb_size)
        pe[:, ::2] = torch.sin(i / torch.pow(10000, 2 * j / emb_size))
        pe[:, 1::2] = torch.cos(i / torch.pow(10000, 2 * j / emb_size))
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return self.dropout(self.embeddings(x) + self.pe[:x.size(1)])

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)

embeddings = Embeddings(6)

In [None]:
embeddings(x)

In [None]:
class Encoder(Module):
    def __init__(
        self,
        vocab_size,
        max_length=4096,
        n_layers=6,
        emb_size=512,
        n_heads=8,
        fcnn_hidden_size=2048,
        dropout_p=0.1
    ):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, emb_size, max_length, dropout_p)
        self.layers = ModuleList(
            EncoderLayer(emb_size, n_heads, fcnn_hidden_size, dropout_p) for _ in range(n_layers)
        )
        
    def forward(self, x, mask=None):
        z = self.embeddings(x)
        for layer in self.layers:
            z = layer(z, mask)
        return z

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)

encoder = Encoder(6)

In [None]:
encoder(x)

In [None]:
class DecoderLayer(Module):
    def __init__(self, emb_size=512, n_heads=8, fcnn_hidden_size=2048, dropout_p=0.1):
        super().__init__()
        self.mha_self = MultiHeadAttentionBlock(emb_size, n_heads, dropout_p)
        self.mha_enc_dec = MultiHeadAttentionBlock(emb_size, n_heads, dropout_p)
        self.fcnn = FCNNBlock(emb_size, fcnn_hidden_size, dropout_p)
        
    def forward(self, h, x, enc_mask=None, dec_mask=None):
        z = self.mha_self(x, x, x, dec_mask)
        z = self.mha_enc_dec(z, h, h, enc_mask)
        return self.fcnn(z)

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

y = [[1, 2, 5, 0, 0],
     [1, 2, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)

enc_mask, dec_mask = make_masks(x, y)

y = torch.randn((2, 5, 512))

encoder = Encoder(6)
decoder_layer = DecoderLayer()

h = encoder(x)

In [None]:
decoder_layer(h, y, enc_mask, dec_mask)

In [None]:
class Decoder(Module):
    def __init__(
        self,
        vocab_size,
        max_length=4096,
        n_layers=6,
        emb_size=512,
        n_heads=8,
        fcnn_hidden_size=2048,
        dropout_p=0.1
    ):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, emb_size, max_length, dropout_p)
        self.layers = ModuleList(
            DecoderLayer(emb_size, n_heads, fcnn_hidden_size, dropout_p) for _ in range(n_layers)
        )
        
    def forward(self, h, y, enc_mask=None, dec_mask=None):
        z = self.embeddings(y)
        for layer in self.layers:
            z = layer(h, z, enc_mask, dec_mask)
        return z

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

y = [[1, 2, 5, 0, 0],
     [1, 2, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)

enc_mask, dec_mask = make_masks(x, y)

encoder = Encoder(6)
decoder = Decoder(6)

h = encoder(x)

In [None]:
decoder(h, y, enc_mask, dec_mask)

In [None]:
class Transformer(Module):
    def __init__(
        self,
        enc_vocab_size,
        dec_vocab_size,
        max_length=4096,
        n_layers=6,
        emb_size=512,
        n_heads=8,
        fcnn_hidden_size=2048,
        dropout_p=0.1
    ):
        super().__init__()
        self.encoder = Encoder(enc_vocab_size, max_length, n_layers, emb_size, n_heads, fcnn_hidden_size, dropout_p)
        self.decoder = Decoder(dec_vocab_size, max_length, n_layers, emb_size, n_heads, fcnn_hidden_size, dropout_p)
        
    def forward(self, x, y, enc_mask=None, dec_mask=None):
        h = self.encoder(x, enc_mask)
        z = self.decoder(h, y, enc_mask, dec_mask)
        return z

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

y = [[1, 2, 5, 0, 0],
     [1, 2, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)

enc_mask, dec_mask = make_masks(x, y)

transformer = Transformer(6, 6)

In [None]:
transformer(x, y, enc_mask, dec_mask)

In [None]:
class Seq2SeqModel(Module):
    def __init__(
        self,
        enc_vocab_size,
        dec_vocab_size,
        max_length=4096,
        n_layers=6,
        emb_size=512,
        n_heads=8,
        fcnn_hidden_size=2048,
        dropout_p=0.1
    ):
        super().__init__()
        self.transformer = Transformer(enc_vocab_size, dec_vocab_size, max_length, n_layers, emb_size, n_heads, fcnn_hidden_size, dropout_p)
        self.logits = Linear(emb_size, dec_vocab_size)
        
    def forward(self, x, y, enc_mask=None, dec_mask=None):
        s = self.transformer(x, y, enc_mask, dec_mask)
        return self.logits(s)

In [None]:
x = [[1, 2, 5, 3, 2, 0, 0],
     [1, 3, 5, 0, 0, 0, 0]]

y = [[1, 2, 5, 0, 0],
     [1, 2, 0, 0, 0]]

x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)

enc_mask, dec_mask = make_masks(x, y)

model = Seq2SeqModel(6, 6)

In [None]:
model(x, y, enc_mask, dec_mask)

# Train model

In [None]:
class Seq2SeqDataset(Dataset):
    def __init__(self, n_samples, vocab_size, min_length=3, max_length=32, seed=None):
        self.pad_id = 0
        self.bos_id = 1
        self.eos_id = 2
        if seed is not None:
            fix_seed(seed)
        self.data = []
        for i in range(n_samples):
            length = np.random.randint(min_length, max_length + 1)
            generated = np.random.randint(3, vocab_size, length).tolist()
            x = [self.bos_id] + generated + [self.eos_id] + [self.pad_id] * (max_length - length)
            y = [self.bos_id] + generated[::-1] + [self.eos_id] + [self.pad_id] * (max_length - length)
            self.data.append((x, y))
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def collate(data):
    x, y = zip(*data)
    return torch.tensor(x, dtype=torch.int32), torch.tensor(y, dtype=torch.int32)

In [None]:
vocab_size = 32
train_dataset_size = 20000
n_epoch = 2

In [None]:
train_dataset = Seq2SeqDataset(train_dataset_size, vocab_size, seed=42)
model = Seq2SeqModel(vocab_size, vocab_size, n_layers=3, emb_size=128, fcnn_hidden_size=256)
optimizer = Adam(model.parameters(), lr=1e-3)
loss_func = CrossEntropyLoss(reduction='none') 

In [None]:
dataloader = DataLoader(train_dataset, 8, shuffle=True, collate_fn=collate)

In [None]:
epoch_loss = []
for i in range(n_epoch):
    losses = []
    print(f'Epoch {i + 1}')
    for x, y in tqdm(dataloader):
        curr_y = y[:, :-1]
        next_y = y[:, 1:].clone()
        next_y[(curr_y == 0) | (curr_y == 2)] = -100
        
        enc_mask, dec_mask = make_masks(x, curr_y)
        
        logits = model(x, curr_y, enc_mask, dec_mask)
        token_losses = loss_func(logits.transpose(1, 2), next_y.to(torch.long))
        loss = token_losses.sum() / (token_losses > 0).sum()
        losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    epoch_loss.append(np.mean(losses))
    print(f'Loss: {epoch_loss[-1]}')

# Test model

In [None]:
def generate(model, x, max_steps=50):
    x = torch.tensor(x, dtype=torch.int32).unsqueeze(0)
    ids = [1]
    for i in range(max_steps):
        y = torch.tensor(ids, dtype=torch.int32).unsqueeze(0)
        enc_mask, dec_mask = make_masks(x, y)
        with torch.no_grad():
            logits = model(x, y, enc_mask, dec_mask)
        next_y = logits[0][-1].argmax().item()
        ids.append(next_y)
        if ids[-1] == 2:
            break
    return ids

In [None]:
x = [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 2]

In [None]:
model.eval()
y = generate(model, x)

In [None]:
assert y == [1, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2]