In [1]:
# !pip install torchdata
# !pip install -U torchtext
!pip install sacremoses

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 23.7 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=b912e9a18e2098aab8356a59debe162857916527a86688d233dfb6fb536607ca
  Stored in directory: /root/.cache/pip/wheels/87/39/dd/a83eeef36d0bf98e7a4d1933a4ad2d660295a40613079bafc9
Successfully built sacremoses
Installing collected packages: sacremoses
Successfully installed sacremoses-0.0.53


In [2]:
import torch
import torch.nn as nn

import os

import torchtext.transforms as T
from torchtext.data.utils import get_tokenizer

from torchtext.vocab import vocab
from collections import OrderedDict

from torch.utils.data import Dataset, DataLoader

import torch.optim as optim

import matplotlib.pyplot as plt

import math

import torch.nn.functional as F

from tqdm import tqdm

## Dataset Setup

In [3]:
!rm -rf data
!rm -rf *.tgz*
!wget https://www.statmt.org/europarl/v7/fr-en.tgz
!tar xfz *.tgz
!mkdir -p data/training
!mv *.en *.fr data/training

--2022-11-07 02:47:04--  https://www.statmt.org/europarl/v7/fr-en.tgz
Resolving www.statmt.org (www.statmt.org)... 129.215.197.184
Connecting to www.statmt.org (www.statmt.org)|129.215.197.184|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 202718517 (193M) [application/x-gzip]
Saving to: ‘fr-en.tgz’


2022-11-07 02:50:59 (844 KB/s) - ‘fr-en.tgz’ saved [202718517/202718517]



In [4]:
# TODO: need to clean and speed this up, but for now it works fine
def get_phrases(path, tokenizer, size):
    phrases = []
    en_tokenize = get_tokenizer(tokenizer, language='en')
    fr_tokenize = get_tokenizer(tokenizer, language='fr')
    # get tokenized dataset
    with open(f'{path}/europarl-v7.fr-en.en') as en:
        with open(f'{path}/europarl-v7.fr-en.fr') as fr:
            k = 0
            for phrase in zip(en.readlines(), fr.readlines()):
                en_phrase = en_tokenize(phrase[0].strip())
                fr_phrase = fr_tokenize(phrase[1].strip())
                if len(en_phrase) > 16 or len(fr_phrase) > 16:
                    continue
                phrases.append({
                    'en': en_phrase,
                    'fr': fr_phrase,
                })
                k+=1
                if k >= size: break
    return phrases

def get_vocab(phrases, special_tokens):
    vocab_freq_en = {}
    vocab_freq_fr = {}

    for phrase in phrases:
        for word in phrase['en']:
            vocab_freq_en[word] = vocab_freq_en.get(word, 0) + 1

        for word in phrase['fr']:
            vocab_freq_fr[word] = vocab_freq_fr.get(word, 0) + 1

    vocab_en_ = sorted(vocab_freq_en.keys(), key=lambda x: vocab_freq_en[x], reverse=True)
    vocab_fr_ = sorted(vocab_freq_fr.keys(), key=lambda x: vocab_freq_fr[x], reverse=True)

    

    vocab_en = vocab_en_[:math.floor(0.6*len(vocab_en_))] + special_tokens
    vocab_fr = vocab_fr_[:math.floor(0.6*len(vocab_fr_))] + special_tokens

    vocab_en = vocab(OrderedDict([(word, 1) for word in vocab_en]))
    idx = vocab_en[special_tokens[2]]
    print(idx)
    vocab_en.set_default_index(idx)
    
    vocab_fr = vocab(OrderedDict([(word, 1) for word in vocab_fr]))
    idx = vocab_fr[special_tokens[2]]
    print(idx)
    vocab_fr.set_default_index(vocab_fr[special_tokens[-1]])

    return vocab_en, vocab_fr


def to_idx(phrases, vocab, special_tokens):
    max_len = 0
    idx_phrases = []
    vocab_en, vocab_fr = vocab
    for phrase in phrases:
        phrase['en'] = special_tokens[:1] + phrase['en'] + special_tokens[1:2]
        phrase['fr'] = special_tokens[:1] + phrase['fr'] + special_tokens[1:2]

        if max_len < len(phrase['en']):
            max_len = len(phrase['en'])
        if max_len < len(phrase['fr']):
            max_len = len(phrase['fr'])

    for phrase in phrases:
        en_phrase = phrase['en']
        fr_phrase = phrase['fr']

        en_phrase_idx = T.VocabTransform(vocab_en)(en_phrase)
        fr_phrase_idx = T.VocabTransform(vocab_fr)(fr_phrase)

        idx_phrases.append({
            'en': en_phrase_idx,
            'fr': fr_phrase_idx,
        })
        
    return idx_phrases, max_len

