In [3]:
import torch

Generating text from output tokens

Let's implement the token-generation process as follows:
* step 1: idx is a (batch,n_tokens) array of indices in the current context
* step 2: crop current context if it exceeds the supported context size (E.g., if the LLM supports only 5 tokens, and the context size is 10 then only the last 5 tokens are used as context)
* step 3: Focus only on the last time step, so that (batch,n_token,vocab_size) becomes (batch,vocab_size)
* step 4: probas has shape (batch,vocab_size)
* step 5: idx_next has shape (batch,1)
* step 6: append sampled index to the running sequence, where idx has shape (batch,n_tokens + 1)

In [4]:
def generate_text_simple(model,idx,max_new_tokens,context_size):
    #idx is (batch,n_tokens) array of indices in the current context
    for _ in range(max_new_tokens):
        # crop current context if it exceeds the supported context size
        idx_cond = idx[:,-context_size:]
        # get the predictions
        with torch.no_grad():
            logits = model(idx_cond)
        # focus only on the last time step
        # (batch,n_tokens,vocab_size) becomes (batch,vocab_size)
        logits = logits[:,-1,:]
        # apply softmax to get probabilities
        probas = torch.softmax(logits,dim=-1) # (batch,vocab_size)
        # get the idx of the vocab entry with the highest probability value
        idx_next = torch.argmax(probas,dim=-1,keepdim=True)
        # append sampled index to the running sequence
        idx = torch.cat((idx,idx_next),dim=1) #(vatch,n_tokens+1)
    return idx