In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import random
import math
import string

# Training a Transformer

First, let's look at an attention layer in detail.

The job of the attention layer is to allow a token to "see" information from faraway long-ago tokens. It does this by forming a "query" vector that it uses to match other token "key" vectors. Then the information that it sees is put into a set of "value" vectors that are summed up and added to the original querying token state.

This can all be done in parallel - let's look at how.

In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d):
        """
        Here we will assume that the input dimensions are same as the
        output dims.
        """
        super().__init__()

        self.q_layer = torch.nn.Linear(d, d)
        self.k_layer = torch.nn.Linear(d, d)
        self.v_layer = torch.nn.Linear(d, d)

    def forward(self, x, mask=None, return_weights=False):
        """
        Assume x is <t x d> -- t being the sequence length, d
        the embed dims.

        W_q, W_k, and W_v are weights for projecting into queries,
        keys, and values, respectively. Here these will have shape
        <d x t>, yielding d dimensional vectors for each input.

        This function should return a t dimensional attention vector
        for each input -- i.e., an attention matrix with shape <t x t>,
        and the values derived from this <t x d>.

        Derive Q, K, V matrices, then self attention weights. These should
        be used to compute the final representations (t x d); optionally
        return the weights matrix if `return_weights=True`.
        """
        Q = self.q_layer(x)
        K = self.k_layer(x)
        V = self.v_layer(x)

        A = Q @ K.transpose(-2, -1)
        if mask is not None:
            A = A.masked_fill(mask == 0, -1e9)
        weights = F.softmax(A, dim=-1)

        if return_weights:
          return weights, weights @ V

        return weights @ V

Now we put lots of attention layers into a transformer model.

Around each attention layer we create a residual structure, and we also use LayerNorm to prevent things from blowing up.

The last detail is Positional encoding - a vector we add to each token based on where it is, rather than what it is.  This allows a token to be "queried" according to position rather than content.

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers):
        super(TransformerModel, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)
        self.layers = nn.ModuleList([SelfAttentionLayer(embed_dim) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_dim)
        self.decoder = nn.Linear(embed_dim, vocab_size)

    def forward(self, src):
        mask = torch.triu(torch.ones(
            src.size(1), src.size(1), dtype=torch.bool, device=src.device)).T
        src = self.embed(src)
        src = self.pos_encoder(src)
        for layer in self.layers:
            src = layer(src, mask)
        src = self.norm(src)
        output = self.decoder(src)
        return output

class SelfAttentionLayer(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttentionLayer, self).__init__()
        self.self_attn = SingleHeadAttention(embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)

    def forward(self, src, mask=None):
        src = src + self.norm1(self.self_attn(src, mask=mask))
        return src

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(100.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]


## Training data

Here we make some training data.

In [None]:
# Generate synthetic dataset
def generate_string_manipulation_data(num_examples):
    examples = []
    for _ in range(num_examples):
        input_string = ''.join(random.choices(string.ascii_lowercase, k=random.randint(5, 10)))
        #instructions = random.choice(['reverse', 'upper', 'lower', 'capital', 'swap', 'echo'])
        instructions = random.choice(['upper', 'echo'])

        if instructions == 'reverse':
            output_string = input_string[::-1]
        elif instructions == 'upper':
            output_string = input_string.upper()
        elif instructions == 'lower':
            output_string = input_string.lower()
        elif instructions == 'capital':
            output_string = input_string.capitalize()
        elif instructions == 'swap':
            output_string = input_string.swapcase()
        elif instructions == 'echo':
            output_string = input_string

        examples.append(f"{input_string} {instructions} {output_string}. ")

    return examples

print('\n'.join(generate_string_manipulation_data(10)))

Ideally our special words "upper" and "echo" would just be a single token.
We use a "Byte Pair Encoding" tokenizer to learn these rules based on our data.

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Split

# Create a new CharBPETokenizer
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = Split(' ', 'isolated')

# Train the tokenizer on the synthetic dataset
trainer = BpeTrainer(vocab_size=100, min_frequency=2, special_tokens=["<unk>"])
tokenizer.train_from_iterator(generate_string_manipulation_data(1000), trainer=trainer)

# Save the trained tokenizer
tokenizer.save("char_bpe_tokenizer.json")

# Example usage
encoded = tokenizer.encode("abcd capital .")
print(encoded.tokens)
print(encoded.ids)

The code below tokenizes batches of examples into tensors, for training.

In [None]:
vocab = tokenizer.get_vocab()
decode = {v: k for k, v in vocab.items()}

max_len = max([len(tokenizer.encode(s).ids) for s in generate_string_manipulation_data(10000)])

def data_process(data):
    global max_len
    result = []
    for item in data:
        encoded = torch.zeros(max_len + 1, dtype=torch.long)
        ids = tokenizer.encode(item).ids
        encoded[:len(ids)] = torch.tensor(ids, dtype=torch.long)
        result.append(encoded)
    return torch.stack(result)

def make_epoch_data(batch_size, epoch_len):
    flat_train_data = data_process(
        generate_string_manipulation_data(batch_size * epoch_len)
    )
    return flat_train_data.view(epoch_len, batch_size, -1)

make_epoch_data(2, 3)

## The training loop

Play with hyperparameters here

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 100
seq_len = 100
embed_dim = 32
num_layers = 5
num_epochs = 100
warmup_learning_rate = 0.0002
learning_rate = 0.002
num_batches = 50


# Initialize the model, loss function, and optimizer
model = TransformerModel(len(vocab), embed_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=warmup_learning_rate)

# Training loop
for epoch in range(num_epochs):
    epoch_data = make_epoch_data(batch_size, num_batches).to(device)
    for i in range(0, len(epoch_data)):
        inputs = epoch_data[i,:,:-1]
        targets = epoch_data[i,:,1:]
        #print('inputs:', tokenizer.decode(inputs[0].tolist()))
        #print('targets', tokenizer.decode(targets[0].tolist()))
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, len(vocab)), targets.contiguous().view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    if epoch > 3:
        optimizer.param_groups[0]['lr'] = learning_rate * (num_epochs - epoch) / num_epochs




## Generation

In [None]:
end_of_text = tokenizer.encode(".").ids[0]

# Inference
def generate_text(prompt, max_len=28):
    model.eval()
    tokens = tokenizer.encode(prompt).ids
    prompt_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(max_len):
            output = model(prompt_tensor)
            pred_token = output.argmax(dim=2)[:,-1]
            if pred_token == 0:
                break
            prompt_tensor = torch.cat((prompt_tensor, pred_token.unsqueeze(0)), dim=1)

    model.train()
    predicted_sentence = ''.join([decode[t.item()] for t in prompt_tensor.squeeze()])
    return predicted_sentence

# Example usage
prompt = "amazing echo"
generated_text = generate_text(prompt)
print(f"Prompt: {prompt}")
print(f"Generated Text: {generated_text}")