class WMT2014(Dataset):
    def __init__(self, path, special_tokens, ds_len=10_000, show_idx=True):
        self.path = path
        self.show_idx = show_idx
        self.special_tokens = special_tokens
        self.phrases = get_phrases(self.path, 'moses', ds_len)
        self.vocab = get_vocab(self.phrases, self.special_tokens)
        self.idx_phrases, self.max_len = to_idx(self.phrases, self.vocab, self.special_tokens)

    def __getitem__(self, idx) :
        if self.show_idx:
            padded_src = torch.zeros(self.max_len, dtype=torch.int32)
            padded_tgt = torch.zeros(self.max_len, dtype=torch.int32)
            src = torch.tensor(self.idx_phrases[idx]['en'])
            tgt = torch.tensor(self.idx_phrases[idx]['fr'])
            
            padded_src[:src.shape[0]] = src
            padded_tgt[:tgt.shape[0]] = tgt

            return padded_src, padded_tgt
        else:
            src = self.phrases[idx]['en']
            tgt = self.phrases[idx]['fr']
            return src, tgt
        

        return src, tgt

    def __len__(self):
        return len(self.phrases)

## Transformer Model

In [5]:
# TODO: add masking
class MultiHeadAttention(nn.Module):
    def __init__(self, seq_len, d_model, num_heads, mask=False):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.seq_len = seq_len
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def reshape(self, x):
        # [bs, seq_len, num_heads, d_k]
        # split into seperate heads
        out = x.reshape(*x.shape[:2], self.num_heads, self.d_k)
        # [bs, num_heads, seq_len, d_k]
        # swap heads/seq_len dim to be able to do matmul on each head in parallel
        return out.permute(0, 2, 1, 3)

    def attention(self, q, k, v):
        scale = 1 / self.d_k**0.5
        out = q @ k.transpose(-2, -1) * scale
        out = self.softmax(out) @ v
        return out

    def forward(self, q, k, v):
        # [bs, num_heads, seq_len, d_k]
        q = self.reshape(self.w_q(q))
        k = self.reshape(self.w_k(k))
        v = self.reshape(self.w_v(v))

        # [bs, seq_len, num_heads, d_k]
        # compute attention and swap back to orig dims
        attn = self.attention(q, v, k).permute(0, 2, 1, 3)
        # [bs, seq_len, d_model]
        # concatenate heads
        attn = attn.reshape(-1, self.seq_len, self.d_model)

        out = self.w_o(attn)

        return out

class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, emb_dim):
        super(PositionalEncoding, self).__init__()

        self.pe = torch.zeros(seq_len, emb_dim).cuda()
        pos = torch.arange(seq_len).unsqueeze(1).cuda()
        i = torch.arange(0, emb_dim, 2).cuda()
        div = torch.exp(i * -1 * math.log(10_000) / emb_dim).cuda()
        self.pe[:, 0::2] = torch.sin(pos / div).cuda()
        self.pe[:, 1::2] = torch.cos(pos / div).cuda()

        # give it batch dim
        self.pe = self.pe.unsqueeze(0)

    def forward(self, emb):
        return emb + self.pe, self.pe

class EncoderBlock(nn.Module):
    def __init__(self, seq_len, d_model, num_heads):
        super(EncoderBlock, self).__init__()

        norm_shape = (seq_len, d_model)

        self.attn = MultiHeadAttention(seq_len, d_model, num_heads)
        self.ln_1 = nn.LayerNorm(norm_shape)

        self.ff = nn.Sequential(
            nn.Linear(d_model, 2048),
            nn.ReLU(),
            nn.Linear(2048, d_model),
        )

        self.ln_2 = nn.LayerNorm(norm_shape)
    
    def forward(self, src):
        attn = self.attn(src, src, src)
        out_1 = self.ln_1(src + attn)

        ff = self.ff(out_1)
        out_2 = self.ln_2(out_1 + ff)

        return out_2

class DecoderBlock(nn.Module):
    def __init__(self, seq_len, d_model, num_heads):
        super(DecoderBlock, self).__init__()
        
        norm_shape = (seq_len, d_model)

        self.attn_1 = MultiHeadAttention(seq_len, d_model, num_heads)
        self.ln_1 = nn.LayerNorm(norm_shape)

        self.attn_2 = MultiHeadAttention(seq_len, d_model, num_heads)
        self.ln_2 = nn.LayerNorm(norm_shape)

        self.ff = nn.Sequential(
            nn.Linear(d_model, 2048),
            nn.ReLU(),
            nn.Linear(2048, d_model),
        )
        self.ln_3 = nn.LayerNorm(norm_shape)

    # TODO: find better solutions for decoder inputs
    def forward(self, args):
        enc = args[0]
        tgt = args[1]

        attn_1 = self.attn_1(tgt, tgt, tgt)
        out_1 = self.ln_1(enc + attn_1)

        attn_2 = self.attn_2(enc, enc, out_1)
        out_2 = self.ln_2(out_1 + attn_2)

        ff = self.ff(out_2)
        out_3 = self.ln_3(out_2 + ff)

        return enc, out_3

