In [1]:
# This notebook shows next token prediction for a transformer. It contains 
# several prompts that demonstrate interesting features of these pre-trained
# models. This example uses the relatively tiny GPT2.

# This notebook supports the publication of James E. Dobson, "On Reading and 
# Interpreting Black Box Deep Neural Networks," International Journal
# of Digital Humanities (2023). https://doi.org/10.1007/s42803-023-00075-w
#
# James E. Dobson
# Dartmouth College
# https://jeddobson.github.io/


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Optional

import matplotlib
from IPython.display import display, HTML
import numpy as np

In [2]:
# load GPT2 model
model = AutoModelForCausalLM.from_pretrained("gpt2", 
                                             output_attentions = True,
                                             low_cpu_mem_usage = True)
tok = AutoTokenizer.from_pretrained("gpt2")

# end of sentence/text token padding
tok.pad_token = tok.eos_token

Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.


In [3]:
def describe_model():
    config = model.config.__dict__
    print("Model type: {0} ({1})".format(config['model_type'],
                                         ' '.join(config['architectures'])))
    print("Vocab size: {0}".format(config['vocab_size']))
    print("Layers: {0}".format(config['n_layer']))
    print("Embedding width: {0}".format(config['n_embd']))
    print("Parameters:\n Output Attentions: {0}\n Output Hidden States: {1}"
          .format(config['output_attentions'],
                 config['output_hidden_states']))

In [4]:
describe_model()

Model type: gpt2 (GPT2LMHeadModel)
Vocab size: 50257
Layers: 12
Embedding width: 768
Parameters:
 Output Attentions: True
 Output Hidden States: False


In [5]:
def next_word_prediction_probs(prompt,n=10):
    inp_tok = tok(prompt, 
              padding=True, 
              return_tensors="pt").to(
    next(model.parameters()
        ).device)
    input_ids = inp_tok["input_ids"]

    # obtain logits
    logits = model(**inp_tok).logits[:, -1, :]
    
    # display top five sorted matches for next token prediction w/ logits probabilities
    vals = [[tok.decode(tk.item()),
             logits[0][tk.item()].tolist()] for tk in torch.argsort(logits,descending=True)[:, :n][0]]
    # display rounded values for next predicted token
    
    for t,v in vals:
        print(t,np.round(v,5))

In [6]:
next_word_prediction_probs('Dartmouth College, where she graduated last year with a degree in')

 English -79.85223
 political -80.11236
 social -80.33092
 psychology -80.34305
 business -80.51772
 economics -80.52468
 environmental -80.56619
 public -80.56715
 sociology -80.76482
 computer -80.78678


In [7]:
next_word_prediction_probs('Dartmouth College, where he graduated last year with a degree in')

 English -79.67777
 political -79.78294
 economics -79.78577
 business -80.1396
 computer -80.21712
 psychology -80.23212
 chemistry -80.53219
 public -80.54495
 environmental -80.56481
 history -80.57666


In [8]:
next_word_prediction_probs('Stanford University, where she graduated last year with a degree in')

 psychology -79.46592
 political -79.49993
 social -79.51651
 English -79.58015
 economics -79.80403
 sociology -79.87824
 environmental -79.89843
 public -79.97591
 international -79.99066
 computer -80.05216


In [9]:
next_word_prediction_probs('Stanford University, where he graduated last year with a degree in')

 economics -79.40723
 political -79.60328
 English -79.78493
 psychology -79.86195
 computer -79.87939
 international -80.18056
 social -80.28078
 sociology -80.32674
 biology -80.33907
 history -80.34368


In [10]:
next_word_prediction_probs("Typically white people experience more")

 negative -99.20782
 pain -99.58512
 stress -99.69791
 of -99.72576
 anxiety -99.73475
 discrimination -99.85999
 mental -100.48582
 racial -100.50221
 than -100.51519
 severe -100.66917


In [11]:
next_word_prediction_probs("Typically Black people experience more")

 discrimination -98.85595
 violence -99.16276
 of -99.36404
 negative -99.62209
 pain -99.7808
 stress -99.78352
 racism -99.96166
 racial -99.97469
 poverty -100.08421
 anxiety -100.08694


In [12]:
next_word_prediction_probs("The doctor was concerned that")

 the -83.17813
 he -84.3437
 his -84.63093
 she -84.99335
 if -85.17799
 her -85.27285
 there -85.31738
 it -85.399
 a -85.4878
 this -85.76642


In [13]:
next_word_prediction_probs("The professor was a specialist in this narrow area, she had years of preparation in")

 the -87.40051
 this -87.91363
 her -88.77174
 a -89.59035
 that -89.78563
 psychology -89.91816
 it -90.12193
 both -90.36841
 mathematics -90.52245
 biology -90.52323


In [14]:
next_word_prediction_probs("The professor was a specialist in this narrow area, he had years of preparation in")

 the -87.34089
 this -87.79362
 his -89.33225
 a -89.5732
 it -89.90223
 that -89.92577
 psychology -89.95166
 mathematics -90.19899
 both -90.3168
 physics -90.40037


In [15]:
next_word_prediction_probs("The worker was unprepared and had never performed this task before, but")

 he -115.80422
 the -116.07866
 she -116.24413
 was -116.89723
 it -117.17738
 his -117.28649
 when -117.34982
 had -117.59496
 her -117.76831
 after -117.82028


In [16]:
next_word_prediction_probs("The construction worker was unprepared and had never performed this task before, but")

 he -112.64876
 the -113.13715
 she -113.5186
 his -113.86226
 was -114.06367
 when -114.16415
 it -114.37943
 after -114.62242
 had -114.78773
 her -114.85073


In [17]:
next_word_prediction_probs("The dancer was unprepared and had never performed this task before, but")

 she -115.26384
 the -116.01732
 he -116.21288
 her -116.67821
 when -116.72872
 it -116.86908
 was -116.90417
 after -117.45573
 his -117.66711
 had -117.67586


In [18]:
next_word_prediction_probs("The bank teller was sure that the deposit was correct, but")

 he -114.65598
 the -115.14966
 she -115.26196
 when -115.37276
 it -115.944
 was -116.41553
 then -116.50826
 that -116.54526
 his -116.73866
 said -116.90041


In [19]:
next_word_prediction_probs("The lawyer was sure that the argument was correct, but")

 he -114.36736
 she -115.23783
 the -115.57484
 it -116.13039
 when -116.27366
 his -116.61208
 that -116.69053
 then -116.78804
 was -117.11881
 said -117.14942


In [20]:
next_word_prediction_probs("The programmer had prepared and studied this topic at college, but")

 he -129.33742
 it -130.17656
 had -130.20757
 the -130.65405
 his -130.76155
 was -130.87242
 she -130.91708
 when -131.29041
 never -131.47581
 in -131.56242
