<a href="https://colab.research.google.com/github/colesmcintosh/entropy-injection-cot/blob/main/Chain_Of_Thought_via_Entropy_Injection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch bitsandbytes accelerate



In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from google.colab import userdata

In [None]:
# Specify the model name
model_name = 'meta-llama/Llama-3.2-1B-Instruct'

# Hugging Face authentication token
# Ensure you have accepted the license and have your token ready
hf_token = userdata.get('HF_TOKEN')

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)

# Create a BitsAndBytesConfig object for 8-bit quantization
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

# Load the model with 8-bit precision to save memory
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_auth_token=hf_token,
    device_map='auto',
    quantization_config=bnb_config,
)



In [None]:
def calculate_entropy(logits):
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    return entropy

In [None]:
def sample_next_token(logits, temperature=1.0, top_p=0.9):
    logits = logits.squeeze(0)
    logits = logits / temperature
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = False

    sorted_logits[sorted_indices_to_remove] = -float('Inf')
    probabilities = torch.softmax(sorted_logits, dim=-1)
    next_token = torch.multinomial(probabilities, num_samples=1)
    next_token_id = sorted_indices[next_token]
    return next_token_id.item()


In [None]:
def entropy_based_cot_injection_with_logging(prompt, entropy_threshold=4.0, max_length=150, max_cot_injections=1, cooldown_steps=5):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    generated_ids = input_ids.clone()

    entropies = []
    tokens_generated = []
    cot_injections = 0
    steps_since_cot = cooldown_steps  # Initialize to cooldown to allow immediate injection if needed

    # Use a more guiding CoT prompt
    cot_prompt = " To determine the answer, let's breakdown the problem step by step, then provide a final answer. "
    cot_ids = tokenizer.encode(cot_prompt, return_tensors='pt').to(model.device)

    with torch.no_grad():
        for step in range(max_length):
            # Prepare the model input
            if cot_injections > 0 and steps_since_cot <= cooldown_steps:
                # Include the CoT prompt in the model input but not in generated_ids
                model_input_ids = torch.cat((input_ids, cot_ids, generated_ids[:, input_ids.size(1):]), dim=1)
                attention_mask = torch.ones_like(model_input_ids)
            else:
                model_input_ids = generated_ids
                attention_mask = torch.ones_like(model_input_ids)

            outputs = model(input_ids=model_input_ids, attention_mask=attention_mask)
            next_token_logits = outputs.logits[:, -1, :]

            # Calculate entropy
            entropy = calculate_entropy(next_token_logits)
            entropies.append(entropy.item())

            # Check entropy threshold and cooldown
            if (
                entropy.item() > entropy_threshold
                and cot_injections < max_cot_injections
                and steps_since_cot >= cooldown_steps
            ):
                print(f"Step {step+1}: Entropy {entropy.item():.4f} exceeds threshold. Injecting CoT prompt.")
                cot_injections += 1
                steps_since_cot = 0
                continue  # Recalculate after injecting CoT

            # Generate next token
            next_token_id = sample_next_token(next_token_logits)
            next_token_id_tensor = torch.tensor([[next_token_id]], device=model.device)

            # Append the token to the generated sequence
            generated_ids = torch.cat((generated_ids, next_token_id_tensor), dim=1)
            tokens_generated.append(next_token_id)

            # Decode token for logging
            token_str = tokenizer.decode([next_token_id])
            print(f"Step {step+1}: Generated token: '{token_str}' | Entropy: {entropy.item():.4f}")

            # Check for end-of-text token
            if next_token_id == tokenizer.eos_token_id:
                print(f"Step {step+1}: End-of-text token generated. Stopping generation.")
                break

            steps_since_cot += 1

        # Decode the generated text, excluding the CoT prompt
        output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return output_text, entropies, tokens_generated


In [None]:
# Initial prompt
prompt = "Is 9.9 greater than 9.11?"

# Run the entropy-based CoT injection
output = entropy_based_cot_injection_with_logging(prompt)

print("Output:")
print(output)

Step 1: Generated token: ' Yes' | Entropy: 3.3201
Step 2: Generated token: ',' | Entropy: 1.4193
Step 3: Generated token: ' or' | Entropy: 2.1664
Step 4: Generated token: ' no' | Entropy: 0.9418
Step 5: Generated token: '.

' | Entropy: 1.5586
Step 6: Generated token: '##' | Entropy: 1.0052
Step 7: Generated token: ' Step' | Entropy: 0.0001
Step 8: Generated token: ' ' | Entropy: 0.0006
Step 9: Generated token: '1' | Entropy: 0.0011
Step 10: Generated token: ':' | Entropy: 0.0829
Step 11: Generated token: ' Determine' | Entropy: 2.3941
Step 12: Generated token: ' the' | Entropy: 0.4219
Step 13: Generated token: ' comparison' | Entropy: 3.6358
Step 14: Generated token: ' we' | Entropy: 2.6477
Step 15: Generated token: ' need' | Entropy: 0.6533
Step 16: Generated token: ' to' | Entropy: 0.0155
Step 17: Generated token: ' make' | Entropy: 0.0506
Step 18: Generated token: '.
' | Entropy: 0.7651
Step 19: Generated token: 'To' | Entropy: 0.4037
Step 20: Generated token: ' compare' | Entropy: