In [None]:
import compyute as cp

In [None]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("../tokenizer.json")

In [None]:
vocab_size = 8192
block_size = 1024
embed_dims = 384
device = cp.cuda

In [None]:
from gpt_transformer import GPTTransformer
from attention_s import get_causal_mask

mask = get_causal_mask((block_size, block_size))

model = GPTTransformer(
    n_embeddings=vocab_size,
    embedding_dim=embed_dims,
    ffwd_channels=4 * embed_dims,
    n_heads=6,
    n_blocks=6,
    max_seq_len=block_size,
    mask=mask,
)
model.to_device(device)
state_dict = cp.load("../transformer_wikitext_1_2500.cp")
model.load_state_dict(state_dict["model"])

In [None]:
context = "Hi, I am "
print(context, end="")

context = cp.tensor(tokenizer.encode(context).ids, dtype=cp.int32)  # encode context
context = context.to_shape((1, -1)).to_device(model.device)

for _ in range(300):
    logits = model(context)[0, -1].to_cpu()  # get logits
    probs = cp.nn.functional.softmax(logits)  # convert to probs
    topk_probs, topk_indices = cp.topk(probs, 50)  # get top 50 probs
    topk_probs /= cp.sum(topk_probs)  # normalize probs
    index = cp.random.multinomial(x=50, p=topk_probs, shape=(1,))  # sample
    index = topk_indices[index]  # get token id
    char = tokenizer.decode([index.item()])
    print(char, end="")
    context = cp.append(context, values=cp.reshape(index, shape=(1, 1)), axis=1).to_int()  # append to context
    context = context[:, -block_size:].to_device(device)