In [None]:
import compyute as cp

In [None]:
device = cp.cuda

In [None]:
with open("../data/tinyshakespeare.txt", "r") as f:
    data = f.read()

In [None]:
from compyute.preprocessing.text import CharacterTokenizer

chars = sorted(list(set(data)))
tokenizer = CharacterTokenizer()
tokenizer.vocab = {i:c for i, c in enumerate(chars)}
tokenizer.ivocab = {c:i for i, c in enumerate(chars)}
data_enc = tokenizer.encode(data)

In [None]:
block_size = 256

In [None]:
X = cp.stack([data_enc[i * block_size : i * block_size + block_size] for i in range(len(data_enc) // block_size)])
y = cp.stack([data_enc[i * block_size + 1 : i * block_size + block_size + 1] for i in range(len(data_enc) // block_size)])

X_train = X.to_int()
y_train = y.to_int()

print(f"{X_train.shape=}")
print(f"{y_train.shape=}")

In [None]:
import compyute.nn as nn
from transformer import Transformer, get_causal_mask

In [None]:
embed_dims = 384
mask = get_causal_mask((block_size, block_size))

model = Transformer(
    n_embeddings=tokenizer.vocab_size,
    embedding_dim=embed_dims,
    feedforward_channels=4*embed_dims,
    n_heads=6,
    n_blocks=6,
    sequence_length=block_size,
    mask=mask
)

model.to_device(device)

In [None]:
summary = cp.nn.utils.get_module_summary(model, input_shape=(block_size,), input_dtype=cp.int32)
print(summary)

In [None]:
batch_size = 32

val_interval = 200
max_iter = 5000
checkpoint_interal = 500


train_dl = nn.utils.Dataloader(X_train, y_train, batch_size, device=device)
loss_func = nn.CrossEntropy()
optim = nn.optimizers.AdamW(model.get_parameters(), lr=3e-4, beta1=0.9, beta2=0.95)

In [None]:
import time

step = 1
while step < max_iter:
    for x, y in train_dl():
        start = time.time()

        with model.train():
            loss = loss_func(model(x), y).item()
            model.backward(loss_func.backward())
                
        optim.step()  # update parameters
        optim.reset_grads()  # reset all gradients

        dt = time.time() - start
        print(f"step {step:4} | loss {loss:.4f} | dt {dt:.4f} s")

        if step == max_iter:
            break
        step+= 1