In [3]:
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer

class LMHeadModel:

    def __init__(self, model_name):
        # Initialize the model and the tokenizer.
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def get_predictions(self, sentence):
        # Encode the sentence using the tokenizer and return the model predictions.
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        return predictions
    
    def get_next_word_probabilities(self, sentence, top_k=500):

        # Get the model predictions for the sentence.
        predictions = self.get_predictions(sentence)
        
        # Get the next token candidates.
        next_token_candidates_tensor = predictions[0, -1, :]

        # Get the top k next token candidates.
        topk_candidates_indexes = torch.topk(
            next_token_candidates_tensor, top_k).indices.tolist()

        # Get the token probabilities for all candidates.
        all_candidates_probabilities = torch.nn.functional.softmax(
            next_token_candidates_tensor, dim=-1)
        
        # Filter the token probabilities for the top k candidates.
        topk_candidates_probabilities = \
            all_candidates_probabilities[topk_candidates_indexes].tolist()

        # Decode the top k candidates back to words.
        topk_candidates_tokens = \
            [self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]

        # Return the top k candidates and their probabilities.
        return list(zip(topk_candidates_tokens, topk_candidates_probabilities))




  from .autonotebook import tqdm as notebook_tqdm


In [4]:

sentence = "I enjoy walking in the"
model = LMHeadModel("gpt2")
model.get_next_word_probabilities(sentence, top_k=500)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


[('park', 0.15904049575328827),
 ('woods', 0.10028108954429626),
 ('streets', 0.04183783754706383),
 ('dark', 0.031174374744296074),
 ('door', 0.02961907349526882),
 ('street', 0.02388927899301052),
 ('rain', 0.021734017878770828),
 ('city', 0.018898695707321167),
 ('same', 0.01503657829016447),
 ('halls', 0.013454659841954708),
 ('field', 0.012773651629686356),
 ('middle', 0.012384142726659775),
 ('garden', 0.010566969402134418),
 ('neighborhood', 0.010260550305247307),
 ('snow', 0.009522601962089539),
 ('forest', 0.009171221405267715),
 ('parks', 0.009017538279294968),
 ('open', 0.00845408346503973),
 ('world', 0.0075866952538490295),
 ('hallway', 0.006888187490403652),
 ('shoes', 0.00647016242146492),
 ('footsteps', 0.0062239086255431175),
 ('hall', 0.0059839216992259026),
 ('room', 0.005687463562935591),
 ('sun', 0.005530762020498514),
 ('doors', 0.004972364753484726),
 ('house', 0.00484781339764595),
 ('yard', 0.004810491111129522),
 ('sand', 0.004204212222248316),
 ('mud', 0.0038