In [1]:
!pip install transformers==4.31.0
!pip install torch




Collecting transformers==4.31.0
  Downloading transformers-4.31.0-py3-none-any.whl.metadata (116 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.31.0)
  Downloading tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl.metadata (6.7 kB)
Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl (3.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.15.0
    Uninstalling tokenizers-0.15.0:
      Successfully uninstalled tokenizers-0.15.0
  Attempting uninstall: transformers
    Found existing installation: transformers 4.36.2
    Uninstalling transformers-

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


#### GPT 2 WITH ATTENTION OUTPUTS ####

with this we're registering forward hooks on each attention layer to capture the attention outputs


In [3]:
class GPT2WithAttentionOutputs(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.attention_scores = []

    def reset_attention_scores(self):
        self.attention_scores = []

    def save_attention_scores(self, module, input, output):
        self.attention_scores.append(output[1]) # output[1] contains the attention probabilities

    def forward(self, input_ids, past_key_values=None, attention_mask=None):
        self.reset_attention_scores()

        # use hooks to capture attention outputs
        hooks = []
        for block in self.transformer.h:
            hook = block.attn.register_forward_hook(self.save_attention_scores)
            hooks.append(hook)
        outputs = super().forward(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            output_attentions=True,
            return_dict=True,
        )
        for hook in hooks:
            hook.remove()
        return outputs


In [59]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 7.03MB/s]


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [45]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', output_attentions=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

#### metric calculations ####

Entropy: Measures the uncertainty in the model's predictions

Varentropy: Measures the variance of entropy across different positions

Agreement: We measure how consistent the attention patterns are across different heads

Interaction Strength: Measures the mean absolute attention values



In [77]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Initialize model with output_attentions=True
model = GPT2LMHeadModel.from_pretrained('gpt2', output_attentions=True)

# Move model to the appropriate device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [78]:
def calculate_entropy_and_varentropy(logits):
    probs = F.softmax(logits, dim=-1) # get probabilities from logits using softmax

    # entropy is the average of the negative sum of the probabilities times the log2 of the probabilities (measures the uncertainty of the model's predictions)
    entropy = -torch.sum(probs * torch.log2(probs + 1e-10), dim=-1)

    # varentropy is the variance of the entropy -- this is across different positions 
    varentropy = torch.var(entropy)
    return entropy.mean(), varentropy


We're computing the attention entropy and varentropy to effectively understand the model's focus

In [79]:
def calculate_attention_entropy_and_varentropy(attention_scores):
    attention_entropies = []
    for attn in attention_scores:
        # attn shape: (batch_size, num_heads, seq_len, seq_len)
        attn_probs = attn  # This is already the attention probabilities
        # Compute entropy over the source sequence (last dimension)
        entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-10), dim=-1)  # Shape: (batch_size, num_heads, seq_len)
        # Mean over batch and heads
        entropy = entropy.mean(dim=(0, 1))  # Shape: (seq_len,)
        # Mean over positions
        entropy = entropy.mean()  # Scalar
        attention_entropies.append(entropy)
    
    attention_entropies = torch.stack(attention_entropies)  # Shape: (num_layers,)
    attention_entropy = attention_entropies.mean()
    attention_varentropy = attention_entropies.var()
    return attention_entropy, attention_varentropy

In [80]:
def calculate_attention_agreement(attention_scores):
    agreements = []
    for attn in attention_scores:
        # attn shape: (batch_size, num_heads, seq_len, seq_len)
        mean_attention = attn.mean(dim=1, keepdim=True)  # Mean over heads
        agreement = torch.abs(attn - mean_attention).mean(dim=(0, 1, 2, 3))  # Scalar
        agreements.append(agreement)
    attention_agreement = torch.stack(agreements).mean()
    return attention_agreement


In [81]:
def calculate_interaction_strength(attention_scores):
    strengths = []
    for attn in attention_scores:
        # attn shape: (batch_size, num_heads, seq_len, seq_len)
        strength = torch.abs(attn).mean(dim=(0, 1, 2, 3))  # Scalar
        strengths.append(strength)
    interaction_strength = torch.stack(strengths).mean()
    return interaction_strength


#### sampling

In [90]:
def greedy_sampling(logits):
    next_token = torch.argmax(logits, dim=-1, keepdim=True)
    return next_token  # Shape: (batch_size, 1)


In [89]:
CLARIFICATION_TOKEN = "<clarify>"

def clarification_insertion(tokenizer, generated_tokens):
    clarification_token_id = tokenizer.encode(CLARIFICATION_TOKEN, add_special_tokens=False)[0]
    batch_size = generated_tokens.size(0)
    clarification_token_ids = torch.full(
        (batch_size, 1),
        clarification_token_id,
        device=generated_tokens.device,
        dtype=torch.long
    )
    return clarification_token_ids  # Shape: (batch_size, 1)


