<a href="https://colab.research.google.com/github/eissana/translator/blob/master/translator2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install torchtext==0.6

In [None]:
%pip install --upgrade spacy

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.data import Field, BucketIterator
from torchtext.datasets import Multi30k
from torchtext.data.metrics import bleu_score
import spacy
import numpy as np

In [25]:
def get_loss(logits, y, ignore_index):
    """
    Computes cross-entropy loss, given logits and labels.
    """
    B, T, C = logits.shape
    # F.cross_entropy expects size C, (B, C), or (B, C, ...)
    # logits shape is (B, T, C), so we flatten the first two dimensions.
    return F.cross_entropy(
        logits.view(B * T, C), y.reshape(B * T), ignore_index=ignore_index
    )
    # loss_fn = nn.CrossEntropyLoss(ignore_index=tgt_pad_index)
    # return loss_fn(logits.view(B*T, C), y.reshape(B*T))


def text2tokens(text, tokenizer):
    tokens = [Preprocessor.INIT_TOKEN]
    tokens.extend([t.text.lower() for t in tokenizer(text)])
    tokens.append(Preprocessor.EOS_TOKEN)
    return tokens


def src2target(tokens, model, preprocessor, block_size, device):
    """
    Gets source language tokens, calls model to translate them, and returns
    target tokens.
    """
    token_ids = [preprocessor.src_field.vocab.stoi[token] for token in tokens]

    x = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)

    sos = preprocessor.tgt_field.vocab.stoi[Preprocessor.INIT_TOKEN]
    eos = preprocessor.tgt_field.vocab.stoi[Preprocessor.EOS_TOKEN]

    y = torch.tensor([[sos]], dtype=torch.long, device=device)

    for _ in range(block_size):
        with torch.no_grad():
            logits = model(x, y)

        logits = logits[:, -1, :]

        scores = F.softmax(logits, dim=-1)
        next_token = scores.multinomial(1)
        # next_token = logits.argmax(dim=-1).unsqueeze(0)

        if next_token.item() == eos:
            break

        y = torch.cat((y, next_token), dim=-1)

    y = y.view(-1)
    y = [preprocessor.tgt_field.vocab.itos[t] for t in y]

    return y[1:]


def translate(text, model, preprocessor, block_size, device):
    src_tokens = text2tokens(text, preprocessor.src_spacy.tokenizer)
    tgt_tokens = src2target(src_tokens, model, preprocessor, block_size, device)
    return " ".join(tgt_tokens)


def bleu(data, model, preprocessor, block_size, device):
    targets = []
    outputs = []

    for example in data:
        src = example.src
        trg = example.trg
        prediction = src2target(src, model, preprocessor, block_size, device)

        targets.append([trg])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


In [None]:
%%python3 -m spacy download en_core_web_sm

In [None]:
%%python3 -m spacy download de_core_news_sm

In [None]:
!wget https://raw.githubusercontent.com/eissana/translator/master/multi30k.sh
!sh multi30k.sh

In [8]:
class Preprocessor():
    INIT_TOKEN = "<init>"
    EOS_TOKEN = "<eos>"  # end of sentence

    def __init__(self, spacy_names, exts, data_root, min_freq, max_size):
        self.src_spacy = spacy.load(spacy_names[0])
        self.tgt_spacy = spacy.load(spacy_names[1])

        def src_tokenize(text):
            return [t.text for t in self.src_spacy.tokenizer(text)]

        def tgt_tokenize(text):
            return [t.text for t in self.tgt_spacy.tokenizer(text)]

        self.src_field = Field(
            tokenize=src_tokenize,
            init_token=self.INIT_TOKEN,
            eos_token=self.EOS_TOKEN,
            lower=True,
        )
        self.tgt_field = Field(
            tokenize=tgt_tokenize,
            init_token=self.INIT_TOKEN,
            eos_token=self.EOS_TOKEN,
            lower=True,
        )

        # Run the following command to download data:
        # > sh multi30k.sh
        self.train, self.val, self.test = Multi30k.splits(
            exts=exts,
            fields=(self.src_field, self.tgt_field),
            root=data_root,  # data/multi30k/
        )
        self.src_field.build_vocab(self.train, max_size=max_size, min_freq=min_freq)
        self.tgt_field.build_vocab(self.train, max_size=max_size, min_freq=min_freq)


