# Reasoning Step Pruning via Attention Scores

**Based on:**
1. [Think Clearly: Improving Reasoning via Redundant Token Pruning](https://arxiv.org/abs/2507.08806)
2. [TRAAC: Think Right with Adaptive, Attentive Compression](https://arxiv.org/abs/2510.01581)

**Author:** Naveen Pasupuleti

## 0. Setup

In [None]:
!pip install transformers accelerate bitsandbytes sentencepiece protobuf -q

In [None]:
import os, re, json, time, torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import warnings
warnings.filterwarnings('ignore')

print(f'GPU Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    props = torch.cuda.get_device_properties(0)
    mem = getattr(props, 'total_memory', None) or getattr(props, 'total_mem', 0)
    print(f'Memory: {mem / 1e9:.1f} GB')
print(f'PyTorch: {torch.__version__}')

## 1. AIME24 Dataset

In [None]:
AIME24_PROBLEMS = [
    {"id": 1, "problem": "Every morning Asha decides randomly whether to walk left or right and independently Sasha does the same. They start on different ends of a 4-block long street. What is the probability that they meet? Express your answer as a fraction m/n in lowest terms and find m+n.", "answer": 31},
    {"id": 2, "problem": "There exist real numbers x and y, both greater than 1, such that log_x(y^x) = log_y(x^(4y)) = 10. Find xy.", "answer": 25},
    {"id": 3, "problem": "Alice and Bob play a game. Alice starts first and they alternate turns. Alice's move is to choose an integer from 1 to 6 (inclusive) and add it to the running total. Bob does the same. The player who brings the running total to exactly 2024 wins. What is the smallest starting move Alice can use to guarantee a win?", "answer": 5},
    {"id": 4, "problem": "Let x, y, and z be positive real numbers satisfying the system: log_2(x/yz) = 1/2, log_2(y/xz) = 1/3, log_2(z/xy) = 1/4. Find the value of |log_2(x^4 y^3 z^2)|. Express as a fraction p/q in lowest terms and find p+q.", "answer": 33},
    {"id": 5, "problem": "Rectangle ABCD has side lengths AB=10 and BC=4. Point M is the midpoint of CD. Triangle AMB is removed, leaving a quadrilateral. Find the perimeter of the quadrilateral formed. If the answer is a+b*sqrt(c), find a+b+c.", "answer": 18},
]

print(f'Loaded {len(AIME24_PROBLEMS)} AIME24 problems')

## 2. Load Quantized Reasoning Model

In [None]:
MODEL_NAME = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'

print(f'Loading {MODEL_NAME}...')
print('This may take 2-3 minutes for download + quantization.')

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=True,
    attn_implementation='eager',
)
model.eval()

print('Model loaded successfully!')
if torch.cuda.is_available():
    alloc = torch.cuda.memory_allocated() / 1e9
    print(f'GPU Memory used: {alloc:.1f} GB')

## 3. Core Functions

In [None]:
def generate_reasoning(prompt, max_tokens=4096):
    messages = [{'role': 'user', 'content': prompt}]
    if hasattr(tokenizer, 'apply_chat_template'):
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'
    inputs = tokenizer(text, return_tensors='pt').to(model.device)
    input_len = inputs['input_ids'].shape[1]
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.6, top_p=0.95, do_sample=True)
    return tokenizer.decode(out[0][input_len:], skip_special_tokens=True)


def segment_steps(text):
    think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
    if think_match:
        thinking = think_match.group(1)
        answer = text[think_match.end():]
    else:
        thinking = text
        answer = ''
    chunks = re.split(r'\n\n+', thinking)
    steps = []
    for chunk in chunks:
        chunk = chunk.strip()
        if len(chunk) >= 10:
            steps.append({'text': chunk, 'step_id': len(steps)})
    return steps, thinking, answer


def score_steps_by_attention(steps, full_text, prompt):
    messages = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': full_text}]
    if hasattr(tokenizer, 'apply_chat_template'):
        text = tokenizer.apply_chat_template(messages, tokenize=False)
    else:
        text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{full_text}<|im_end|>'
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=4096).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    last_attn = torch.stack([layer[:, :, -1, :] for layer in outputs.attentions]).mean(dim=(0, 2))[0].cpu().numpy()
    for marker in ['<|im_start|>assistant', 'assistant\n', '<|assistant|>']:
        marker_pos = text.find(marker)
        if marker_pos != -1:
            offset = len(tokenizer.encode(text[:marker_pos + len(marker) + 1]))
            break
    else:
        offset = 0
    for step in steps:
        start = full_text.find(step['text'])
        if start == -1:
            step['importance'] = 0.0
            step['avg_attention'] = 0.0
            step['num_tokens'] = 0
            continue
        prefix_toks = len(tokenizer.encode(full_text[:start], add_special_tokens=False))
        step_toks = len(tokenizer.encode(step['text'], add_special_tokens=False))
        t_start = max(0, min(offset + prefix_toks, len(last_attn) - 1))
        t_end = max(t_start + 1, min(offset + prefix_toks + step_toks, len(last_attn)))
        step_attn = last_attn[t_start:t_end]
        step['importance'] = float(np.sum(step_attn))
        step['avg_attention'] = float(np.mean(step_attn)) if len(step_attn) > 0 else 0.0
        step['num_tokens'] = step_toks
    total = sum(s['importance'] for s in steps)
    if total > 0:
        for s in steps:
            s['importance'] /= total
    del outputs
    torch.cuda.empty_cache()
    return steps


