In [None]:
import compyute as cp

In [None]:
device = "cuda" if cp.engine.gpu_available() else "cpu"
device

In [None]:
cp.engine.set_cuda_tf32(True)
cp.random.set_seed(42)

In [None]:
from datasets import load_dataset

dataset = load_dataset(path="Salesforce/wikitext", name="wikitext-2-v1")
dataset

In [None]:
def get_training_corpus():
    for i in range(0, len(dataset), 1000):
        yield dataset["train"][i : i + 1000]["text"]

In [None]:
from tokenizers import (
    models,
    normalizers,
    pre_tokenizers,
    trainers,
    Tokenizer,
)
import os

file = "tokenizer.json"

if not os.path.exists(file):
    tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]", max_input_chars_per_word=1000000000))
    tokenizer.normalizer = normalizers.Sequence([normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()])
    pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.WhitespaceSplit(), pre_tokenizers.Punctuation()])
    special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
    trainer = trainers.WordPieceTrainer(vocab_size=8192, special_tokens=special_tokens, continuing_subword_prefix="")
    tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
    tokenizer.save("tokenizer.json")
else:
    tokenizer = Tokenizer.from_file("tokenizer.json")

In [None]:
def encode(split):
    lines = dataset[split]["text"]
    encodings = tokenizer.encode_batch(lines)
    token_id_lists = [encoding.ids for encoding in encodings]
    token_ids = [token_id for token_id_list in token_id_lists for token_id in token_id_list]

    return cp.tensor(token_ids).to_int()

In [None]:
train_data_enc = encode("train")
val_data_enc = encode("validation")

In [None]:
len(train_data_enc)

In [None]:
block_size = 256

In [None]:
X_train = cp.stack([train_data_enc[i * block_size : i * block_size + block_size] for i in range(len(train_data_enc) // block_size)])
y_train = cp.stack([train_data_enc[i * block_size + 1 : i * block_size + block_size + 1] for i in range(len(train_data_enc) // block_size)])

X_val = cp.stack([val_data_enc[i * block_size : i * block_size + block_size] for i in range(len(val_data_enc) // block_size)])
y_val = cp.stack([val_data_enc[i * block_size + 1 : i * block_size + block_size + 1] for i in range(len(val_data_enc) // block_size)])

print(f"{X_train.shape=}")
print(f"{y_train.shape=}")
print(f"{X_val.shape=}")
print(f"{y_val.shape=}")

In [None]:
import compyute.nn as nn
from transformer import Transformer, get_causal_mask

In [None]:
embed_dims = 384
mask = get_causal_mask((block_size, block_size))

model = Transformer(
    n_embeddings=tokenizer.get_vocab_size(),
    embedding_dim=embed_dims,
    feedforward_channels=4*embed_dims,
    n_heads=6,
    n_blocks=6,
    sequence_length=block_size,
    mask=mask,
    activation="gelu"
)

model.to_device(device)

In [None]:
summary = cp.nn.utils.get_module_summary(model, input_shape=(block_size,), input_dtype=cp.int32)
print(summary)

In [None]:
batch_size = 64
micro_batch_size = 16
grad_accumulation_steps = batch_size // micro_batch_size

train_dl = nn.utils.Dataloader(X_train, y_train, micro_batch_size, device=device)
val_dl = nn.utils.Dataloader(X_val, y_val, micro_batch_size, device=device)
loss_func = nn.CrossEntropy()
optim = nn.optimizers.AdamW(model.get_parameters(), lr=6e-4)
scheduler = nn.utils.lr_schedulers.CosineLrScheduler(optim, target_lr=3e-5, warmup_steps=125, decay_steps=1125)

In [None]:
val_interval = 50
checkpoint_interal = 250
step = 0

In [None]:
from datetime import datetime
import os

from compyute.nn.utils.tensorboard import SummaryWriter

max_iter = 2500

# create tensorboard logging directory
label = "transformer_wikitext_2_v2"
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
logdir = f"./runs/{label}_{timestamp}/"
if not os.path.exists(logdir):
    os.makedirs(logdir)

writer =  SummaryWriter(log_dir=logdir)
while step < max_iter:
    loss = 0
    accum_step = 0

    for x, y in train_dl():
        accum_step += 1
        
        # training
        with model.train():
            # forward pass
            y_pred = model(x)
            loss += loss_func(y_pred, y).item() / grad_accumulation_steps

            # backward pass
            loss_grads = loss_func.backward() / grad_accumulation_steps # scale by grad accumulation steps
            model.backward(loss_grads)  # compute new gradients
        
        if accum_step == 4:
            grad_norm = cp.nn.utils.clip_grad_norm(model.get_parameters(), 1.0)  # clip gradients
            scheduler.step()
            optim.step()  # update parameters
            optim.reset_grads()  # reset all gradients
            
            writer.add_scalar("train/loss", loss, step)
            writer.add_scalar("train/grad_norm", grad_norm, step)
            writer.add_scalar("train/lr", optim.lr, step)
            accum_step = 0
            step+= 1
            loss = 0

        # validation
        if step > 1 and step % val_interval == 0:
            val_loss = 0
            for x_val, y_val in val_dl():
                y_pred = model(x_val)
                val_loss += loss_func(y_pred, y_val).item()
            val_loss /= len(val_dl)
            writer.add_scalar("val/loss", val_loss, step)

        # save checkpoints
        if step > 1 and step % checkpoint_interal == 0:
            model_state = model.get_state_dict()
            optim_state = optim.get_state_dict()
            cp.save({"model": model_state, "optim": optim_state}, f"{label}_{step}.cp")

        if step == max_iter:
            break

In [None]:
state = {
    "model": model.get_state_dict(),
    "optim": optim.get_state_dict(),
    "step": step
}
cp.save(state, f"{label}_{step}.cp")

In [None]:
context = "The most common disease is"
print(context, end="")

context = cp.tensor(tokenizer.encode(context).ids, dtype=cp.int32)  # encode context
context = context.to_shape((1, -1)).to_device(model.device)

for _ in range(300):
    logits = model(context)[0, -1].to_cpu()  # get logits
    probs, _ = cp.nn.functional.softmax(logits)  # convert to probs
    topk_probs, topk_indices = cp.topk(probs, 50)  # get top 50 probs
    topk_probs /= cp.sum(topk_probs)  # normalize probs
    index = cp.random.multinomial(x=50, p=topk_probs, shape=(1,))  # sample
    index = topk_indices[index]  # get token id
    char = tokenizer.decode([index.item()])
    print(char, end="")
    context = cp.append(context, values=cp.reshape(index, shape=(1, 1)), axis=1).to_int()  # append to context
    context = context[:, -block_size:].to_device(device)