In [None]:
import compyute as cp

In [None]:
device = "cuda" if cp.engine.gpu_available() else "cpu"
device

In [None]:
cp.random.set_seed(1337)

In [None]:
with open("../data/tinyshakespeare.txt", "r") as f:
    data = f.read()

In [None]:
from compyute.preprocessing.text import CharacterTokenizer

chars = sorted(list(set(data)))
tokenizer = CharacterTokenizer()
tokenizer.vocab = {i:c for i, c in enumerate(chars)}
tokenizer.ivocab = {c:i for i, c in enumerate(chars)}
tokenizer.vocab_size

In [None]:
data_enc = tokenizer.encode(data)
len(data_enc)

In [None]:
block_size = 256

In [None]:
X = cp.stack([data_enc[i * block_size : i * block_size + block_size] for i in range(len(data_enc) // block_size)])
y = cp.stack([data_enc[i * block_size + 1 : i * block_size + block_size + 1] for i in range(len(data_enc) // block_size)])

n = int(len(X) * 0.9)
X_train = X.to_int()[:n]
y_train = y.to_int()[:n]
X_val = X.to_int()[n:]
y_val = y.to_int()[n:]

print(f"{X_train.shape=}")
print(f"{y_train.shape=}")
print(f"{X_val.shape=}")
print(f"{y_val.shape=}")

In [None]:
import compyute.nn as nn
from transformer import Transformer, get_causal_mask

In [None]:
embed_dims = 384
mask = get_causal_mask((block_size, block_size))

model = Transformer(
    n_embeddings=tokenizer.vocab_size,
    embedding_dim=embed_dims,
    feedforward_channels=4*embed_dims,
    n_heads=6,
    n_blocks=6,
    sequence_length=block_size,
    mask=mask
)

model.to_device(device)

In [None]:
summary = cp.nn.utils.get_module_summary(model, input_shape=(block_size,), input_dtype=cp.int32)
print(summary)

In [None]:
batch_size = 64
micro_batch_size = 32
grad_accumulation_steps = batch_size // micro_batch_size

val_interval = 200
max_iter = 5000
checkpoint_interal = 500
step = 0

train_dl = nn.utils.Dataloader(X_train, y_train, micro_batch_size, device=device)
val_dl = nn.utils.Dataloader(X_val, y_val, micro_batch_size, device=device)
loss_func = nn.CrossEntropy()
optim = nn.optimizers.AdamW(model.get_parameters(), lr=3e-4)

In [None]:
from datetime import datetime
import os

from compyute.nn.utils.tensorboard import SummaryWriter

# create tensorboard logging directory
label = "transformer_shakespeare"
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
logdir = f"./runs/{label}_{timestamp}/"
if not os.path.exists(logdir):
    os.makedirs(logdir)

writer =  SummaryWriter(log_dir=logdir)
loss = 0
accum_step = 0

while step < max_iter:
    for x, y in train_dl():
        accum_step += 1
        
        # training
        with model.train():
            # forward pass
            y_pred = model(x)
            loss += loss_func(y_pred, y).item() / grad_accumulation_steps

            # backward pass
            loss_grads = loss_func.backward() / grad_accumulation_steps # scale by grad accumulation steps
            model.backward(loss_grads)  # compute new gradients
                
        if accum_step == grad_accumulation_steps:
            optim.step()  # update parameters
            optim.reset_grads()  # reset all gradients

            writer.add_scalar("train/loss", loss, step)

            # validation
            if step > 1 and step % val_interval == 0:
                val_loss = 0
                for x_val, y_val in val_dl():
                    y_pred = model(x_val)
                    val_loss += loss_func(y_pred, y_val).item()
                val_loss /= len(val_dl)
                writer.add_scalar("val/loss", val_loss, step)

            # save checkpoints
            if step > 1 and step % checkpoint_interal == 0:
                model_state = model.get_state_dict()
                optim_state = optim.get_state_dict()
                state_dict = {"model": model_state, "optim": optim_state}
                checkpoint_name = f"{label}_{step}.cp"
                cp.save(state_dict, checkpoint_name)

            if step == max_iter:
                break
            step+= 1
            loss = accum_step = 0

In [None]:
output = ""
context = "\n"
print(context, end="")

context = cp.tensor(tokenizer.encode(context), dtype=cp.int32)  # encode context
context = context.to_shape((1, -1)).to_device(model.device)

for _ in range(500):
    logits = model(context)[0, -1].to_cpu()  # get logits
    probs, _ = cp.nn.functional.softmax(logits)  # convert to probs
    topk_probs, topk_indices = cp.topk(probs, 50)  # get top 50 probs
    topk_probs /= cp.sum(topk_probs)  # normalize probs
    index = cp.random.multinomial(x=50, p=topk_probs, shape=(1,))  # sample
    index = topk_indices[index]  # get token id
    char = tokenizer.decode(index)
    print(char, end="")
    output += char
    context = cp.append(context, values=cp.reshape(index, shape=(1, 1)), axis=1).to_int()  # append to context
    context = context[:, -block_size:].to_device(device)

with open("transformer_shakespeare.txt", "w") as f:
    f.write(output)