### Defining the Model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from cmath import sqrt
from positional_encodings.torch_encodings import PositionalEncoding1D, Summer

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, embed_dim, query_size):
        super().__init__()
        self.heads = heads
        self.query_size = int(embed_dim / heads)
        self.lin = nn.Linear(embed_dim, embed_dim)
        self.qkv = nn.ModuleList([copy.deepcopy(self.lin) for _ in range(3)])
        self.ff = nn.Linear(embed_dim, embed_dim)

    def dot_product_attention(self, q, k, v, mask=None):
        scores = q @ torch.transpose(k,-2,-1)
        scores /= 8

        # This mask will be used when the Decoder passes input into its attention layers.
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)

        scores = torch.softmax(scores, dim=-1)
        z = torch.matmul(scores, v)
        return z


    def forward(self, query, key, value, mask=None): # each is of shape [BATCH_SIZE x SEQ_LEN x EMB_DIM]
        #x = torch.cat([x, x, x], dim=-1) # x reshaped to [BATCH_SIZE x SEQ_LEN x 3 * EMB_DIM]
        n_batches = query.size(0)
        mask = mask.unsqueeze(1) if mask is not None else None

        # x projected to [BATCH_SIZE x SEQ_LEN x HEADS * QUERY_SIZE], then reshaped to 
        # [BATCH_SIZE x SEQ_LEN x HEADS x QUERY_SIZE], and finally permuted to 
        # [BATCH_SIZE x HEADS x SEQ_LEN x QUERY_SIZE] for all q, k, v
        q, k, v = [
            qkv(x).view(n_batches, -1, self.heads, self.query_size).permute(0, 2, 1, 3) 
            for qkv, x in zip(self.qkv, (query, key, value))
            ]
        

        z = self.dot_product_attention(q, k, v, mask)
        
        # z made contiguous in memory and transformed from [BATCH_SIZE x HEADS x SEQ_LEN x QUERY_SIZE]
        # to [BATCH_SIZE x SEQ_LEN x HEADS * QUERY_SIZE]
        z = z.transpose(1, 2).contiguous().view(n_batches, -1, self.heads * self.query_size)

        return self.ff(z)

In [3]:
def subsequent_position_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

In [4]:
inp = torch.LongTensor([[1,2,3,4],[3,2,5,1]])
mha = MultiHeadAttention(8, 512, 64)
emb = nn.Embedding(10, 512)
mask = subsequent_position_mask(4)
inp = emb(inp)
out = mha(inp, inp, inp, mask)
out.shape

torch.Size([2, 4, 512])

In [5]:
class PositionalFFN(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, 4 * embed_dim)
        self.fc2 = nn.Linear(embed_dim * 4, embed_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x

In [6]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, heads, embed_dim, dropout=0.1):
        super().__init__()
        query_size = int(embed_dim // heads)
        self.self_attn = MultiHeadAttention(heads, embed_dim, query_size)
        self.ffn = PositionalFFN(embed_dim)
        self.lnorm1 = nn.LayerNorm(embed_dim)
        self.lnorm2 = nn.LayerNorm(embed_dim)
        self.drop1=nn.Dropout(dropout)
        self.drop2=nn.Dropout(dropout)

    def forward(self, x):
        out = self.self_attn(x, x, x)
        out = self.drop1(out)
        norm_out = self.lnorm1(out + x)
        out = self.ffn(norm_out)
        out = self.drop2(out)
        out = self.lnorm2(out + norm_out)
        return out

In [7]:
inp = torch.LongTensor([[1,2,3,4],[3,2,5,1]])
mha = TransformerEncoderLayer(8, 512)
emb = nn.Embedding(10, 512)
inp = emb(inp)
out = mha(inp)
out.shape

torch.Size([2, 4, 512])

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self, n_encoders, heads, embed_dim):
        super().__init__()
        self.encoder_stack = nn.ModuleList([TransformerEncoderLayer(heads, embed_dim) for _ in range(n_encoders)])

    def forward(self, x):
        for enc in self.encoder_stack:
            x = enc(x)
        return out

In [9]:
inp = torch.LongTensor([[1,2,3,4],[3,2,5,1]])
mha = TransformerEncoder(6, 8, 512)
emb = nn.Embedding(10, 512)
inp = emb(inp)
out = mha(inp)
out.shape

torch.Size([2, 4, 512])