def prune_steps(steps, thinking, answer, threshold):
    if threshold == 0.0 or not steps:
        return thinking + '\n' + answer, len(steps), len(steps)
    importances = [s['importance'] for s in steps]
    cutoff = np.percentile(importances, threshold * 100)
    kept = [s for s in steps if s['importance'] >= cutoff]
    if not kept:
        kept = [max(steps, key=lambda s: s['importance'])]
    pruned_text = '\n\n'.join(s['text'] for s in kept) + '\n' + answer
    return pruned_text, len(steps), len(kept)


def extract_answer(text):
    for pattern in [r'\\boxed\{(\d+)\}', r'the answer is\s*[:\s]*(\d+)',
                     r'final answer[:\s]*(\d+)', r'answer[:\s]*\*?\*?(\d+)', r'= (\d+)\s*$']:
        matches = re.findall(pattern, text, re.IGNORECASE | re.MULTILINE)
        if matches:
            try:
                ans = int(matches[-1])
                if 0 <= ans <= 999:
                    return ans
            except ValueError:
                continue
    numbers = re.findall(r'\b(\d{1,3})\b', text[-300:])
    return int(numbers[-1]) if numbers else None


def reevaluate(problem, pruned_reasoning):
    prompt = (f'Here is a math problem and a partial reasoning trace. '
              f'Based on the reasoning, give the final numeric answer.\n\n'
              f'Problem: {problem}\n\nReasoning:\n{pruned_reasoning}\n\n'
              f'Based on the above reasoning, the final answer is:')
    return generate_reasoning(prompt, max_tokens=512)


print('All functions defined.')

## 4. Run Experiment

In [None]:
THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
metrics = {t: {'correct': 0, 'total': 0, 'kept_pct': [], 'lengths': []} for t in THRESHOLDS}
all_step_data = []

print(f'Running: {len(AIME24_PROBLEMS)} problems x {len(THRESHOLDS)} thresholds')
print(f'Estimated time: 15-25 minutes\n')

for i, prob in enumerate(AIME24_PROBLEMS):
    print(f'\n{"="*60}')
    print(f'Problem {i+1}/{len(AIME24_PROBLEMS)} (ID: {prob["id"]})')
    print(f'{"="*60}')

    t0 = time.time()
    full_text = generate_reasoning(prob['problem'])
    print(f'  Generated {len(full_text)} chars in {time.time()-t0:.1f}s')

    steps, thinking, answer = segment_steps(full_text)
    print(f'  {len(steps)} reasoning steps')

    if not steps:
        print('  No steps found, skipping')
        for t in THRESHOLDS:
            metrics[t]['total'] += 1
        continue

    print('  Scoring by attention...')
    try:
        steps = score_steps_by_attention(steps, full_text, prob['problem'])
        sorted_s = sorted(steps, key=lambda s: s['importance'], reverse=True)
        print(f'  Top: score={sorted_s[0]["importance"]:.4f}')
        print(f'  Bottom: score={sorted_s[-1]["importance"]:.4f}')
    except Exception as e:
        print(f'  Attention failed: {e}, using uniform')
        for s in steps:
            s['importance'] = 1.0 / len(steps)
            s['avg_attention'] = 1.0 / len(steps)

    all_step_data.append({'problem_id': prob['id'], 'steps': steps})

    for t in THRESHOLDS:
        pruned, orig, kept = prune_steps(steps, thinking, answer, t)
        if t == 0.0:
            predicted = extract_answer(full_text)
        else:
            eval_out = reevaluate(prob['problem'], pruned)
            predicted = extract_answer(eval_out)
            if predicted is None:
                predicted = extract_answer(pruned)
        correct = predicted is not None and predicted == prob['answer']
        metrics[t]['correct'] += int(correct)
        metrics[t]['total'] += 1
        metrics[t]['kept_pct'].append(kept / max(orig, 1))
        metrics[t]['lengths'].append(len(pruned))
        mark = 'Y' if correct else 'N'
        print(f'  t={t:.1f}: {orig}->{kept} steps, pred={predicted}, exp={prob["answer"]}, {mark}')

    torch.cuda.empty_cache()

print(f'\nDone!')

## 5. Comparison Table

