In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5
eval_interval = 200
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 512
n_head = 8
n_layer = 6
dropout = 0.1

### 1. Build a wordpiece tokenizer of vocab size 30_000. 
- dataset: https://huggingface.co/datasets/opus100/viewer/en-es/train
- instruction https://huggingface.co/learn/nlp-course/chapter6/8?fw=pt#building-a-bpe-tokenizer-from-scratch

In [2]:
from tokenizers import decoders, models, normalizers, pre_tokenizers, trainers, Tokenizer
from datasets import load_dataset

def get_tokenizer(lang, train_iter):
    tokenizer = Tokenizer(models.WordPiece(unk_token="<unk>"))
    tokenizer.normalizer = normalizers.Sequence(
        [normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()]
    )
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    special_tokens = ["<unk>", "<pad>", "<bos>", "<eos>"]
    tokenizer.decoder = decoders.WordPiece(prefix="##")

    trainer = trainers.WordPieceTrainer(
        vocab_size=30_000,
        min_frequency=2,
        special_tokens=special_tokens)

    def get_training_corpus():
        sents = [d[lang] for d in train_iter]
        for i in range(0, len(sents), 1000):
            yield sents[i : i + 1000]

    tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
    tokenizer.decoder = decoders.WordPiece(prefix="##")
    tokenizer.save(f"tokenizer-{lang}.json")
    return tokenizer
    
SRC_LANG = 'en'
TGT_LANG = 'es'
retrain = False

tokenizers = {}
for lang in [SRC_LANG, TGT_LANG]:
    if retrain:
        train_iter = load_dataset('opus100', language_pair='en-es', split='train')['translation']
        tokenizers[lang] = get_tokenizer(lang, train_iter)
    else:
        tokenizers[lang] = Tokenizer.from_file(f"tokenizer-{lang}.json")

PAD_IDX = tokenizers['en'].token_to_id("<pad>")
BOS_IDX = tokenizers['en'].token_to_id("<bos>")
EOS_IDX = tokenizers['en'].token_to_id("<eos>")

vocab_size = tokenizers[TGT_LANG].get_vocab_size()
vocab_size_enc = tokenizers[SRC_LANG].get_vocab_size()

  from .autonotebook import tqdm as notebook_tqdm


### 2. processing pipeline to get dataloader
- instruction: https://pytorch.org/tutorials/beginner/translation_transformer.html#collation

In [3]:
class token_encode:
    """ transform class to generate token ids and fits with `sequential_transform`"""
    def __init__(self, lang):
        self.tokenizer = tokenizers[lang]

    def __call__(self, x):
        return self.tokenizer.encode(x).ids

token_encodes = {}
for lang in [SRC_LANG, TGT_LANG]:
    token_encodes[lang] = token_encode(lang)

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to truncate list of tokens if over block size
def truncation_transform(token_ids):
    eid = min(len(token_ids), block_size-2)
    return token_ids[:eid]

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for lang in [SRC_LANG, TGT_LANG]:
    text_transform[lang] = sequential_transforms(
        token_encodes[lang], #Tokenization
        truncation_transform,
        tensor_transform, # Add BOS/EOS and create tensor
    )

# function to collate data samples into batch tensors
# https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for sample in batch:
        src_batch.append(text_transform[SRC_LANG](sample[SRC_LANG].rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANG](sample[TGT_LANG].rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [5]:
train_iter = load_dataset('opus100', language_pair='en-es', split='train')['translation']
val_iter = load_dataset('opus100', language_pair='en-es', split='validation')['translation']
train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_dataloader = DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

### 3. build the model
- build based on: https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py
- positional embedding class: https://github.com/hyunwoongko/transformer/blob/master/models/embedding/positional_encoding.py
- also checkout this NB for an analysis of transformer: https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb

In [None]:
class Head(nn.Module):
    """ one head of self/cross attention with optional causal masking """

    def __init__(self, head_size, is_causal):
        super().__init__()
        self.is_causal = is_causal
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y):
        # x is the input for query, y is for key and value; x and y can be the same
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        q = self.query(x) # (B,T,hs)
        k = self.key(y)   # (B,T_y,hs)
        v = self.value(y) # (B,T_y,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T_y) -> (B, T, T_y)
        # add causal mask for decoders
        if self.is_causal:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T_y)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        out = wei @ v # (B, T, T_y) @ (B, T_y, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, is_causal=False):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, is_causal) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y):
        out = torch.cat([h(x, y) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class EncoderBlock(nn.Module):
    """ Encoding block """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, is_causal=False)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x), self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
