Given a start text, let's list the next n most probable tokens (with their probabilities).

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from gptbench import Sample, empty_config

In [2]:
ben = Sample(seed=0xcabc0ffee)

# set config settings
cfg = empty_config()
cfg.model.set(dtype='bfloat16') # halve the memory requirements

ben.init_pretrained('gpt2', cfg) # 'gpt2' or 'gpt-xl' if your GPU can handle it

Initializing model from gpt2
Dataset: dummy 0 tokens
Dataset: loading uint16 tokens
Expanding initial dataset size of 1 (less than block_size+1) by 1025 times to size of 1025
Dataset train_path: dummy empty dataset, val_path: None, train_split: 0.9, vocab_size: 50257
Model params: 124.44M


In [3]:
# What are the most probable tokens after 'The sky is'
ben.model_next(10, text='The sky is')

[(0.1865234375, ' the', 262),
 (0.11328125, ' blue', 4171),
 (0.06884765625, ' falling', 7463),
 (0.041748046875, ' full', 1336),
 (0.041748046875, ' a', 257),
 (0.0252685546875, ' not', 407),
 (0.0252685546875, ' clear', 1598),
 (0.0252685546875, ' dark', 3223),
 (0.01531982421875, ' black', 2042),
 (0.00927734375, ' so', 523)]

Note that most words were tokenized to start with a space character, hence ' blue' in the returned values.

Returned tuple means: probability, token_text, token_id

Let's build a function to follow and list the top 3 probabilities along n steps/tokens:

In [4]:
def follow(text, steps):
    for s in range(steps):
        gen = ben.model_next(3, text=text)
        print(f"============================== Step {s+1}")
        print(f"'{text}' -> ", end='')
        for o in (gen):
            print(f"'{o[1]}' ({o[0]*100:.2f}%), ", end='')
        print()

        text += gen[0][1]

follow('The sky is', 5)

'The sky is' -> ' the' (18.65%), ' blue' (11.33%), ' falling' (6.88%), 
'The sky is the' -> ' limit' (96.48%), ' Limit' (0.39%), ' only' (0.39%), 
'The sky is the limit' -> '.' (20.80%), ',' (12.60%), '!' (7.67%), 
'The sky is the limit.' -> '
' (22.56%), ' The' (5.03%), ' I' (3.05%), 
'The sky is the limit.
' -> '
' (98.83%), 'The' (0.15%), 'I' (0.05%), 


Using the probability of the next token we can score which one from a number of options has the highest probability. For example:

'This sky is' ->

- ' blue'

- ' yellow'

By checking which token as the highest probability, we can choose the winner:

In [5]:
def choose(text, options):
    print(f"Prompt: '{text}'")
    tokens = []
    for o in options:
        enc = ben.train_dataset.encode(o)
        assert len(enc) == 1, f"Only single token options: '{o}' has {len(enc)} tokens"
        tokens.append(enc[0])
        
    probs = ben.model_probs(text=text)

    best_prob = 0
    for i,t in enumerate(tokens):
        prob = probs[t].item()
        option_text = options[i]
        print(f"- '{option_text}': prob={prob*100:.1f}%")
        
        if(prob > best_prob):
            best_prob = prob
            best_text = option_text

    print(f"Best choice: '{best_text}'")
    print('-->', text + best_text)

text = 'The sky is'
# note that we're prefixing with a space character:
options = [' blue', ' yellow']

choose(text, options)

Prompt: 'The sky is'
- ' blue': prob=11.3%
- ' yellow': prob=0.1%
Best choice: ' blue'
--> The sky is blue


In [6]:
text = 'The capital of Portugal is'
# note that we're prefixing with a space character:
options = [' Lisbon', ' Cuba', ' Paris', ' white']

choose(text, options)

Prompt: 'The capital of Portugal is'
- ' Lisbon': prob=9.0%
- ' Cuba': prob=0.0%
- ' Paris': prob=0.1%
- ' white': prob=0.0%
Best choice: ' Lisbon'
--> The capital of Portugal is Lisbon


In [7]:
text = '1+1='
# note that we're prefixing with a space character:
options = ['2', '3', '7']

choose(text, options)

Prompt: '1+1='
- '2': prob=3.4%
- '3': prob=2.1%
- '7': prob=0.5%
Best choice: '2'
--> 1+1=2


What if options encode into multiple tokens? In this case we could either use the mean of all tokens for the option or multiply sucessive generated token probabilities.

See the ../prompting/winogrande notebook for an example.