In [None]:
print()
print('=' * 85)
print('  COMPARISON TABLE: Attention-Based Reasoning Pruning on AIME24')
print(f'  Model: {MODEL_NAME}')
print(f'  Papers: Think Clearly (2507.08806) + TRAAC (2510.01581)')
print('=' * 85)
print(f'{"Threshold":<14} {"Accuracy":<14} {"Steps Kept":<14} {"Avg Length":<16} {"Correct/Total":<15}')
print('-' * 85)

baseline_acc = None
baseline_len = None

for t in THRESHOLDS:
    m = metrics[t]
    if m['total'] == 0:
        continue
    acc = m['correct'] / m['total'] * 100
    kept = np.mean(m['kept_pct']) * 100 if m['kept_pct'] else 0
    avg_len = np.mean(m['lengths']) if m['lengths'] else 0
    if t == 0.0:
        baseline_acc = acc
        baseline_len = avg_len
        label = f'{t:.1f} (baseline)'
        delta_acc = ''
        delta_len = ''
    else:
        label = f'{t:.1f}'
        delta_acc = f' ({acc - baseline_acc:+.1f}%)' if baseline_acc is not None else ''
        delta_len = f' (-{(1 - avg_len/baseline_len)*100:.0f}%)' if baseline_len and baseline_len > 0 else ''
    print(f'{label:<14} {acc:.1f}%{delta_acc:<10} {kept:.1f}%{"":<10} {avg_len:.0f}{delta_len:<12} {m["correct"]}/{m["total"]}')

print('=' * 85)
print()
print('Interpretation:')
print('  - Threshold 0.0 = full chain (baseline, no pruning)')
print('  - Moderate pruning (0.1-0.3): removes distracting steps, may improve accuracy')
print('  - Aggressive pruning (0.4+): risks removing critical steps')

## 6. Visualizations

In [None]:
import matplotlib.pyplot as plt

# Attention importance per step
if all_step_data:
    n = len(all_step_data)
    fig, axes = plt.subplots(n, 1, figsize=(12, 3 * n))
    if n == 1:
        axes = [axes]
    for idx, data in enumerate(all_step_data):
        ax = axes[idx]
        steps = data['steps']
        ids = [s['step_id'] for s in steps]
        imps = [s['importance'] for s in steps]
        if not imps:
            continue
        med = np.median(imps)
        colors = ['#27ae60' if v >= med else '#e74c3c' for v in imps]
        ax.bar(ids, imps, color=colors, alpha=0.85)
        ax.axhline(y=med, color='black', linestyle='--', alpha=0.4, label=f'Median')
        ax.set_xlabel('Step ID')
        ax.set_ylabel('Importance')
        ax.set_title(f'Problem {data["problem_id"]}: Step Importance (Green=Keep, Red=Prune)')
        ax.legend()
    plt.tight_layout()
    plt.savefig('attention_importance.png', dpi=150, bbox_inches='tight')
    plt.show()

# Accuracy vs threshold
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ts, accs, kepts = [], [], []
for t in THRESHOLDS:
    m = metrics[t]
    if m['total'] > 0:
        ts.append(t)
        accs.append(m['correct'] / m['total'] * 100)
        kepts.append(np.mean(m['kept_pct']) * 100)

ax1.plot(ts, accs, 'bo-', linewidth=2, markersize=8)
ax1.set_xlabel('Pruning Threshold')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title('Accuracy vs Pruning')
ax1.set_ylim([-5, 105])
ax1.grid(True, alpha=0.3)

ax2.plot(ts, kepts, 'rs-', linewidth=2, markersize=8)
ax2.set_xlabel('Pruning Threshold')
ax2.set_ylabel('Steps Kept (%)')
ax2.set_title('Compression vs Pruning')
ax2.set_ylim([-5, 105])
ax2.grid(True, alpha=0.3)

plt.suptitle('Attention-Based Reasoning Pruning - AIME24', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('pruning_results.png', dpi=150, bbox_inches='tight')
plt.show()
print('Plots saved.')

## 7. Save Results

In [None]:
results = {
    'model': MODEL_NAME,
    'dataset': 'AIME24',
    'num_problems': len(AIME24_PROBLEMS),
    'papers': ['Think Clearly (arXiv:2507.08806)', 'TRAAC (arXiv:2510.01581)'],
    'thresholds': {}
}
for t in THRESHOLDS:
    m = metrics[t]
    if m['total'] > 0:
        results['thresholds'][str(t)] = {
            'accuracy': round(m['correct'] / m['total'] * 100, 1),
            'steps_kept_pct': round(float(np.mean(m['kept_pct'])) * 100, 1),
            'avg_length': round(float(np.mean(m['lengths']))),
            'correct': m['correct'],
            'total': m['total'],
        }

with open('aime24_pruning_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print('Saved: aime24_pruning_results.json')
print()
print('Upload to GitHub:')
print('  1. This notebook')
print('  2. aime24_pruning_results.json')
print('  3. attention_importance.png')
print('  4. pruning_results.png')
print('  5. README.md')