In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
model_filename = "models/model.pt"

example_text = "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche."

load_model = True
save_model = True

params = {
    "epochs": 10,
    "learning_rate": 3.0e-4,
    "batch_size": 32,
    "embedding_dim": 512,
    "nhead": 8,
    "num_encoder_layers": 3,
    "num_decoder_layers": 3,
    "dropout": 0.1,
    "block_size": 100,
    "dim_feedforward": 4,
}

losses = {
    "train": [],
    "val": [],
}

running on cuda


In [32]:
class Head(nn.Module):
    '''
    Self-attention head layer.
    '''
    def __init__(self, head_size, params, use_mask):
        super().__init__()

        embedding_dim = params['embedding_dim']
        block_size = params['block_size']

        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(params['dropout'])

        self.use_mask = use_mask
        if use_mask:
          # tril is not a model parameter so we register it as a buffer.
          # block_size is the maximum size. The actual size can be smaller.
          self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, v, k, q):
        _, T, C = q.shape
        value, key, query = self.value(v), self.key(k), self.query(q)
        weights = query @ key.transpose(-2, -1) * C**-0.5

        if self.use_mask:
          # The time dimension can be smaller than the block-size.
          weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        out = weights @ value
        return out


class MultiHead(nn.Module):
    def __init__(self, params, use_mask, device):
        super().__init__()
        self.device = device

        embedding_dim = params['embedding_dim']
        nhead = params['nhead']
        assert embedding_dim % nhead == 0, f"{embedding_dim=} must be divisible by {nhead=}"
        head_size = embedding_dim // nhead

        self.ln = nn.LayerNorm(embedding_dim)
        self.heads = nn.ModuleList([
            Head(head_size, params, use_mask) for _ in range(nhead)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(params['dropout'])

    def forward(self, v, k, q):
        v, k, q = self.ln(v), self.ln(k), self.ln(q)
        out = torch.cat([head(v, k, q) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, params):
        super().__init__()

        embedding_dim = params['embedding_dim']
        dim_feedforward = params['dim_feedforward']

        # feed-forward network
        self.ffn = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, dim_feedforward * embedding_dim),
            nn.ReLU(),
            nn.Linear(dim_feedforward * embedding_dim, embedding_dim),  # projection
            nn.Dropout(params['dropout'])
        )

    def forward(self, x):
        out = self.ffn(x)
        return out


class EncoderBlock(nn.Module):
    def __init__(self, params, device):
        super().__init__()

        embedding_dim = params['embedding_dim']
        nhead = params['nhead']

        # multi-head self attention with no mask. All nodes are allowed to
        # communicate freely.
        self.attn = MultiHead(params, use_mask=False, device=device)
        self.ffn = FeedForward(params)

    def forward(self, v, k, q):
        out = q + self.attn(v, k, q)
        out = out + self.ffn(out)
        return out


class DecoderBlock(nn.Module):
    def __init__(self, params, device):
        super().__init__()

        embedding_dim = params['embedding_dim']

        # multi-head self attention with triangular mask. Nodes communicate only
        # with previous nodes.
        self.attn = MultiHead(params, use_mask=True, device=device)
        # Reusing Encoder as the top part of the decoder with a multi-head
        # cross-attention and a feed-forward network on top of it.
        self.attn_ffn = EncoderBlock(params, device)

    def forward(self, enc_out, dec_in):
        out = dec_in
        out = out + self.attn(out, out, out)
        out = out + self.attn_ffn(enc_out, enc_out, out)
        return out


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, device, params):
        super().__init__()
        self.device = device

        embedding_dim = params['embedding_dim']
        block_size = params['block_size']
        num_encoder_layers = params['num_encoder_layers']
        num_decoder_layers = params['num_decoder_layers']
        dropout = params['dropout']

        self.src_emb = nn.Embedding(src_vocab_size, embedding_dim)
        self.src_pos = nn.Embedding(block_size, embedding_dim)

        self.tgt_emb = nn.Embedding(tgt_vocab_size, embedding_dim)
        self.tgt_pos = nn.Embedding(block_size, embedding_dim)

        self.encoders = nn.ModuleList(
            [EncoderBlock(params, device) for _ in range(num_encoder_layers)])
        self.decoders = nn.ModuleList(
            [DecoderBlock(params, device) for _ in range(num_decoder_layers)])

        self.proj = nn.Linear(embedding_dim, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt):
        _, srcT = src.shape
        src_positions = torch.arange(srcT).unsqueeze(0).to(self.device)
        src_out = self.src_emb(src) + self.src_pos(src_positions)
        src_out = self.dropout(src_out)

        _, tgtT = tgt.shape
        tgt_positions = torch.arange(tgtT).unsqueeze(0).to(self.device)
        tgt_out = self.tgt_emb(tgt) + self.tgt_pos(tgt_positions)
        tgt_out = self.dropout(tgt_out)

        for encoder in self.encoders:
            src_out = encoder(src_out, src_out, src_out)

        for decoder in self.decoders:
            tgt_out = decoder(src_out, tgt_out)

        tgt_out = self.proj(tgt_out)
        tgt_out = self.dropout(tgt_out)

        return tgt_out