class Transformer(nn.Module):
    def __init__(self, seq_len, vocab_len, d_model, num_blocks, num_heads):
        super(Transformer, self).__init__()

        self.inp_emb = nn.Embedding(vocab_len[0], d_model)
        self.out_emb = nn.Embedding(vocab_len[1], d_model)
        self.pos_enc = PositionalEncoding(seq_len, d_model)

        self.enc = nn.Sequential(
            *[EncoderBlock(seq_len, d_model, num_heads) for _ in range(num_blocks)]
        )

        self.dec = nn.Sequential(
            *[DecoderBlock(seq_len, d_model, num_heads) for _ in range(num_blocks)]
        )

        self.out = nn.Sequential(
            nn.Linear(d_model, vocab_len[1]),
            nn.LogSoftmax(dim=-1)
        )

        self.init_params()

    def init_params(self):
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt):
        inp_emb, _ = self.pos_enc(self.inp_emb(src))
        enc = self.enc(inp_emb)

        out_emb = self.out_emb(tgt)
        _, dec = self.dec((enc, out_emb))
        out = self.out(dec)

        return out

## Training

In [6]:
ds = WMT2014('data/training', ['<SOS>', '<EOS>', '<UNK>'], ds_len=100_000)

14349
19148


In [7]:
dl = DataLoader(ds, batch_size=128, shuffle=True)

In [8]:
seq_len = ds[0][0].size(0)
vocab_len = len(ds.vocab[0]), len(ds.vocab[1])
d_model = 512
num_blks = 6
num_heads = 8

transformer = Transformer(seq_len, vocab_len, d_model, num_blks, num_heads).to('cuda')
opt = optim.Adam(transformer.parameters(), lr=2e-3, betas=(0.9,0.98))

kl = nn.KLDivLoss()

epochs = 10

for epoch in tqdm(range(epochs)):
    loss_ = 0
    # currently just passing in tgt unmasked & unshifted
    # model currently just guessing same statement for every input
    # TODO: fix these issues
    for batch_num, (src, tgt) in enumerate(dl):
        transformer.train()
        # TODO: find better fix for negative indicies
        src[src<0] = ds.vocab[0]['<UNK>']
        tgt[tgt<0] = ds.vocab[1]['<UNK>']

        src = src.cuda()
        tgt = tgt.cuda()
        
        opt.zero_grad()

        out = transformer(src, tgt)

        tgt_dist = torch.zeros_like(out).cuda()
        for i in range(tgt.size(0)):
            tgt_dist[i, :, :] = F.one_hot(tgt[i,:].long(), num_classes=tgt_dist.size(2))

        loss = kl(out, tgt_dist)
        loss_ = loss.item()
        loss.backward()

        opt.step()
    
    if (epoch+1) % 2 == 0:
        print(loss_)



  "reduction: 'mean' divides the total loss by both the batch size and the support size."
 20%|██        | 2/10 [07:13<28:55, 216.91s/it]

0.0002253243583254516


 40%|████      | 4/10 [14:28<21:44, 217.42s/it]

0.00019704042642842978


 60%|██████    | 6/10 [21:44<14:30, 217.62s/it]

0.0002007755101658404


 80%|████████  | 8/10 [28:59<07:15, 217.67s/it]

0.0002167060156352818


100%|██████████| 10/10 [36:14<00:00, 217.48s/it]

0.0002177873975597322





In [21]:
# TESTING MODEL RUNS

transformer.eval()

src = ds[10000][0].unsqueeze(0).cuda()
tgt = ds[10000][1].unsqueeze(0).cuda()

out = transformer(src, tgt)

out.argmax(-1)

def token_to_str(lang='en'):
    def func(token):
        if lang == 'en':
            out = ds.vocab[0].lookup_token(token)
        elif lang == 'fr':
            out = ds.vocab[1].lookup_token(token)
        
        return out
    
    return func

print('en:', ' '.join(map(token_to_str('en'), src[0])))
print('fr:', ' '.join(map(token_to_str('fr'), out[0].argmax(-1))))

en: <SOS> In our opinion this initiative would be excessively bureaucratic and would not make sense . <EOS> .
fr: <SOS> Il &apos;est , , . . . . . . . . . . . . .


In [None]:
# import torch.nn.functional as F
# tgt = torch.cat([ds[i][1].unsqueeze(0) for i in range(10)], dim=0)
# out = torch.zeros(10, 18, 6500)

# print('tgt.shape', tgt.shape)
# print('out.shape', out.shape)

# for i in range(tgt.size(0)):
    # out[i, :, :] = F.one_hot(tgt[i, :].long(), num_classes=6500)