In [None]:
import compyute as cp
import compyute.nn.functional as F
from simple_tokenizers import CharacterTokenizer

In [None]:
import requests

DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(DATA_URL)
data = response.text

chars = sorted(list(set(data)))
tokenizer = CharacterTokenizer()
tokenizer.vocab = {i: c for i, c in enumerate(chars)}
tokenizer.ivocab = {c: i for i, c in enumerate(chars)}

In [None]:
vocab_size = tokenizer.vocab_size
block_size = 256
embed_dims = 384
device = cp.cuda

In [None]:
from transformer.gpt import GPTTransformer
from transformer.attention_utils import get_causal_mask

mask = get_causal_mask(block_size)

model = GPTTransformer(
    n_embeds=vocab_size,
    embed_dim=embed_dims,
    mlp_channels=4 * embed_dims,
    n_heads=6,
    n_blocks=6,
    max_context_len=block_size,
    mask=mask,
)

In [None]:
state_dict = cp.load("transformer_shakespeare_8_2500.cp")
model.load_state_dict(state_dict["model"], target_device=device)

In [None]:
context = "Hello, my name is"
print(context, end="")
context = tokenizer.encode(context)  # encode context
context = cp.tensor(context, device, cp.int32).view((1, -1))  # insert batch dim

for _ in range(500):
    logits = model(context)[0, -1].to_cpu()  # get logits of last token
    probs = F.softmax(logits)  # convert to probs
    topk_probs, topk_indices = cp.topk(probs, 50)  # get top 50 probs
    topk_probs /= topk_probs.sum()  # normalize probs
    index = cp.random.multinomial(x=50, p=topk_probs, shape=(1,))  # sample
    index = topk_indices[index]  # get token id
    char = tokenizer.decode(index.to_list())
    print(char, end="")
    index = index.view((1, 1)).to_device(device)
    context = cp.append(context[:, 1:], values=index, dim=1).to_int()  # append to previous context