In [1]:
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load BLOOM model and tokenizer
model_name = "bigscience/bloomz-560m"  # Model that works better for Q/A
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = BloomForCausalLM.from_pretrained(model_name).eval()

In [23]:
# Define text for interpretation
input_text = "What are the 5 tallest mountains in the world?" 
# Tokenize the input text and get the input tensor
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=20,
                        temperature=0.05,
                        top_p=0.9,#nucleus sampling
                        do_sample=True )
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)


What are the 5 tallest mountains in the world? Mount Kilimanjaro, Mount Everest, Mount Kilimanjaro, Mount Kilimanjaro, and Mount Kilimanjaro


In [24]:

# Tokenize the initial prompt
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"]

# Number of new tokens to generate
max_new_tokens = 10
generated_tokens = []
top_logits_per_token = []

# Loop to generate each new token one at a time and get top 5 logits
for _ in range(max_new_tokens):
    # Pass the current input through the model to get logits
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    
    # Get logits for the last token in the sequence
    logits = outputs.logits[0, -1, :]  # Shape [vocab_size]
    
    # Get the top 5 tokens and their logits
    top_k_logits, top_k_indices = torch.topk(logits, k=5)
    
    # Decode the top 5 tokens and clean up the Ġ prefix
    top_k_tokens = [
        tokenizer.decode([token_id]).replace("Ġ", "") for token_id in top_k_indices
    ]
    
    # Store the top 5 logits and corresponding tokens
    top_logits_per_token.append(list(zip(top_k_tokens, top_k_logits.tolist())))
    
    # Select the most probable token (argmax) as the next token
    next_token_id = top_k_indices[0].unsqueeze(0)
    generated_tokens.append(next_token_id.item())
    
    # Update input_ids with the new token for the next iteration
    input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)

# Decode the generated tokens and remove any remaining prefixes for clarity
generated_text = tokenizer.decode(generated_tokens).replace("Ġ", "")

# Display the results
print(f"Generated Text: {generated_text}\n")
print("Top 5 Logits for Each New Token:")
for i, token_info in enumerate(top_logits_per_token):
    print(f"\nToken {i + 1}: '{tokenizer.decode([generated_tokens[i]]).replace('Ġ', '')}'")
    for token, logit in token_info:
        print(f"    {token}: {logit:.4f}")


Generated Text:  Mount Kilimanjaro, Mount Everest, Mount Kilimanjaro, Mount

Top 5 Logits for Each New Token:

Token 1: ' Mount'
     Mount: 397.5774
     Mt: 396.3606
     Everest: 396.0087
     Kilimanjaro: 395.5673
     Alps: 395.3434

Token 2: ' Kilimanjaro'
     Kilimanjaro: 388.0668
     Everest: 387.4675
     Rush: 385.2826
     Kin: 383.7083
     Kenya: 383.6555

Token 3: ','
    ,: 419.4991
    </s>: 419.4828
     and: 417.3188
     (: 416.9478
     in: 415.8520

Token 4: ' Mount'
     Mount: 403.3391
     Kilimanjaro: 401.9385
     Tanzania: 400.4769
     Mt: 400.3781
     the: 400.2496

Token 5: ' Everest'
     Everest: 381.4550
     Kilimanjaro: 381.0041
     Et: 380.6830
     Kenya: 379.9611
     Rush: 379.8688

Token 6: ','
    ,: 424.3529
    </s>: 420.8809
     and: 419.6128
     (: 418.7540
     ,: 416.8163

Token 7: ' Mount'
     Mount: 402.5578
     Mt: 400.0130
     Kilimanjaro: 398.9368
     the: 398.7545
     Everest: 397.6899

Token 8: ' Kilimanjaro'
     Kiliman