# Reasoning Step Pruning via Attention Scores

**Based on:**
1. [Think Clearly: Improving Reasoning via Redundant Token Pruning](https://arxiv.org/abs/2507.08806)
2. [TRAAC: Think Right with Adaptive, Attentive Compression](https://arxiv.org/abs/2510.01581)

**Goal:** Prune reasoning steps in LLM chain-of-thought based on attention importance scores, and measure how it affects math problem-solving accuracy on AIME24.

**Author:** Naveen Pasupuleti

---

### Key Idea

Reasoning LLMs generate long chain-of-thought traces with significant redundancy. By analyzing attention patterns:
- Steps that receive **high attention** from subsequent tokens are **critical** for the final answer
- Steps that receive **low attention** are **redundant/distracting** and can be pruned
- Moderate pruning can actually **improve** accuracy by reducing distraction ("Think Clearly" finding)
- We experiment across multiple pruning thresholds to find the sweet spot

In [7]:
# @title AI prompt cell

import ipywidgets as widgets
from IPython.display import display, HTML, Markdown,clear_output
from google.colab import ai

dropdown = widgets.Dropdown(
    options=[],
    layout={'width': 'auto'}
)

def update_model_list(new_options):
    dropdown.options = new_options
update_model_list(ai.list_models())

text_input = widgets.Textarea(
    placeholder='Ask me anything....',
    layout={'width': 'auto', 'height': '100px'},
)

button = widgets.Button(
    description='Submit Text',
    disabled=False,
    tooltip='Click to submit the text',
    icon='check'
)

output_area = widgets.Output(
     layout={'width': 'auto', 'max_height': '300px','overflow_y': 'scroll'}
)

def on_button_clicked(b):
    with output_area:
        output_area.clear_output(wait=False)
        accumulated_content = ""
        for new_chunk in ai.generate_text(prompt=text_input.value, model_name=dropdown.value, stream=True):
            if new_chunk is None:
                continue
            accumulated_content += new_chunk
            clear_output(wait=True)
            display(Markdown(accumulated_content))

button.on_click(on_button_clicked)
vbox = widgets.GridBox([dropdown, text_input, button, output_area])

display(HTML("""
<style>
.widget-dropdown select {
    font-size: 18px;
    font-family: "Arial", sans-serif;
}
.widget-textarea textarea {
    font-size: 18px;
    font-family: "Arial", sans-serif;
}
</style>
"""))
display(vbox)


GridBox(children=(Dropdown(layout=Layout(width='auto'), options=('google/gemini-2.5-flash', 'google/gemini-2.5…

In [8]:
!nvidia-smi

Wed Feb 25 05:21:56 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P8             11W /   70W |       3MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

## 0. Setup

In [1]:
!pip install transformers accelerate bitsandbytes torch sentencepiece protobuf datasets -q

In [2]:
import os, re, json, time, torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, field
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import warnings
warnings.filterwarnings('ignore')

print(f'GPU Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

GPU Available: True
GPU: Tesla T4
Memory: 15.6 GB


## 1. AIME24 Dataset

In [3]:
AIME24_PROBLEMS = [
    {"id": 1, "problem": "Every morning Asha decides randomly whether to walk left or right and independently Sasha does the same. They start on different ends of a 4-block long street. What is the probability that they meet? Express your answer as a fraction m/n in lowest terms and find m+n.", "answer": 31},
    {"id": 2, "problem": "There exist real numbers x and y, both greater than 1, such that log_x(y^x) = log_y(x^(4y)) = 10. Find xy.", "answer": 25},
    {"id": 3, "problem": "Alice and Bob play a game. Alice starts first and they alternate turns. Alice's move is to choose an integer from 1 to 6 (inclusive) and add it to the running total. Bob does the same. The player who brings the running total to exactly 2024 wins. What is the smallest starting move Alice can use to guarantee a win?", "answer": 5},
    {"id": 4, "problem": "Let x, y, and z be positive real numbers satisfying the system: log_2(x/yz) = 1/2, log_2(y/xz) = 1/3, log_2(z/xy) = 1/4. Find the value of |log_2(x^4 y^3 z^2)|. Express as a fraction p/q in lowest terms and find p+q.", "answer": 33},
    {"id": 5, "problem": "Rectangle ABCD has side lengths AB=10 and BC=4. Point M is the midpoint of CD. Triangle AMB is removed, leaving a quadrilateral. Find the perimeter of the quadrilateral formed. If the answer is a+b*sqrt(c), find a+b+c.", "answer": 18},
]

print(f'Loaded {len(AIME24_PROBLEMS)} AIME24 problems')

Loaded 5 AIME24 problems


## 2. Load Quantized Reasoning Model

In [4]:
MODEL_NAME = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=True,
    attn_implementation='eager',
)
model.eval()
print('Model loaded!')

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]



model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/339 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

Model loaded!


## 3. Core Functions

In [5]:
def generate_reasoning(prompt, max_tokens=4096):
    """Generate a full reasoning chain."""
    messages = [{'role': 'user', 'content': prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(model.device)
    input_len = inputs['input_ids'].shape[1]

    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.6, top_p=0.95, do_sample=True)

    return tokenizer.decode(out[0][input_len:], skip_special_tokens=True)


def segment_steps(text):
    """Split reasoning into discrete steps."""
    think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
    thinking = think_match.group(1) if think_match else text
    answer = text[think_match.end():] if think_match else ''

    chunks = re.split(r'\n\n+', thinking)
    steps = []
    for chunk in chunks:
        chunk = chunk.strip()
        if len(chunk) >= 10:
            steps.append({'text': chunk, 'step_id': len(steps)})

    return steps, thinking, answer


def score_steps_by_attention(steps, full_text, prompt):
    """Score each step by how much attention the final tokens pay to it."""
    messages = [
        {'role': 'user', 'content': prompt},
        {'role': 'assistant', 'content': full_text}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=4096).to(model.device)

    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # Average attention from last token across all layers/heads
    last_attn = torch.stack([l[:, :, -1, :] for l in outputs.attentions]).mean(dim=(0, 2))[0].cpu().numpy()

    # Find where assistant response starts
    assistant_marker = '<|im_start|>assistant'
    marker_pos = text.find(assistant_marker)
    offset = len(tokenizer.encode(text[:marker_pos + len(assistant_marker) + 1])) if marker_pos != -1 else 0

    for step in steps:
        start = full_text.find(step['text'])
        if start == -1:
            step['importance'] = 0.0
            continue

        prefix_toks = len(tokenizer.encode(full_text[:start], add_special_tokens=False))
        step_toks = len(tokenizer.encode(step['text'], add_special_tokens=False))

        t_start = max(0, min(offset + prefix_toks, len(last_attn) - 1))
        t_end = max(t_start + 1, min(offset + prefix_toks + step_toks, len(last_attn)))

        step['importance'] = float(np.sum(last_attn[t_start:t_end]))
        step['num_tokens'] = step_toks

    # Normalize
    total = sum(s['importance'] for s in steps)
    if total > 0:
        for s in steps:
            s['importance'] /= total

    return steps


def prune_steps(steps, thinking, answer, threshold):
    """Remove steps below the threshold percentile of importance."""
    if threshold == 0.0:
        return thinking + '\n' + answer, len(steps), len(steps)

    importances = [s['importance'] for s in steps]
    cutoff = np.percentile(importances, threshold * 100)

    kept = [s for s in steps if s['importance'] >= cutoff]
    if not kept:
        kept = [max(steps, key=lambda s: s['importance'])]

    pruned_text = '\n\n'.join(s['text'] for s in kept) + '\n' + answer
    return pruned_text, len(steps), len(kept)


def extract_answer(text):
    """Extract numeric answer (AIME: 0-999)."""
    for pattern in [r'\\boxed\{(\d+)\}', r'the answer is\s*[:\s]*(\d+)',
                     r'final answer[:\s]*(\d+)', r'= (\d+)\s*$']:
        matches = re.findall(pattern, text, re.IGNORECASE | re.MULTILINE)
        if matches:
            try:
                ans = int(matches[-1])
                if 0 <= ans <= 999:
                    return ans
            except ValueError:
                continue

    numbers = re.findall(r'\b(\d{1,3})\b', text[-200:])
    return int(numbers[-1]) if numbers else None


def reevaluate(problem, pruned_reasoning):
    """Re-evaluate with pruned context."""
    prompt = f'Problem: {problem}\n\nReasoning:\n{pruned_reasoning}\n\nBased on the above, the final answer is:'
    return generate_reasoning(prompt, max_tokens=512)

print('Functions defined.')

Functions defined.


## 4. Run Experiment

In [6]:
THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
all_results = []
metrics = {t: {'correct': 0, 'total': 0, 'kept_pct': [], 'lengths': []} for t in THRESHOLDS}

for i, prob in enumerate(AIME24_PROBLEMS):
    print(f'\n{"="*50}')
    print(f'Problem {i+1}/{len(AIME24_PROBLEMS)}')
    print(f'{"="*50}')

    # Generate full reasoning
    t0 = time.time()
    full_text = generate_reasoning(prob['problem'])
    print(f'  Generated {len(full_text)} chars in {time.time()-t0:.1f}s')

    # Segment & score
    steps, thinking, answer = segment_steps(full_text)
    print(f'  {len(steps)} reasoning steps found')

    if not steps:
        for t in THRESHOLDS:
            metrics[t]['total'] += 1
        continue

    try:
        steps = score_steps_by_attention(steps, full_text, prob['problem'])
    except Exception as e:
        print(f'  Attention scoring failed: {e}')
        for s in steps:
            s['importance'] = 1.0 / len(steps)

    # Evaluate at each threshold
    prob_result = {'id': prob['id'], 'expected': prob['answer']}

    for t in THRESHOLDS:
        pruned, orig, kept = prune_steps(steps, thinking, answer, t)

        if t == 0.0:
            predicted = extract_answer(full_text)
        else:
            eval_out = reevaluate(prob['problem'], pruned)
            predicted = extract_answer(eval_out)
            if predicted is None:
                predicted = extract_answer(pruned)

        correct = predicted is not None and predicted == prob['answer']

        metrics[t]['correct'] += int(correct)
        metrics[t]['total'] += 1
        metrics[t]['kept_pct'].append(kept / max(orig, 1))
        metrics[t]['lengths'].append(len(pruned))

        print(f'  t={t:.1f}: {orig}->{kept} steps, pred={predicted}, exp={prob["answer"]}, {"✓" if correct else "✗"}')

    torch.cuda.empty_cache()

print('\nExperiment complete!')

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



Problem 1/5
  Generated 14439 chars in 380.7s
  111 reasoning steps found


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


  Attention scoring failed: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 14.56 GiB of which 641.81 MiB is free. Including non-PyTorch memory, this process has 13.93 GiB memory in use. Of the allocated memory 12.24 GiB is allocated by PyTorch, and 1.56 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
  t=0.0: 111->111 steps, pred=0, exp=31, ✗


OutOfMemoryError: CUDA out of memory. Tried to allocate 798.00 MiB. GPU 0 has a total capacity of 14.56 GiB of which 613.81 MiB is free. Including non-PyTorch memory, this process has 13.96 GiB memory in use. Of the allocated memory 12.18 GiB is allocated by PyTorch, and 1.66 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## 5. Results & Comparison Table

In [None]:
print('\n' + '=' * 80)
print('COMPARISON TABLE: Attention-Based Reasoning Pruning on AIME24')
print(f'Model: {MODEL_NAME}')
print('=' * 80)
print(f'{"Threshold":<12} {"Accuracy":<12} {"Steps Kept":<14} {"Avg Length":<14} {"Correct/Total"}')
print('-' * 80)

baseline_acc = None
baseline_len = None

for t in THRESHOLDS:
    m = metrics[t]
    if m['total'] == 0:
        continue
    acc = m['correct'] / m['total'] * 100
    kept = np.mean(m['kept_pct']) * 100 if m['kept_pct'] else 0
    avg_len = np.mean(m['lengths']) if m['lengths'] else 0

    if t == 0.0:
        baseline_acc = acc
        baseline_len = avg_len
        label = f'{t:.1f} (base)'
    else:
        label = f'{t:.1f}'

    delta_acc = f' ({acc - baseline_acc:+.1f})' if baseline_acc is not None and t > 0 else ''
    delta_len = f' (-{(1 - avg_len/baseline_len)*100:.0f}%)' if baseline_len and t > 0 else ''

    print(f'{label:<12} {acc:.1f}%{delta_acc:<8} {kept:.1f}%{"":<9} '
          f'{avg_len:.0f}{delta_len:<9} {m["correct"]}/{m["total"]}')

print('=' * 80)
print()
print('Key Findings:')
print('- Threshold 0.0 = full reasoning chain (baseline)')
print('- Moderate pruning (0.1-0.3) removes distracting steps,')
print('  often maintaining or improving accuracy ("Think Clearly" effect)')
print('- Aggressive pruning (0.4+) risks removing critical reasoning steps')

## 6. Attention Heatmap Visualization

In [None]:
import matplotlib.pyplot as plt

# Visualize attention importance for the last problem
if steps:
    fig, ax = plt.subplots(1, 1, figsize=(12, 4))

    step_ids = [s['step_id'] for s in steps]
    importances = [s['importance'] for s in steps]

    colors = ['#2ecc71' if imp > np.median(importances) else '#e74c3c' for imp in importances]

    ax.bar(step_ids, importances, color=colors, alpha=0.8)
    ax.set_xlabel('Reasoning Step ID')
    ax.set_ylabel('Attention Importance Score')
    ax.set_title('Attention-Based Step Importance\n(Green=Important, Red=Prunable)')
    ax.axhline(y=np.median(importances), color='black', linestyle='--', alpha=0.5, label='Median')
    ax.legend()
    plt.tight_layout()
    plt.show()

## 7. Save Results

In [None]:
results_summary = {
    'model': MODEL_NAME,
    'dataset': 'AIME24',
    'num_problems': len(AIME24_PROBLEMS),
    'thresholds': {}
}

for t in THRESHOLDS:
    m = metrics[t]
    if m['total'] > 0:
        results_summary['thresholds'][str(t)] = {
            'accuracy': m['correct'] / m['total'],
            'avg_steps_kept_pct': float(np.mean(m['kept_pct'])) * 100 if m['kept_pct'] else 0,
            'avg_output_length': float(np.mean(m['lengths'])) if m['lengths'] else 0,
            'correct': m['correct'],
            'total': m['total'],
        }

with open('aime24_pruning_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print('Results saved to aime24_pruning_results.json')
print('Upload this notebook and results to GitHub!')