First, we import the transformer library and load a pretrained GPT-Neo model. This will result in a model which is an instance of *torch.nn.Module*.

In [1]:
import transformers
import torch
model_name="EleutherAI/gpt-neo-1.3B"
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
print(model)
print(isinstance(model, torch.nn.Module))
print(model.config)

  from .autonotebook import tqdm as notebook_tqdm


GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 2048)
    (wpe): Embedding(2048, 2048)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPTNeoBlock(
        (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
        )
        (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=2048, out_features=8192, bias=True)
          (c_proj):

Next we need a matching tokenizer which will also act as the encoder. We can generate a tokenizer using *AutoTokenizer.from_pretrained*. We will see that apart from the vocabulary and methods to encode and decode, the tokenizer also contains information on special token which in general should match that in the model configuration *model.config*.

In [2]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
print(tokenizer)

GPT2TokenizerFast(name_or_path='EleutherAI/gpt-neo-1.3B', vocab_size=50257, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})


In [3]:
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"UNK token: {tokenizer.unk_token} (ID = {tokenizer.unk_token_id})")
print(f"BOS token: {tokenizer.bos_token} (ID = {tokenizer.bos_token_id})")
print(f"EOS token: {tokenizer.eos_token} (ID = {tokenizer.eos_token_id})")

Vocabulary size: 50257
UNK token: <|endoftext|> (ID = 50256)
BOS token: <|endoftext|> (ID = 50256)
EOS token: <|endoftext|> (ID = 50256)


In [4]:
test = "This is a short test"
ids = tokenizer.encode(test)
print(ids)
for id in ids:
    print(f"Token ID {id} is {tokenizer.convert_ids_to_tokens(id)}")#
print(f"Result of decoding: {tokenizer.decode(ids)}")
print(tokenizer(test))

[1212, 318, 257, 1790, 1332]
Token ID 1212 is This
Token ID 318 is Ġis
Token ID 257 is Ġa
Token ID 1790 is Ġshort
Token ID 1332 is Ġtest
Result of decoding: This is a short test
{'input_ids': [1212, 318, 257, 1790, 1332], 'attention_mask': [1, 1, 1, 1, 1]}


Knowing how to decode and encode, we can now write a simple function to continue a prompt, as we have done it for our own transformer trained on the Wikipedia data set. The only difference that we have to observe is that the output of the model is not simply the logits, but a dictionary containing the logits and the past key values (we get to this point later).

In [5]:
out = model(torch.tensor(ids))
print(out.keys())
print(out.logits.shape)
#
# Past key values is an array with one entry per layer
# For each layer, the shape is B x H x L x head_dim
# see https://github.com/huggingface/transformers/blob/4e9f6fc67ce6290b3ab6efe2ddb1fcfc3e554382/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L230
#
print(len(out.past_key_values))
print(out.past_key_values[0][0].shape) # keys
print(out.past_key_values[0][1].shape) # values

odict_keys(['logits', 'past_key_values'])
torch.Size([5, 50257])
24
torch.Size([1, 16, 5, 128])
torch.Size([1, 16, 5, 128])


In [6]:
def do_p_sampling(p, p_val = 0.95):
    #
    # Apply top-p sampling (nucleus sampling)
    #
    items , indices = torch.sort(p, descending = True)    
    _k = max((torch.cumsum(items, dim = 0) <= p_val).to(int).sum().item(), 1)
    keep = indices[:_k]
    _p = [p[i] for i in keep]
    idx = torch.distributions.categorical.Categorical(probs = torch.tensor(_p)).sample().item()
    idx = keep[idx]
    return idx.item()

    
def predict(model, prompt, length, tokenizer, temperature = 0.7,  p_val = 0.95):
    model.eval()
    with torch.no_grad():
        sample = []
        device = next(model.parameters()).device
        #
        # Turn prompt into sequence of token IDs
        # 
        encoded_prompt  = tokenizer.encode(prompt)        
        encoded_sample = encoded_prompt
        encoded_prompt = torch.tensor(encoded_prompt, dtype = torch.long).unsqueeze(dim = 0)
        with torch.no_grad():
            out = model(encoded_prompt.to(device)).logits # shape B x L x V
            while (len(encoded_sample) < length):
                #
                # Sample next character from last output. Note that we need to remove the
                # batch dimension to obtain shape (L, V) and take the last element only
                #
                p = torch.nn.functional.softmax(out[0, -1, :] / temperature, dim = -1)
                #
                # Sample new index and append to encoded sample
                #
                encoded_sample.append(do_p_sampling(p, p_val))
                #
                # Feed new sequence
                #
                input = torch.tensor(encoded_sample[-model.config.max_position_embeddings:], dtype=torch.long)
                input = torch.unsqueeze(input, dim = 0)
                out = model(input.to(device)).logits
                print(tokenizer.decode(encoded_sample))

        return tokenizer.decode(encoded_sample)


In [7]:
predict(model = model, prompt = "My name is Joe and I am", length = 15, tokenizer = tokenizer)

