In [1]:
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Custom stopping criteria that stops when a specific token (e.g., period) is generated.
class StopOnToken(StoppingCriteria):
    def __init__(self, token_id):
        self.token_id = token_id

    def __call__(self, input_ids, scores, **kwargs):
        # Check if the last token equals the target token
        if input_ids[0, -1].item() == self.token_id:
            return True
        return False

    


In [3]:
# Load model and tokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
prompt = "Once upon a time"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

In [5]:
# Define the token id for the stop token
stop_token_id = tokenizer.convert_tokens_to_ids(".")

In [6]:
# Setup stopping criteria using our custom class
stopping_criteria = StoppingCriteriaList([StopOnToken(stop_token_id)])

In [7]:
# Generate text with scores output
outputs = model.generate(
    input_ids,
    max_length=100,
    stopping_criteria=stopping_criteria,
    output_scores=True,
    return_dict_in_generate=True
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [8]:
# Retrieve generated sequence and scores
generated_ids = outputs.sequences[0]
scores = outputs.scores # This is a list of tensors with shape (batch_size, vocab_size)

In [9]:
# Calculate entropy for each generation step
entropies = []
epsilon = 1e-10 # Small constant to avoid log(0)

In [10]:
for step_logits in scores:
    # Compute probability distribution
    probs = torch.softmax(step_logits, dim=-1)
    # Compute entropy for this step: H = -sum(p * log(p))
    entropy = -(probs * torch.log(probs + epsilon)).sum()
    entropies.append(entropy.item())

In [11]:
# Calculate average entropy over the steps
average_entropy = sum(entropies) / len(entropies) if entropies else 0

In [12]:
# Decode and print the generated text and average entropy
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

In [13]:
print("Generated text:\n", generated_text)
print("Average Entropy:\n", average_entropy)

Generated text:
 Once upon a time, the world was a place of great beauty and great danger.
Average Entropy:
 5.2147185160563545
