In [None]:
import compyute as cp

In [None]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("experiments/wikitext/wikitext_103_tokenizer_8192.json")

In [None]:
vocab_size = 8192
block_size = 1024
embed_dims = 384
device = cp.cuda

In [None]:
from transformer.gpt import GPTTransformer
from transformer.attention_funcs import get_causal_mask

mask = get_causal_mask((block_size, block_size))

model = GPTTransformer(
    n_embeddings=vocab_size,
    embedding_dim=embed_dims,
    ffwd_channels=4 * embed_dims,
    n_heads=6,
    n_blocks=6,
    max_seq_len=block_size,
    mask=mask,
)
model.to_device(device)
state_dict = cp.load("experiments/wikitext/transformer_wikitext_103_1_7500.cp")
model.load_state_dict(state_dict["model"])

In [None]:
context = "America was discovered by Christoph Columbus in the year "
print(context, end="")

context = cp.tensor(tokenizer.encode(context).ids, dtype=cp.int32)  # encode context
context = context.view((1, -1)).to_device(model.device)

for _ in range(50):
    logits = model(context)[0, -1].to_cpu()  # get logits
    probs = cp.nn.functional.softmax(logits)  # convert to probs
    index = cp.random.multinomial(x=len(probs), p=probs, shape=(1,))  # sample
    char = tokenizer.decode([index.item()])
    print(char, end="")
    context = cp.append(context, values=cp.reshape(index, shape=(1, 1)), dim=1).to_int()  # append to context
    context = context[:, -block_size:].to_device(device)