In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import pandas as pd
from torch.utils.data import TensorDataset, DataLoader

In [2]:
n_embd = 256  # d_model
block_size = 16
n_head = 4
n_layer = 6
dropout = 0.2

In [198]:
df = pd.read_parquet('train.parquet')
df = df['translation'].apply(lambda r: pd.Series([r['en'].lower(), r['fr'].lower()]))
df.columns = 'en', 'fr'
df = df[df.en.str.len() < 30]
df = df.apply(lambda s: s.str.strip().str.split()).reset_index(drop=True)
print(f"[EN] max_length: {df.en.str.len().max()}, [FR] max_length: {df.fr.str.len().max()}")
df

[EN] max_length: 6, [FR] max_length: 10


Unnamed: 0,en,fr
0,[dessert],[dessert]
1,"[white, rice,, cooked,, unsalted]","[riz, blanc,, cuit,, non, salé]"
2,"[cashew,, dry, roasted,, unsalted]","[noix, de, cajou,, grillée, à, sec,, non, salée]"
3,"[burger, sauce]","[sauce, burger]"
4,"[horse, mackerel,, oily,, raw]","[chinchard,, gras,, cru]"
...,...,...
1610,"[fruit, jelly]","[pâte, de, fruits]"
1611,"[celery, stalk]","[céleri, branche]"
1612,"[agar, (seaweed),, raw]","[agar, (algue),, cru]"
1613,"[chilli, pepper,, raw]","[piment,, cru]"


In [4]:
from torchtext.vocab import Vocab, build_vocab_from_iterator

specials = ['<pad>', '<unk>', '<sos>', '<eos>']
PAD, UNK, SOS, EOS = specials
vocab_en = build_vocab_from_iterator(df.en, specials=specials, max_tokens=2048)
vocab_fr = build_vocab_from_iterator(df.fr, specials=specials, max_tokens=2048)
vocab_en.set_default_index(vocab_en[UNK])
vocab_fr.set_default_index(vocab_fr[UNK])
vocab_en_size, vocab_fr_size = len(vocab_en), len(vocab_fr)
vocab_en_size, vocab_fr_size

(1425, 1537)

In [5]:
def text_to_tensor(text, vocab, add_sos=True):
    tokenized_text = [SOS] + text + [EOS] if add_sos else text
    tensor = torch.zeros(block_size).long()
    tensor[:len(tokenized_text)] = torch.as_tensor(vocab.lookup_indices(tokenized_text))
    return tensor.unsqueeze(0)

tokens_en = df.en.apply(lambda x: text_to_tensor(x, vocab_en, add_sos=False)).tolist()
tokens_en = torch.cat(tokens_en, 0)

tokens_fr = df.fr.apply(lambda x: text_to_tensor(x, vocab_fr, add_sos=True)).tolist()
tokens_fr = torch.cat(tokens_fr, 0)

In [6]:
tokens_en.shape, tokens_fr.shape

(torch.Size([1615, 16]), torch.Size([1615, 16]))

---

In [7]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, Xq, Xk, Xv, mask=None):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)

        # B, T, C = Xq.shape
        q = self.query(Xq) # (B,T,hs)
        k = self.key(Xk)   # (B,T,hs)
        v = self.value(Xv) # (B,T,hs)

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        if mask is not None:
            wei = wei.masked_fill(mask, float('-inf')) # (B, T, T)
        
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        
        # perform the weighted aggregation of the values
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

#------------------------------------
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, n_embd):
        super().__init__()
        head_size = n_embd // n_head

        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *parts, mask=None):
        Xq, Xk, Xv = parts if len(parts) == 3 else parts * 3
        out = torch.cat([h(Xq, Xk, Xv, mask=mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

#------------------------------------
class EncoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()

        self.mha = MultiHeadAttention(n_head, n_embd)
        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4), 
            nn.ReLU(), 
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout)
        )

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, mask=None):
        x = x + self.mha(self.ln1(x), mask=mask)
        x = x + self.ffwd(self.ln2(x))
        return x
    
#------------------------------------
class DecoderBlock(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()

        self.mha1 = MultiHeadAttention(n_head, n_embd)
        self.mha2 = MultiHeadAttention(n_head, n_embd)

        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4), 
            nn.ReLU(), 
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout)
        )

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)

    def forward(self, y, x, pad_mask=None, future_mask=None):
        y = y + self.mha1(self.ln1(y), mask=future_mask)
        x = y + self.mha2(self.ln2(y), x, x, mask=pad_mask)
        x = x + self.ffwd(self.ln3(y))
        return x

