In [None]:
import numpy as np

In [None]:
import torch
import torch.nn as nn

In [None]:
from selfatt.plot import TransformerPlotter
from selfatt.training import TrainingAssistant
from selfatt.nanogpt import GPTLanguageModel
from selfatt import device

In [None]:
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

In [None]:
def to_numpy(tensor):
    return tensor.cpu().detach().numpy()

# Data generation

In [None]:
def generate_deterministic_sequence(length=1, n=1, whitespace=False):
    base_str = ''.join([chr(ord('a') + i) for i in range(length)]) + (' ' if whitespace else '')
    return base_str * n

In [None]:
def pick_random_words(words=1, source_vocab=None):
    if source_vocab is None:
        with open('../vocab.txt', 'r') as f:
            source_vocab = f.read().split('\n')
    return np.random.choice(source_vocab, words)

In [None]:
def pick_random_vocab(n=1, words=1, source_vocab=None):
    universe = pick_random_words(words=words, source_vocab=source_vocab)
    return ' '.join(np.random.choice(universe, n))

In [None]:
# Generate a text consisting of the same sequence, over and over
text = generate_deterministic_sequence(length=3, n=100, whitespace=True)

In [None]:
# Generate a text comprised of a number of words repeated at random
# text = pick_random_vocab(n=1000, words=3)

In [None]:
text[:100]

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.5*len(data))
train_data = data[:n]
val_data = data[n:]


In [None]:
# hyperparameters
batch_size = 1 # how many independent sequences will we process in parallel?
block_size = 2 # what is the maximum context length for predictions?

learning_rate = 3e-4
eval_iters = 4
n_embd = 2
n_head = 1
head_size = 2
n_layer = 1
dropout = 0.2

downstream_kwargs = {
    'with_query': True, # enable the Query part of the self-attention mechanism
    'with_key': True, # enable the Key part of the self-attention mechanism
    'with_value': True, # enable the Value part of the self-attention mechanism
    'with_layer_norm': False, # enable the final Layer normalisation step of the network
    'block_with_layer_norm': False, # enable the Layer normalisation steps of each Transformer block
    'block_size': block_size,
    'head_size': head_size,
    'ffwd': False # enable the fully-connected layer at the end of each Transformer block
}

if not downstream_kwargs['with_query']\
    or not downstream_kwargs['with_key']\
    or not downstream_kwargs['with_value']:
    assert n_embd == head_size, 'The embedding dimension and the head size must be equal when head matrices are disabled.'

torch.manual_seed(1337)
model = GPTLanguageModel(n_layer=n_layer, n_embd=n_embd, n_head=n_head, vocab_size=vocab_size, ds_kwargs=downstream_kwargs)
m = model.to(device)
# print the numberof parameters in the model
print(sum(p.numel() for p in m.parameters()), 'parameters')
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

### Visualise the Transformer for some input

In [None]:
input_text = 'ab'
plotter = TransformerPlotter(decode)
plotter.plot_for_input(model, torch.tensor([encode(input_text)]).to(device), n_embd)
clear_output(wait=True)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
max_iters = 100 # Hoy many iterations we train for (i.e. how many batches).
eval_interval = 10 # Hoy many iterations before we evaluate loss again.
plotter = TransformerPlotter(decode)
assistant = TrainingAssistant(batch_size=batch_size, block_size=block_size, eval_iterations=10)
for iter in range(max_iters):
    xb, yb = assistant.get_batch(train_data)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if iter % eval_interval == 0 or iter == max_iters-1:
        clear_output(wait=True)        
        loss_train, loss_val = assistant.estimate_loss(model, train_data), assistant.estimate_loss(model, val_data) 
        loss_str = f"step {iter}: train loss {loss_train:.4f}, val loss {loss_val:.4f}"
        fig, ax = plotter.plot_for_input(model, xb, n_embd, loss_str)
        display(fig)
        # plt.savefig(f'../plots/fig{iter}.png', bbox_inches='tight')
clear_output(wait=True)

In [None]:
encode('c')

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=100)[0].tolist()))