In [10]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, heads, embed_dim, dropout=0.1):
        super().__init__()
        query_size = int(embed_dim // heads)
        self.self_attn = MultiHeadAttention(heads, embed_dim, query_size)
        self.enc_dec_attn = MultiHeadAttention(heads, embed_dim, query_size)
        self.ffn = PositionalFFN(embed_dim)
        self.lnorm1 = nn.LayerNorm(embed_dim)
        self.lnorm2 = nn.LayerNorm(embed_dim)
        self.lnorm3 = nn.LayerNorm(embed_dim)
        self.drop1=nn.Dropout(dropout)
        self.drop2=nn.Dropout(dropout)
        self.drop3=nn.Dropout(dropout)

    def forward(self, x, m, src_mask, tgt_mask):
        out = self.self_attn(x, x, x, src_mask)
        out = self.drop1(out)
        norm_out = self.lnorm1(out + x)

        out = self.enc_dec_attn(norm_out, m, m, tgt_mask)
        out = self.drop2(out)
        norm_out = self.lnorm2(out + norm_out)
        
        out = self.ffn(norm_out)
        out = self.drop3(out)
        out = self.lnorm3(out + norm_out)
        return out

In [11]:
class TransformerDecoder(nn.Module):
    def __init__(self, n_decoders, heads, embed_dim):
        super().__init__()
        self.decoder_stack = nn.ModuleList([TransformerDecoderLayer(heads, embed_dim) for _ in range(n_decoders)])

    def forward(self, x, m, src_mask, tgt_mask):
        for dec in self.decoder_stack:
            x = dec(x, m, src_mask, tgt_mask)
        return x

In [12]:
class TransformerEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.d_model = embed_dim

    def forward(self, x):
        x = self.emb(x) * math.sqrt(self.d_model)
        return x

In [13]:
class ClassicTransformer(nn.Module):
    def __init__(self, n_enc, n_dec, vocab_size, heads, embed_dim, output_head, dropout=0.1):
        super().__init__()
        # produces matrix of [BATCH_SIZE x SEQ_LEN x EMB_SIZE]
        self.source_emb = TransformerEmbeddings(vocab_size, embed_dim)
        self.target_emb = TransformerEmbeddings(vocab_size, embed_dim)
        
        self.pos = Summer(PositionalEncoding1D(embed_dim))
        
        self.enc = TransformerEncoder(n_encoders=n_enc, heads=8, embed_dim=embed_dim)
        self.dec = TransformerDecoder(n_decoders=n_dec, heads=8, embed_dim=embed_dim)
        
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

        self.output_head = output_head
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        source = self.pos(self.source_emb(src))
        source = self.drop1(source)
        target = self.pos(self.target_emb(tgt))
        target = self.drop2(target)
        memory = self.enc(source)
        out = self.dec(target, memory, src_mask=mask, tgt_mask=mask)
        return out

In [14]:
class WordGenerationHead(nn.Module):
    '''Outputs softmaxed scores for next words in sequence'''
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        return F.log_softmax(self.fc(x), dim=-1)

In [15]:
generator_head = WordGenerationHead(5, 512)
transformer = ClassicTransformer(n_enc=6, n_dec=6, vocab_size=10, heads=8, embed_dim=512, output_head=generator_head)

inp = torch.LongTensor([[1,2,3,4],[3,2,5,1]])
mask = subsequent_position_mask(inp.size(1))

out = transformer(inp, inp, mask, mask)
out.shape

torch.Size([2, 4, 512])

### Data

In [122]:
from pathlib import Path

from tokenizers import ByteLevelBPETokenizer

paths = [str(x) for x in Path("./eo_data/").glob("**/*.txt")]

# Initialize a tokenizer
tokenizer = ByteLevelBPETokenizer()

# Customize training
tokenizer.train(files=paths, vocab_size=52_000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

# Save files to disk
tokenizer.save_model(".", "esperberto")

In [130]:
europarl_en[1]

'I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.'

### Training

In [58]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

import time

from torch.optim.lr_scheduler import OneCycleLR, LambdaLR
from torchmetrics.functional import precision
import torchmetrics.functional as tf

#### PyTorch Code

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#### Annotated Transformer Code

In [70]:
# Annotated Transformer code

class Batch:
    """Object for holding a batch of data with mask during training."""

    def __init__(self, src, tgt=None, pad=2):  # 2 = <blank>
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()

    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_position_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask

In [71]:
class TrainState:
    """Track number of steps, examples, and tokens processed"""

    step: int = 0  # Steps in the current epoch
    accum_step: int = 0  # Number of gradient accumulation steps
    samples: int = 0  # total # of examples used
    tokens: int = 0  # total # of tokens processed

In [87]:
def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode="train",
    accum_iter=1,
    train_state=TrainState(),
):
    """Train a single epoch"""
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(
            batch.src, batch.tgt, batch.src_mask, batch.tgt_mask
        )
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)
        # loss_node = loss_node / accum_iter
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print(
                (
                    "Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f "
                    + "| Tokens / Sec: %7.1f | Learning Rate: %6.1e"
                )
                % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr)
            )
            start = time.time()
            tokens = 0
        del loss
        del loss_node
    return total_loss / total_tokens, train_state