##### Exploration Sampling

In [91]:
def exploration_sampling(logits, temperature=1.0, top_k=50):
    logits = logits / temperature
    top_k_logits, top_k_indices = torch.topk(logits, k=top_k)
    probs = F.softmax(top_k_logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)  # Shape: (batch_size, 1)
    next_token = top_k_indices.gather(-1, next_token)
    return next_token  # Shape: (batch_size, 1)


In [92]:
def high_uncertainty_sampling(logits, temperature=1.5, top_p=0.9):
    logits = logits / temperature
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))
    probs = F.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)  # Shape: (batch_size, 1)
    return next_token


In [86]:
def adaptive_sampling(logits, metrics, num_samples=12, base_temperature=1.0, base_top_p=0.9, base_top_k=50):
    # Adjust parameters based on metrics
    temperature = base_temperature * (1 + 0.3 * metrics['logits_uncertainty'] + 0.2 * metrics['attention_uncertainty'] - 0.2 * metrics['attention_agreement'])
    temperature = max(0.5, min(temperature, 1.5))

    top_p = base_top_p * (1 + 0.1 * metrics['attention_varentropy'])
    top_p = max(0.1, min(top_p, 1.0))

    top_k = int(base_top_k * (1 + 0.3 * metrics['interaction_strength'] - 0.2 * metrics['attention_agreement']))
    top_k = max(5, min(top_k, logits.size(-1)))

    # Generate multiple samples
    logits = logits / temperature
    next_tokens = []
    scores = []

    for _ in range(num_samples):
        filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
        probs = F.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_tokens.append(next_token)
        log_prob = torch.log(probs.gather(-1, next_token))
        confidence_score = calculate_confidence_score(metrics)
        score = log_prob + confidence_score
        scores.append(score)

    # we want to select the best sample
    best_index = torch.argmax(torch.stack(scores))
    best_next_token = next_tokens[best_index]
    return best_next_token

def top_k_top_p_filtering(logits, top_k=50, top_p=0.9, filter_value=-float('Inf')):
    # top-k filtering - top-k filtering is a technique used to limit the number of tokens that can be selected by the model
    if top_k > 0:
        values_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits = logits.masked_fill(values_to_remove, filter_value)

    # top-p filtering - top-p filtering is a technique used to limit the maximum probability of a token being selected
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        indices_to_remove = cumulative_probs > top_p
        indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
        indices_to_remove[..., 0] = 0

        sorted_indices_to_remove = indices_to_remove.scatter(1, sorted_indices, indices_to_remove)
        logits = logits.masked_fill(sorted_indices_to_remove, filter_value)
    return logits

def calculate_confidence_score(metrics):
    confidence_score = (
        (1 - metrics["logits_entropy"]) * 0.1 +
        (1 - metrics["attention_entropy"]) * 0.2 +
        (1 - metrics["logits_varentropy"]) * 0.3 +
        (1 - metrics["attention_varentropy"]) * 0.4 +
        metrics["attention_agreement"] * 0.5 +
        metrics["interaction_strength"] * 0.6
    )
    return confidence_score


### Dynamic Parameter Adjustment
We adjust sampling parameters based on the calculated metrics to dynamically adapt to the model's behavior

In [69]:
def adjust_parameters(metrics, base_params):
    temperature = base_params['temperature'] * (1 + 0.3 * metrics['logits_uncertainty'] + 0.2 * metrics['attention_uncertainty'] - 0.2 * metrics['attention_agreement'])
    temperature = max(0.5, min(temperature, 1.5))

    top_p = base_params['top_p'] * (1 + 0.1 * metrics['attention_varentropy'])
    top_p = max(0.1, min(top_p, 1.0))
    top_k = int(base_params['top_k'] * (1 + 0.3 * metrics['interaction_strength'] - 0.2 * metrics['attention_agreement']))
    top_k = max(5, min(top_k, 100))
    min_p = base_params['min_p'] * (1 - 0.5 * metrics['logits_uncertainty'])
    min_p = max(0.01, min(min_p, 0.5))
    adjusted_params = {
        'temperature': temperature,
        'top_p': top_p,
        'top_k': top_k,
        'min_p': min_p
    }
    return adjusted_params


### Entropix Sampling Function
now just does the text generation using Entropix from the blog post