My name is Joe and I am a
My name is Joe and I am a professional
My name is Joe and I am a professional baseball
My name is Joe and I am a professional baseball player
My name is Joe and I am a professional baseball player for
My name is Joe and I am a professional baseball player for the
My name is Joe and I am a professional baseball player for the St
My name is Joe and I am a professional baseball player for the St.


'My name is Joe and I am a professional baseball player for the St.'

In [8]:
prompt = "A long time ago"
encoded_prompt = tokenizer.encode(prompt)
input_ids = torch.tensor(encoded_prompt).unsqueeze(dim = 0)
print(input_ids.shape) # B x L where B = 1 and L = 4

torch.Size([1, 4])


In [9]:
out = model(input_ids = input_ids)
logits = out.logits
print(logits.shape) # B x L x V
past_key_values = out.past_key_values
print(past_key_values[0][0].shape) # B x H x L x head_dim 

torch.Size([1, 4, 50257])
torch.Size([1, 16, 4, 128])


In [10]:
#
# First append new token and run entire new sample through model. It does not matter which token we choose
#
encoded_sample = encoded_prompt
new_id = tokenizer.convert_tokens_to_ids(",")
encoded_sample.append(new_id)
input_ids = torch.tensor(encoded_sample).unsqueeze(dim = 0)
out1 = model(input_ids = input_ids)
print(out1.logits.shape) # B x (L + 1) x V

torch.Size([1, 5, 50257])


In [11]:
#
# Now try the same with the past key values
#
out2 = model(input_ids = torch.tensor(new_id).unsqueeze(dim = 0), past_key_values = past_key_values)
print(out2.logits.shape) # B x V
#
# the model will be kind and return the full past keys and values again - shape B x H x (L+1) x head_dim
# 
print(out2.past_key_values[0][0].shape) 
print(out2.logits[0, 1])
print(out1.logits[0, -1, 1])
V = model.config.vocab_size
assert torch.allclose(out1.logits[0, -1, :], out2.logits[0], rtol = 1e-2), "Logits do not match"

torch.Size([1, 50257])
torch.Size([1, 16, 5, 128])
tensor(-9.7990, grad_fn=<SelectBackward0>)
tensor(-9.7990, grad_fn=<SelectBackward0>)


Let us now use this pattern for a new, significantly more efficient prediction method. When you run this, you will find that the first pass through the model still takes some time, but the remaining passes are much faster and bearable, even on a CPU.

In [12]:
def predict(model, prompt, length, tokenizer, temperature = 0.7,  p_val = 0.95):
    model.eval()
    with torch.no_grad():
        sample = []
        device = next(model.parameters()).device
        #
        # Turn prompt into sequence of token IDs
        # 
        encoded_prompt  = tokenizer.encode(prompt)        
        encoded_sample = encoded_prompt
        input_ids = torch.tensor(encoded_prompt, dtype = torch.long).unsqueeze(dim = 0)
        with torch.no_grad():
            #
            # First forward pass- use full prompt
            #
            out = model(input_ids = input_ids.to(device))
            logits = out.logits[:, -1, :] # shape B x V
            past_key_values = out.past_key_values
            while (len(encoded_sample) < length):
                #
                # Sample next character from last output. Note that we need to remove the
                # batch dimension to obtain shape (L, V) and take the last element only
                #
                p = torch.nn.functional.softmax(logits[0, :] / temperature, dim = -1)
                #
                # Sample new index and append to encoded sample
                #
                idx = do_p_sampling(p, p_val)
                encoded_sample.append(idx)
                #
                # Feed new sequence
                #
                input_ids = torch.tensor(idx).unsqueeze(dim = 0)
                out = model(input_ids = input_ids.to(device), past_key_values = past_key_values)
                logits = out.logits
                past_key_values = out.past_key_values
                print(tokenizer.decode(encoded_sample))

        return tokenizer.decode(encoded_sample)


In [13]:
predict(model = model, prompt = "My name is Joe and I am", length = 25, tokenizer = tokenizer)

My name is Joe and I am a
My name is Joe and I am a senior
My name is Joe and I am a senior in
My name is Joe and I am a senior in high
My name is Joe and I am a senior in high school
My name is Joe and I am a senior in high school.
My name is Joe and I am a senior in high school. I
My name is Joe and I am a senior in high school. I am
My name is Joe and I am a senior in high school. I am not
My name is Joe and I am a senior in high school. I am not a
My name is Joe and I am a senior in high school. I am not a gang
My name is Joe and I am a senior in high school. I am not a gang member
My name is Joe and I am a senior in high school. I am not a gang member.
My name is Joe and I am a senior in high school. I am not a gang member. I
My name is Joe and I am a senior in high school. I am not a gang member. I am
My name is Joe and I am a senior in high school. I am not a gang member. I am not
My name is Joe and I am a senior in high school. I am not a gang member. I am not a
My name is Joe 

'My name is Joe and I am a senior in high school. I am not a gang member. I am not a murderer'