In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import random
import math
import string

from typing import Float, List, Tuple
from torch import Tensor

# Training data

Here we make some training data.

In [None]:
# Generate synthetic dataset
def generate_string_manipulation_data(num_examples):
    examples = []
    for _ in range(num_examples):
        input_string = ''.join(random.choices(string.ascii_lowercase, k=random.randint(5, 10)))
        # instructions = random.choice(['reverse', 'upper', 'lower', 'capital', 'swap', 'echo'])
        instructions = random.choice(['upper', 'echo'])

        if instructions == 'reverse':
            output_string = input_string[::-1]
        elif instructions == 'upper':
            output_string = input_string.upper()
        elif instructions == 'lower':
            output_string = input_string.lower()
        elif instructions == 'capital':
            output_string = input_string.capitalize()
        elif instructions == 'swap':
            output_string = input_string.swapcase()
        elif instructions == 'echo':
            output_string = input_string

        examples.append(f"{input_string} {instructions} {output_string}. ")

    return examples

print('\n'.join(generate_string_manipulation_data(10)))

Ideally our special words "upper" and "echo" would just be a single token.
We use a "Byte Pair Encoding" tokenizer to learn these rules based on our data.

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Split

# Create a new CharBPETokenizer
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = Split(' ', 'isolated')

# Train the tokenizer on the synthetic dataset
trainer = BpeTrainer(vocab_size=100, min_frequency=2, special_tokens=["<unk>"])
tokenizer.train_from_iterator(generate_string_manipulation_data(1000), trainer=trainer)

# Save the trained tokenizer
tokenizer.save("char_bpe_tokenizer.json")

# Example usage
encoded = tokenizer.encode("abcd capital .")
print(encoded.tokens)
print(encoded.ids)

The code below tokenizes batches of examples into tensors, for training.

In [None]:
vocab = tokenizer.get_vocab()
decode = {v: k for k, v in vocab.items()}

max_len = max([len(tokenizer.encode(s).ids) for s in generate_string_manipulation_data(10000)])

def data_process(data):
    global max_len
    result = []
    for item in data:
        encoded = torch.zeros(max_len + 1, dtype=torch.long)
        ids = tokenizer.encode(item).ids
        encoded[:len(ids)] = torch.tensor(ids, dtype=torch.long)
        result.append(encoded)
    return torch.stack(result)

def make_epoch_data(batch_size, epoch_len):
    flat_train_data = data_process(
        generate_string_manipulation_data(batch_size * epoch_len)
    )
    return flat_train_data.view(epoch_len, batch_size, -1)

make_epoch_data(2, 3)

## The training loop

Play with hyperparameters here

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 100
seq_len = 100
embed_dim = 32
num_layers = 5
num_epochs = 100
warmup_learning_rate = 0.0002
learning_rate = 0.002
num_batches = 50


# Initialize the model, loss function, and optimizer
model = TransformerModel(len(vocab), embed_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=warmup_learning_rate)

# Training loop
for epoch in range(num_epochs):
    epoch_data = make_epoch_data(batch_size, num_batches).to(device)
    for i in range(0, len(epoch_data)):
        inputs = epoch_data[i,:,:-1]
        targets = epoch_data[i,:,1:]
        #print('inputs:', tokenizer.decode(inputs[0].tolist()))
        #print('targets', tokenizer.decode(targets[0].tolist()))
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, len(vocab)), targets.contiguous().view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    if epoch > 3:
        optimizer.param_groups[0]['lr'] = learning_rate * (num_epochs - epoch) / num_epochs




## Generation

In [None]:
end_of_text = tokenizer.encode(".").ids[0]

# Inference
def generate_text(prompt, max_len=28):
    model.eval()
    tokens = tokenizer.encode(prompt).ids
    prompt_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(max_len):
            output = model(prompt_tensor)
            pred_token = output.argmax(dim=2)[:,-1]
            if pred_token == 0:
                break
            prompt_tensor = torch.cat((prompt_tensor, pred_token.unsqueeze(0)), dim=1)

    model.train()
    predicted_sentence = ''.join([decode[t.item()] for t in prompt_tensor.squeeze()])
    return predicted_sentence

# Example usage
prompt = "amazing echo"
generated_text = generate_text(prompt)
print(f"Prompt: {prompt}")
print(f"Generated Text: {generated_text}")