#------------------------------------
class Transformer(nn.Module):
    def __init__(self, vocab_src_size, vocab_tgt_size):
        super().__init__()

        self.embeding_src = nn.Embedding(vocab_src_size, n_embd, padding_idx=0)
        self.embeding_tgt = nn.Embedding(vocab_tgt_size, n_embd, padding_idx=0)
        self.encoders = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.decoders = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_tgt_size)

        self.register_buffer('triu_mask', torch.tril(torch.ones(block_size, block_size)) == 0)
        self.register_buffer('pos_enc', self._positional_encoding(n_embd, max_len=block_size))
    
    def forward(self, src, tgt):
        # B, T = src.shape
        pad_mask = src.eq(0).unsqueeze(1).repeat(1, block_size, 1)
        future_mask = torch.logical_or(pad_mask, self.triu_mask)

        # idx and targets are both (B,T) tensor of integers
        src_emb = self.embeding_src(src) # (B,T,C)
        tgt_emb = self.embeding_tgt(tgt) # (B,T,C)

        x = src_emb + self.pos_enc # (B,T,C) + (T,C)
        y = tgt_emb + self.pos_enc # (B,T,C) + (T,C)

        # Encoder Blocks
        for encoder in self.encoders:
            x = encoder(x, mask=pad_mask) # (B,T,C)

        # Decoder Blocks
        for decoder in self.decoders:
            y = decoder(y, x, pad_mask=pad_mask, future_mask=future_mask) # (B,T,C)

        y = self.ln_f(y) # (B,T,C)
        logits = self.lm_head(y) # (B,T,vocab_size)
        return logits

    def _positional_encoding(self, d_model, max_len=1000):
        pos = torch.arange(max_len).view(-1, 1).float()
        pe = torch.arange(d_model // 2).repeat_interleave(2).repeat(max_len, 1).float()
        pe[:, 0::2] = torch.sin(pos / (10000 ** (2 * pe[:, 0::2] / d_model)))
        pe[:, 1::2] = torch.cos(pos / (10000 ** (2 * pe[:, 1::2] / d_model)))
        return pe

In [8]:
EPOCHS = 10
BATCH_SIZE = 128

In [9]:
train_ds = TensorDataset(tokens_en, tokens_fr)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
model = Transformer(vocab_en_size, vocab_fr_size)
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), amsgrad=True)
criterion = nn.CrossEntropyLoss()

In [11]:
def transfomer_train():
    for i in range(1, EPOCHS + 1):
        accs = []
        for i_batch, (src, tgt) in enumerate(train_dl, start=1):
            out = model(src, tgt)
            out = out[:, :-1, :]
            tgt = tgt[:, 1:]
            loss = criterion(out.permute(0, 2, 1), tgt)
            acc = (out.argmax(-1) == tgt).float().mean()
            accs.append(acc.item())
            
            optim.zero_grad()
            loss.backward()
            optim.step()

            print(f"EPOCH: {i:>02}   BATCH: {i_batch:>02}   Acc: {acc.item():.4f}   Loss: {loss.item():.4f}")
        
        print('-' * 50)
        print(f"EPOCH {i:>02} Full Accuracy: {np.mean(accs):.4f}")
        print('-' * 50)

In [13]:
TRAIN = True
PATH = f'weights/weight-transformer-{n_embd:>03}embd.pt'

if TRAIN:
    transfomer_train()
    torch.save(model.state_dict(), PATH)
else:
    model = Transformer(vocab_en_size, vocab_fr_size)
    model.load_state_dict(torch.load(PATH))

EPOCH: 01   BATCH: 01   Acc: 0.8214   Loss: 1.0232
EPOCH: 01   BATCH: 02   Acc: 0.8266   Loss: 1.0424
EPOCH: 01   BATCH: 03   Acc: 0.8333   Loss: 0.9743
EPOCH: 01   BATCH: 04   Acc: 0.8271   Loss: 1.0023
EPOCH: 01   BATCH: 05   Acc: 0.8302   Loss: 0.9950
EPOCH: 01   BATCH: 06   Acc: 0.8224   Loss: 1.0599
EPOCH: 01   BATCH: 07   Acc: 0.8234   Loss: 1.0107
EPOCH: 01   BATCH: 08   Acc: 0.8208   Loss: 1.0203
EPOCH: 01   BATCH: 09   Acc: 0.8203   Loss: 1.0566
EPOCH: 01   BATCH: 10   Acc: 0.8281   Loss: 1.0114
EPOCH: 01   BATCH: 11   Acc: 0.8375   Loss: 0.9597
EPOCH: 01   BATCH: 12   Acc: 0.8260   Loss: 0.9834
EPOCH: 01   BATCH: 13   Acc: 0.8245   Loss: 1.0027
--------------------------------------------------
EPOCH 01 Full Accuracy: 0.8263
--------------------------------------------------
EPOCH: 02   BATCH: 01   Acc: 0.8302   Loss: 0.9517
EPOCH: 02   BATCH: 02   Acc: 0.8417   Loss: 0.9506
EPOCH: 02   BATCH: 03   Acc: 0.8328   Loss: 0.9599
EPOCH: 02   BATCH: 04   Acc: 0.8406   Loss: 0.9214


In [197]:
def  eng_to_fr(english_phrase):
    model.eval()

    phrase_en = f"{english_phrase} {f'{PAD} ' * block_size}".strip()
    tokens_en = vocab_en.lookup_indices(phrase_en.split()[:block_size]) 
    tokens_en = torch.tensor(tokens_en).long()

    phrase_fr = ' '.join([SOS] + [PAD] * (block_size - 1))
    tokens_fr = vocab_fr.lookup_indices(phrase_fr.split()[:block_size]) 
    tokens_fr = torch.tensor(tokens_fr).long()

    for i in range(1, block_size + 1):
        print(f"Iter {i} - {phrase_fr}")
        
        if EOS in phrase_fr: break

        with torch.no_grad():
            tokens_fr = vocab_fr.lookup_indices(phrase_fr.split()) 
            tokens_fr = torch.tensor(tokens_fr).long()
            logits = model(tokens_en.unsqueeze(0), tokens_fr.unsqueeze(0))
            tokens_out = logits.squeeze().argmax(-1).tolist()
            phrase_fr = vocab_fr.lookup_tokens([vocab_fr[SOS]] + tokens_out)
            phrase_fr = ' '.join(phrase_fr[:block_size])

eng_to_fr("duck mousse")

Iter 1 - <sos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 2 - <sos> mousse de <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 3 - <sos> mousse de canard <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 4 - <sos> mousse de canard <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


In [200]:
eng_to_fr("white rice")

Iter 1 - <sos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 2 - <sos> riz de <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 3 - <sos> riz blanc riz <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Iter 4 - <sos> riz blanc <eos> <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