In [88]:
def rate(step, model_size, factor, warmup):
    """
    we have to default the step to 1 for LambdaLR function
    to avoid zero raising to negative power.
    """
    if step == 0:
        step = 1
    return factor * (
        model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    )

In [89]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."

    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, true_dist.clone().detach())

In [90]:
def loss(x, crit):
    d = x + 3 * 1
    predict = torch.FloatTensor([[0, x / d, 1 / d, 1 / d, 1 / d]])
    return crit(predict.log(), torch.LongTensor([1])).data

In [91]:
def data_gen(V, batch_size, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 10))
        data[:, 0] = 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0)

In [92]:
class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, generator, criterion):
        self.generator = generator
        self.criterion = criterion

    def __call__(self, x, y, norm):
        x = self.generator(x)
        sloss = (
            self.criterion(
                x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
            )
            / norm
        )
        return sloss.data * norm, sloss

In [93]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encoder(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decoder(
            memory, src_mask, ys, subsequent_position_mask(ys.size(1)).type_as(src.data)
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

In [94]:
# Train the simple copy task.


def example_simple_model():
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = ClassicTransformer(n_enc=2, n_dec=2, vocab_size=V, heads=8, embed_dim=512, generator=Generator(512, V))

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, model_size=512, factor=1.0, warmup=400
        ),
    )

    batch_size = 80
    for epoch in range(20):
        model.train()
        run_epoch(
            data_gen(V, batch_size, 20),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train",
        )
        model.eval()
        run_epoch(
            data_gen(V, batch_size, 5),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )[0]

    model.eval()
    src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))


execute_example(example_simple_model)

RuntimeError: The size of tensor a (4) must match the size of tensor b (9) at non-singleton dimension 3

In [None]:



class TransformerTrainModule(pl.LightningModule):
    def __init__(self, n_enc, n_dec, vocab_size, heads, embed_dim, dropout=0.1):
        super().__init__()

        self.save_hyperparameters(ignore=["model", "data"])
        self.model = ClassicTransformer(n_enc, n_dec, vocab_size, heads, embed_dim, dropout)
    

    def forward(self, x):
        x = self.model(x)
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = self.base_criterion(y_hat, y.type(torch.float))

        rmap = tf.retrieval_average_precision(y_hat, y.type(torch.int))

        category_prec = precision(
            y_hat,
            y.type(torch.int),
            average="macro",
            num_classes=self.hparams.n_classes,
            threshold=self.hparams.thresh,
            multiclass=False,
        )
        category_recall = tf.recall(
            y_hat,
            y.type(torch.int),
            average="macro",
            num_classes=self.hparams.n_classes,
            threshold=self.hparams.thresh,
            multiclass=False,
        )
        category_f1 = tf.f1_score(
            y_hat,
            y.type(torch.int),
            average="macro",
            num_classes=self.hparams.n_classes,
            threshold=self.hparams.thresh,
            multiclass=False,
        )

        overall_prec = precision(
            y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False
        )
        overall_recall = tf.recall(
            y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False
        )
        overall_f1 = tf.f1_score(
            y_hat, y.type(torch.int), threshold=self.hparams.thresh, multiclass=False
        )

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_rmap", rmap, prog_bar=True, on_step=False, on_epoch=True)

            self.log(f"{stage}_cat_prec", category_prec, prog_bar=True)
            self.log(f"{stage}_cat_recall", category_recall, prog_bar=True)
            self.log(f"{stage}_cat_f1", category_f1, prog_bar=True)

            self.log(f"{stage}_ovr_prec", overall_prec, prog_bar=True)
            self.log(f"{stage}_ovr_recall", overall_recall, prog_bar=True)
            self.log(f"{stage}_ovr_f1", overall_f1, prog_bar=True)

            # log prediction examples to wandb
            """
            pred = self.model(x)
            pred_keys = pred[0].sigmoid().tolist()
            pred_keys = [0 if p < self.hparams.thresh else 1 for p in pred_keys]


            mapper = cc.COCOCategorizer()
            pred_lbl = mapper.get_labels(pred_keys)
            
            try:
                self.logger.experiment.log({"val_pred_examples": [wandb.Image(x[0], caption=pred_lbl)]})
            except AttributeError:
                pass
            """

    def training_step(self, batch, batch_idx):
        if self.hparams.use_cutmix:
            x, y = batch
            y_hat = self(x)
            # y1, y2, lam = y
            loss = self.criterion(y_hat, y)

        else:
            x, y = batch
            y_hat = self(x)
            loss = self.base_criterion(y_hat, y.type(torch.float))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=self.hparams.weight_decay,
        )

        lr_scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(self.data.train_dataloader()),
                anneal_strategy="cos",
            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}
        # return optimizer