class DecoderBlock(nn.Module):
    """ Decoding blcok"""

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.msa = MultiHeadAttention(n_head, head_size, is_causal=True)
        self.xa = MultiHeadAttention(n_head, head_size, is_causal=False)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.ln4 = nn.LayerNorm(n_embd)

    def forward(self, x, y):
        # x is decoder input, y is encoder output
        x = x + self.msa(self.ln1(x), self.ln1(x))
        x = x + self.xa(self.ln2(x), self.ln3(y))
        x = x + self.ffwd(self.ln4(x))
        return x

class PositionalEncoding(nn.Module):
    """ compute sinusoid encoding.  """
    def __init__(self, n_embd, device):
        super().__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(block_size, n_embd, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, block_size)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, n_embd, step=2).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / n_embd)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / n_embd)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        # self.encoding
        B, T = x.size()
        return self.encoding[:T, :]

class Transformer(nn.Module):
    """ Transformer with encoder-decoder structure"""

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.token_embedding_table_enc = nn.Embedding(vocab_size_enc, n_embd)
        self.position_embedding = PositionalEncoding(n_embd, device=device)
        self.encoder_blocks = nn.Sequential(*[EncoderBlock(n_embd, n_head=n_head) for _ in range(n_layer)])
        # cannot use nn.Sequential for DecoderBlock because it takes two inputs
        self.decoder_blocks = nn.ModuleList([DecoderBlock(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, idx_enc, targets=None):
        B, T = idx.shape

        tok_emb_enc = self.token_embedding_table_enc(idx_enc) # (B,T_y,C)
        pos_emb_enc = self.position_embedding(idx_enc) # (T_y, c)
        y = tok_emb_enc + pos_emb_enc # (B,T_y,C)
        y = self.encoder_blocks(y) # (B,T_y,C)

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding(idx) # (T, c)
        x = tok_emb + pos_emb # (B,T,C)
        for decoder_block in self.decoder_blocks:
            x = decoder_block(x, y) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits, targets, ignore_index=PAD_IDX)

        return logits, loss

    def generate(self, idx_enc, greedy=False):
        # every generation starts with the BOS token
        B = idx_enc.shape[0]
        idx = torch.ones(B,1).fill_(BOS_IDX).type(torch.long).to(device)
        for i in range(block_size):
            # get the predictions
            logits, _ = self(idx, idx_enc)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            if greedy:
                idx_next = torch.argmax(logits, dim=-1)
            else:
                # apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1) # (B, C)
                # sample from the distribution
                idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # once predicts EOS, everything follows becomes EOS
            idx_next = torch.where(idx[:, -1]==EOS_IDX, EOS_IDX, idx_next.squeeze())
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next[:, None]), dim=1) # (B, T+1)
            # stop generation if everything is EOS
            if torch.all(idx[:, -1]==EOS_IDX):
                break
        return idx

### 4. train and evaluate
- https://pytorch.org/tutorials/beginner/translation_transformer.html
- see `train.py` for a more comprehensive setup

In [None]:
from timeit import default_timer as timer

def train_epoch(model, optimizer, epoch):
    model.train()
    accum_iter =4
    train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)

    losses = 0
    start_time = timer()
    for batch_idx, (src, tgt) in enumerate(train_dataloader):
        src = src.to(device)
        tgt = tgt.to(device)
        _, loss = model(tgt[:, :-1], src, tgt[:, 1:])

        losses += loss.item()
        loss = loss / accum_iter
        loss.backward()

        if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(train_dataloader)):
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        
        if (batch_idx + 1) % eval_interval == 0:
            val_loss = evaluate(model)
            end_time = timer()
            print(
                (f"Epoch: {epoch}, Batch: {batch_idx + 1}, Train loss: {losses/eval_interval:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s")
            )
            losses = 0
            start_time = timer()

def evaluate(model):
    model.eval()
    losses = 0

    for src, tgt in val_dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        _, loss = model(tgt[:, :-1], src, tgt[:, 1:])

        losses += loss.item()
    return losses / len(list(val_dataloader))

In [None]:
model = Transformer()
model.to(device);

In [None]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

NUM_EPOCHS = 2
for epoch in range(1, NUM_EPOCHS+1):
    train_epoch(model, optimizer)