In [94]:
def entropix_generate(
    model, tokenizer, input_text, max_length=50, base_params=None, num_samples=12
):
    if base_params is None:
        base_params = {
            'temperature': 1.0,
            'top_p': 0.9,
            'top_k': 50,
            'min_p': 0.01
        }
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_ids = input_ids.to(device)
    model.to(device)

    generated = input_ids  # Shape: (batch_size, seq_len)
    past_key_values = None

    # Initialize a list to store metrics history (optional)
    metrics_history = []

    for _ in range(max_length):
        # Model outputs with attentions
        outputs = model(
            input_ids=generated,
            past_key_values=past_key_values,
            attention_mask=None,
            output_attentions=True,
        )
        logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
        past_key_values = outputs.past_key_values

        # Get attention scores from outputs.attentions
        attention_scores = outputs.attentions  # Tuple of (num_layers) tensors

        # Compute metrics
        logits_entropy, logits_varentropy = calculate_entropy_and_varentropy(logits)
        attention_entropy, attention_varentropy = calculate_attention_entropy_and_varentropy(attention_scores)
        attention_agreement = calculate_attention_agreement(attention_scores)
        interaction_strength = calculate_interaction_strength(attention_scores)
        metrics = {
            'logits_entropy': logits_entropy.item(),
            'logits_varentropy': logits_varentropy.item(),
            'attention_entropy': attention_entropy.item(),
            'attention_varentropy': attention_varentropy.item(),
            'attention_agreement': attention_agreement.item(),
            'interaction_strength': interaction_strength.item(),
            'logits_uncertainty': logits_entropy.item() + logits_varentropy.item(),
            'attention_uncertainty': attention_entropy.item() + attention_varentropy.item(),
        }

        # Store metrics (optional, for plotting or analysis)
        metrics_history.append(metrics)

        # Sampling strategy based on metrics
        if logits_entropy.item() < 0.1 and logits_varentropy.item() < 0.1:
            next_token = greedy_sampling(logits)
        elif logits_entropy.item() > 3.0 and logits_varentropy.item() < 0.1:
            next_token = clarification_insertion(tokenizer, generated)
        elif logits_entropy.item() < 2.0 and logits_varentropy.item() > 5.0:
            adjusted_params = adjust_parameters(metrics, base_params)
            next_token = exploration_sampling(
                logits,
                temperature=adjusted_params['temperature'],
                top_k=adjusted_params['top_k']
            )
        elif logits_entropy.item() > 5.0 and logits_varentropy.item() > 5.0:
            # High uncertainty sampling
            adjusted_params = adjust_parameters(metrics, base_params)
            next_token = high_uncertainty_sampling(
                logits,
                temperature=adjusted_params['temperature'],
                top_p=adjusted_params['top_p']
            )
        else:
            # Adaptive sampling
            adjusted_params = adjust_parameters(metrics, base_params)
            next_token = adaptive_sampling(
                logits,
                metrics,
                num_samples=num_samples,
                base_temperature=adjusted_params['temperature'],
                base_top_p=adjusted_params['top_p'],
                base_top_k=adjusted_params['top_k']
            )

        generated = torch.cat((generated, next_token), dim=1)
        if next_token.item() == tokenizer.eos_token_id:
            break

    output_text = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
    return output_text

Example 1: Low Uncertainty Generation

In [97]:
input_text = "the place to visit in MA is "
output_text = entropix_generate(model, tokenizer, input_text, max_length=10)
print(output_text)

the place to visit in MA is  the place to visit in MA is  


In [102]:
import matplotlib.pyplot as plt
def plot_metrics(metrics_history):
    logits_entropy = [m['logits_entropy'] for m in metrics_history]
    logits_varentropy = [m['logits_varentropy'] for m in metrics_history]
    attention_entropy = [m['attention_entropy'] for m in metrics_history]
    attention_varentropy = [m['attention_varentropy'] for m in metrics_history]
    attention_agreement = [m['attention_agreement'] for m in metrics_history]
    interaction_strength = [m['interaction_strength'] for m in metrics_history]

    plt.figure(figsize=(12, 8))

    plt.subplot(2, 3, 1)
    plt.plot(logits_entropy)
    plt.title('Logits Entropy')

    plt.subplot(2, 3, 2)
    plt.plot(logits_varentropy)
    plt.title('Logits Varentropy')

    plt.subplot(2, 3, 3)
    plt.plot(attention_entropy)
    plt.title('Attention Entropy')

    plt.subplot(2, 3, 4)
    plt.plot(attention_varentropy)
    plt.title('Attention Varentropy')

    plt.subplot(2, 3, 5)
    plt.plot(attention_agreement)
    plt.title('Attention Agreement')

    plt.subplot(2, 3, 6)
    plt.plot(interaction_strength)
    plt.title('Interaction Strength')

    plt.tight_layout()
    plt.show()


In [100]:
input_text = "In a distant future, humanity has colonized Mars and"
output_text, metrics_history = entropix_generate(model, tokenizer, input_text, max_length=50)
print(output_text)



IndexError: index out of range in self

In [72]:
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Initialize your custom model that captures attention outputs
model = GPT2WithAttentionOutputs.from_pretrained('gpt2', output_attentions=True)

# Move model to appropriate device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

GPT2WithAttentionOutputs(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)