In [1]:
import torch
import torch.nn.functional as F

In [2]:
# logits (bsz, seq_len, num_classes)
logits = torch.tensor([
    [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
    [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]
])
print("logits.shape:", logits.shape)

labels = torch.tensor([
    [1, 2],  # First sequence: class 1 at position 0
    [0, 1]   # Second sequence: class 0 at position 0
])
print("labels.shape:", labels.shape)

torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

logits.shape: torch.Size([2, 2, 3])
labels.shape: torch.Size([2, 2])


tensor([[0.2000, 0.6000],
        [0.7000, 1.1000]])

In [3]:
uniform_probs = F.softmax(torch.ones_like(logits) / logits.shape[-1], dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
print("uniform_probs", uniform_probs)
kl = F.kl_div(log_probs, uniform_probs, reduction='none')
print("kl", kl.shape, kl)

uniform_probs tensor([[[0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333]],

        [[0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333]]])
kl torch.Size([2, 2, 3]) tensor([[[ 0.0344,  0.0011, -0.0322],
         [ 0.0344,  0.0011, -0.0322]],

        [[ 0.0344,  0.0011, -0.0322],
         [ 0.0344,  0.0011, -0.0322]]])


In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")

# generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
bsz = 2
seq_len = 4
vocab_size = 50257

fake_input_ids = torch.randint(0, vocab_size, (bsz, seq_len))  # Random input IDs for demonstration
print("fake_input_ids.shape:", fake_input_ids.shape)

fake_input_ids.shape: torch.Size([2, 4])


In [4]:
attention_mask = torch.ones(fake_input_ids.shape, dtype=torch.long)  # Dummy attention mask

outputs = model.generate(
    input_ids=fake_input_ids,
    max_length=10,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [5]:
outputs.sequences.shape   # (bsz, max_length)

torch.Size([2, 10])

In [7]:
print(f"len(outputs.scores) = {len(outputs.scores)}")  # 6 (max_length - input_length)
print(f"len(outputs.scores[0]) = {len(outputs.scores[0])}")  # 2 bsz
print(f"outputs.scores[0].shape = {outputs.scores[0].shape}")  # (bsz, vocab_size)
outputs.scores

len(outputs.scores) = 6
len(outputs.scores[0]) = 2
outputs.scores[0].shape = torch.Size([2, 50257])


(tensor([[-74.6001, -74.1407, -76.1368,  ..., -83.9544, -81.3293, -75.2622],
         [-73.9760, -72.4697, -73.7191,  ..., -82.8011, -83.0292, -72.1608]]),
 tensor([[-85.5014, -87.6018, -89.7055,  ..., -94.1085, -93.2828, -83.6890],
         [-58.2937, -56.6866, -57.5061,  ..., -63.5722, -65.5105, -51.6693]]),
 tensor([[ -99.6620, -100.3908, -101.2803,  ..., -109.6314, -108.2275,
          -100.0722],
         [ -71.7999,  -69.4696,  -71.1349,  ...,  -80.8000,  -79.5367,
           -67.8075]]),
 tensor([[ -82.9104,  -82.5029,  -84.1082,  ...,  -91.5918,  -88.9574,
           -82.3946],
         [-264.2556, -259.8985, -263.1606,  ..., -284.0302, -290.5381,
          -259.8674]]),
 tensor([[-100.0969, -100.9917,  -99.5622,  ..., -104.4349, -103.3674,
           -97.2489],
         [ -85.1759,  -79.4648,  -80.6355,  ...,  -96.3441,  -97.7259,
           -85.8978]]),
 tensor([[-120.5243, -119.7859, -123.6583,  ..., -125.0602, -121.0221,
          -119.0242],
         [ -81.6380,  -80.2057,

In [8]:
# outputs.scores is a tuple of length max_length - input_length
# Each element is a tensor of shape (bsz, vocab_size)

# Convert scores to tensor, shape (bsz, max_length - input_length, vocab_size)
scores_tensor = torch.stack(outputs.scores, dim=1)  # Shape: (bsz, max_length - input_length, vocab_size)
print("scores_tensor.shape:", scores_tensor.shape)

scores_tensor.shape: torch.Size([2, 6, 50257])