In [40]:
import warnings
warnings.filterwarnings('ignore')

pp = Preprocessor(
    spacy_names=("de_core_news_sm", "en_core_web_sm"),
    exts=(".de", ".en"),
    data_root="data",
    min_freq=2,
    max_size=10000,
)

src_vocab_size = len(pp.src_field.vocab)
tgt_vocab_size = len(pp.tgt_field.vocab)

train_iter, val_iter, _ = BucketIterator.splits(
    datasets=(pp.train, pp.val, pp.test),
    batch_size=params["batch_size"],
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    device=device,
    params=params,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=params["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10
)

if load_model:
    state = torch.load(model_filename)

    model.load_state_dict(state["model"])
    optimizer.load_state_dict(state["optimizer"])
    scheduler.load_state_dict(state["scheduler"])

num_params = sum([p.nelement() for p in model.parameters()])
print(f"\nmodel parameters: {num_params}")
print(f"\n{params=}")

print(f"\nexample text to trasnlate: {example_text}")

answer = input("\nwould you like to proceed to training? (y/n): ")
if answer.lower() in {"y", "yes"}:
    for epoch in range(params["epochs"]):
        print(f"epoch {epoch} / {params['epochs']}")

        model.eval()
        with torch.no_grad():
          translated_text = translate(
              example_text,
              model,
              pp,
              params["block_size"],
              device,
          )
          print(f"translated example text:\n{translated_text}")

          batch_loss = []
          for batch in val_iter:
              src = batch.src.T.to(device)
              tgt = batch.trg.T.to(device)

              logits = model(src, tgt[:, :-1])
              vloss = get_loss(
                  logits,
                  tgt[:, 1:],
                  ignore_index=pp.tgt_field.vocab.stoi[pp.tgt_field.pad_token],
                )
              batch_loss.append(vloss.item())

          losses['val'].extend(batch_loss)

        model.train()
        batch_loss = []
        for batch in train_iter:
            src = batch.src.T.to(device)
            tgt = batch.trg.T.to(device)

            logits = model(src, tgt[:, :-1])
            loss = get_loss(
                logits,
                tgt[:, 1:],
                ignore_index=pp.tgt_field.vocab.stoi[pp.tgt_field.pad_token],
            )
            batch_loss.append(loss.item())

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

        losses['train'].extend(batch_loss)
        if save_model:
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(checkpoint, model_filename)

        scheduler.step(np.mean(batch_loss))

print(f"\ncomputing bleu score...")
score = bleu(pp.test[:100], model, pp, params["block_size"], device) * 100
print(f"bleu score: {score:0.2f}%")



model parameters: 32218885

params={'epochs': 10, 'learning_rate': 0.0003, 'batch_size': 32, 'embedding_dim': 512, 'nhead': 8, 'num_encoder_layers': 3, 'num_decoder_layers': 3, 'dropout': 0.1, 'block_size': 100, 'dim_feedforward': 4}

example text to trasnlate: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
epoch 0 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 1 / 10
translated example text:
two young , white men are outside near many bushes .
epoch 2 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 3 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 4 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 5 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 6 / 10
translated example text:
two young , white males are outside near many bushes .
epoch 7 / 10
translate

In [41]:
print(f"loss of a random model: {np.log(len(pp.tgt_field.vocab))}")
print(f"final training loss: {np.mean(losses['train'])}")
print(f"final validation loss: {np.mean(losses['val'])}")

loss of a random model: 8.681520484837913
final training loss: 1.1286912155497473
final validation loss: 3.281958619132638
