# Chain-of-Thought Reasoning

In this notebook, you'll explore the mechanism behind chain-of-thought prompting—that intermediate reasoning tokens give the model additional forward passes worth of computation, expanding its effective capacity beyond a single forward pass.

**What you'll do:**
- Compare direct answers vs chain-of-thought on arithmetic problems of varying complexity, and see the accuracy difference empirically
- Count intermediate tokens generated during CoT and plot them against problem complexity—each token is an additional forward pass
- Manually corrupt intermediate reasoning steps and observe error propagation—the model does not "catch" mistakes
- Design your own experiment to find the complexity boundary where CoT starts helping

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones—they reveal gaps in your mental model.

In [None]:
# Setup — self-contained for Google Colab
!pip install -q openai

import os
import json
import textwrap
import random
import re
from openai import OpenAI
import matplotlib.pyplot as plt
import numpy as np

# --- API Key Setup ---
# Option 1: Set your API key as an environment variable (recommended)
#   In Colab: go to the key icon in the left sidebar, add OPENAI_API_KEY
# Option 2: Paste it directly (less secure, don't commit this)
#   os.environ["OPENAI_API_KEY"] = "sk-..."

# You can also use any OpenAI-compatible API (e.g., local Ollama, Together AI)
# by changing the base_url:
#   client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")

client = OpenAI()

# Use a small, cheap model for the exercises
MODEL = "gpt-4o-mini"

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Reproducible results where possible
random.seed(42)
np.random.seed(42)


def call_llm(prompt: str, temperature: float = 0.0, max_tokens: int = 300) -> str:
    """Call the LLM with a single prompt. Returns the response text."""
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return response.choices[0].message.content.strip()


def call_llm_with_usage(prompt: str, temperature: float = 0.0,
                        max_tokens: int = 300) -> tuple[str, int]:
    """Call the LLM and return (response_text, completion_tokens)."""
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature,
        max_tokens=max_tokens,
    )
    text = response.choices[0].message.content.strip()
    tokens = response.usage.completion_tokens
    return text, tokens


def print_wrapped(text: str, width: int = 80, prefix: str = ""):
    """Print text with word wrapping for readability."""
    for line in text.split("\n"):
        wrapped = textwrap.fill(line, width=width, initial_indent=prefix,
                                subsequent_indent=prefix)
        print(wrapped)


def extract_number(text: str) -> int | None:
    """Extract the last number from a response string."""
    numbers = re.findall(r'-?\d[\d,]*', text.replace(',', ''))
    if not numbers:
        return None
    try:
        return int(numbers[-1])
    except ValueError:
        return None


# Quick test to verify the API is working
test = call_llm("Say 'API connection successful' and nothing else.")
print(test)
print(f"\nUsing model: {MODEL}")
print("Setup complete.")

## Shared Data

All exercises use arithmetic and word problems of varying complexity. Problems are defined here so exercises can share them.

In [None]:
# --- Arithmetic problems: single-step and multi-step ---
# Each entry: (description, expression or problem text, correct answer, num_steps)

PROBLEMS = [
    # Single-step (should NOT benefit from CoT)
    ("simple addition", "What is 23 + 45?", 68, 1),
    ("simple subtraction", "What is 91 - 37?", 54, 1),
    ("simple multiplication", "What is 6 × 9?", 54, 1),
    ("single-digit multiply", "What is 8 × 7?", 56, 1),

    # Multi-step (SHOULD benefit from CoT)
    ("two-digit multiply", "What is 17 × 24?", 408, 3),
    ("two-digit multiply", "What is 34 × 56?", 1904, 3),
    ("three operations", "What is (15 × 8) + (12 × 3)?", 156, 3),
    ("word problem",
     "A store has 3 shelves with 8 books each. They remove 5 books. How many books are left?",
     19, 2),
    ("word problem",
     "A farmer has 4 fields, each with 6 rows of corn. Each row has 15 plants. How many plants total?",
     360, 3),
    ("chained operations", "What is 13 × 17 + 8 × 9 - 45?", 248, 4),
]

print(f"Loaded {len(PROBLEMS)} problems:")
print(f"  Single-step: {sum(1 for _, _, _, s in PROBLEMS if s == 1)}")
print(f"  Multi-step:  {sum(1 for _, _, _, s in PROBLEMS if s > 1)}")
print("\nData loaded.")

---

## Exercise 1: Direct vs CoT Comparison (Guided)

The lesson demonstrated that the same model gets different answers with and without chain-of-thought prompting. In this exercise, you'll test that claim systematically across 10 problems of varying complexity.

For each problem, you'll run two prompts:
- **Direct:** Ask for the answer only, no intermediate steps
- **CoT:** Ask to think step by step before answering

**Before running, predict:**
- Which problems will benefit from CoT? (Hint: the lesson says the criterion is computational complexity—whether the problem exceeds single-forward-pass capacity)
- Will CoT ever *hurt* accuracy on the simple problems?
- How large will the accuracy gap be on multi-step problems?

In [None]:
# --- Step 1: Run all problems with both prompting strategies ---

results = []

for desc, problem, answer, steps in PROBLEMS:
    # Direct prompt: ask for the answer only
    direct_prompt = f"{problem} Answer with ONLY the number, nothing else."
    direct_response = call_llm(direct_prompt)
    direct_answer = extract_number(direct_response)
    direct_correct = direct_answer == answer

    # CoT prompt: ask to think step by step
    cot_prompt = f"{problem} Let's work through this step by step, then give the final answer."
    cot_response = call_llm(cot_prompt, max_tokens=500)
    cot_answer = extract_number(cot_response)
    cot_correct = cot_answer == answer

    results.append({
        "desc": desc,
        "problem": problem,
        "answer": answer,
        "steps": steps,
        "direct_response": direct_response,
        "direct_answer": direct_answer,
        "direct_correct": direct_correct,
        "cot_response": cot_response,
        "cot_answer": cot_answer,
        "cot_correct": cot_correct,
    })

    symbol_d = "✓" if direct_correct else "✗"
    symbol_c = "✓" if cot_correct else "✗"
    print(f"{desc:25s} | correct={answer:5d} | direct={str(direct_answer):>6s} {symbol_d} | cot={str(cot_answer):>6s} {symbol_c}")

print("\nAll problems tested.")

In [None]:
# --- Step 2: Summarize and visualize ---

single_step = [r for r in results if r["steps"] == 1]
multi_step = [r for r in results if r["steps"] > 1]

single_direct_acc = sum(r["direct_correct"] for r in single_step) / len(single_step)
single_cot_acc = sum(r["cot_correct"] for r in single_step) / len(single_step)
multi_direct_acc = sum(r["direct_correct"] for r in multi_step) / len(multi_step)
multi_cot_acc = sum(r["cot_correct"] for r in multi_step) / len(multi_step)

print("ACCURACY SUMMARY")
print("=" * 50)
print(f"Single-step problems ({len(single_step)}):")
print(f"  Direct: {single_direct_acc:.0%}")
print(f"  CoT:    {single_cot_acc:.0%}")
print(f"  Improvement: {single_cot_acc - single_direct_acc:+.0%}")
print(f"\nMulti-step problems ({len(multi_step)}):")
print(f"  Direct: {multi_direct_acc:.0%}")
print(f"  CoT:    {multi_cot_acc:.0%}")
print(f"  Improvement: {multi_cot_acc - multi_direct_acc:+.0%}")

# Bar chart
fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(2)
width = 0.35
bars1 = ax.bar(x - width/2, [single_direct_acc * 100, multi_direct_acc * 100],
               width, label='Direct', color='#f59e0b', edgecolor='white', linewidth=0.5)
bars2 = ax.bar(x + width/2, [single_cot_acc * 100, multi_cot_acc * 100],
               width, label='CoT', color='#8b5cf6', edgecolor='white', linewidth=0.5)

ax.set_ylabel('Accuracy (%)', fontsize=11)
ax.set_title('Direct vs Chain-of-Thought Accuracy', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(['Single-Step', 'Multi-Step'])
ax.set_ylim(0, 110)
ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

for bar in bars1:
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 2,
            f'{bar.get_height():.0f}%', ha='center', va='bottom', fontsize=10, color='white')
for bar in bars2:
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 2,
            f'{bar.get_height():.0f}%', ha='center', va='bottom', fontsize=10, color='white')

plt.tight_layout()
plt.show()

print("\nKey insight: CoT helps on multi-step problems that exceed single-forward-pass capacity.")
print("Single-step problems fit within one forward pass — CoT adds nothing (and may waste tokens).")
print("The criterion is computational complexity, not difficulty in the human sense.")

In [None]:
# --- Step 3: Inspect a CoT response ---
# Let's look at one CoT response in detail to see the intermediate tokens.

# Find a multi-step problem where CoT was correct
cot_wins = [r for r in results if r["cot_correct"] and r["steps"] > 1]
if cot_wins:
    example = cot_wins[0]
    print(f"Problem: {example['problem']}")
    print(f"Correct answer: {example['answer']}")
    print(f"\nDirect response: {example['direct_response']}")
    print(f"Direct correct: {example['direct_correct']}")
    print(f"\nCoT response:")
    print_wrapped(example['cot_response'])
    print(f"\nCoT correct: {example['cot_correct']}")
    print("\nNotice the intermediate results in the CoT response.")
    print("Each one was generated token by token — each token is a forward pass.")
    print("The intermediate results are IN THE CONTEXT for subsequent forward passes.")
else:
    print("No multi-step CoT wins found — try adjusting the problems or model.")

**What just happened:** You tested 10 problems with both direct and CoT prompting. The results should confirm the lesson's claim: CoT improves accuracy on multi-step problems (where the computation exceeds a single forward pass) but adds little on single-step problems (where one pass is enough).

The key is not that CoT makes the model "think harder" — it's that CoT generates intermediate tokens, and each token triggers another forward pass. More forward passes = more computation. The intermediate results (like "340" in the 17 × 24 example) are generated into the context, where subsequent forward passes can attend to them.

---

## Exercise 2: Token Counting as Computation Measurement (Supported)

If each generated token is a forward pass, then counting tokens is counting computation. In this exercise, you'll measure how many tokens the model generates when solving problems with CoT, and compare that to problem complexity.

The API returns `completion_tokens` — the number of tokens the model generated. For CoT, this includes all the intermediate reasoning tokens plus the final answer. For direct prompting, it's just the answer tokens.

**Before running, predict:**
- Will more complex problems generate more tokens? (The lesson says yes — more complex problems require more intermediate results)
- How will the relationship look? Linear? Exponential?
- How many more tokens will CoT use compared to direct prompting?

In [None]:
# --- Step 1: Measure tokens for both strategies on all problems ---

token_data = []

for desc, problem, answer, steps in PROBLEMS:
    # Direct: answer only
    direct_prompt = f"{problem} Answer with ONLY the number, nothing else."
    direct_text, direct_tokens = call_llm_with_usage(direct_prompt)

    # CoT: step by step
    cot_prompt = f"{problem} Let's work through this step by step, then give the final answer."
    cot_text, cot_tokens = call_llm_with_usage(cot_prompt, max_tokens=500)

    token_data.append({
        "desc": desc,
        "steps": steps,
        "direct_tokens": direct_tokens,
        "cot_tokens": cot_tokens,
        "ratio": cot_tokens / max(direct_tokens, 1),
    })

    print(f"{desc:25s} | steps={steps} | direct_tokens={direct_tokens:3d} | cot_tokens={cot_tokens:3d} | ratio={cot_tokens/max(direct_tokens,1):.1f}x")

print("\nToken counting complete.")

In [None]:
# --- Step 2: Plot tokens vs problem complexity ---
# TODO: Create two plots that visualize the relationship between
# problem complexity and token count.
#
# Left plot (scatter): CoT tokens vs problem steps
#   - X axis: number of reasoning steps (token_data[i]["steps"])
#   - Y axis: CoT tokens generated (token_data[i]["cot_tokens"])
#   - Color single-step (#f59e0b) vs multi-step (#8b5cf6) differently
#
# Right plot (horizontal bar): token ratio per problem
#   - Y axis: problem descriptions (token_data[i]["desc"])
#   - X axis: ratio of CoT/direct tokens (token_data[i]["ratio"])
#   - Same color scheme as above
#
# Useful variables from the previous cell:
#   token_data[i]["steps"]         — number of reasoning steps
#   token_data[i]["cot_tokens"]    — tokens generated with CoT
#   token_data[i]["direct_tokens"] — tokens generated without CoT
#   token_data[i]["ratio"]         — cot_tokens / direct_tokens
#   token_data[i]["desc"]          — problem description string
#
# Hint: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# YOUR CODE HERE (20-30 lines)



<details>
<summary>Solution for Step 2</summary>

```python
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left plot: CoT tokens vs problem steps
steps_arr = [d["steps"] for d in token_data]
cot_tokens_arr = [d["cot_tokens"] for d in token_data]
direct_tokens_arr = [d["direct_tokens"] for d in token_data]
colors = ['#f59e0b' if s == 1 else '#8b5cf6' for s in steps_arr]

ax1.scatter(steps_arr, cot_tokens_arr, c=colors, s=80, edgecolors='white',
            linewidth=0.5, zorder=3)
ax1.set_xlabel('Problem Complexity (reasoning steps)', fontsize=11)
ax1.set_ylabel('CoT Tokens Generated', fontsize=11)
ax1.set_title('More Complex Problems → More Tokens → More Computation',
              fontsize=12, fontweight='bold')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Right plot: token ratio (CoT / direct)
ratios = [d["ratio"] for d in token_data]
descs = [d["desc"][:15] for d in token_data]
bar_colors = ['#f59e0b' if d["steps"] == 1 else '#8b5cf6' for d in token_data]

bars = ax2.barh(range(len(ratios)), ratios, color=bar_colors,
                edgecolor='white', linewidth=0.5)
ax2.set_yticks(range(len(descs)))
ax2.set_yticklabels(descs, fontsize=9)
ax2.set_xlabel('Token Ratio (CoT / Direct)', fontsize=11)
ax2.set_title('CoT Computation Multiplier per Problem', fontsize=12, fontweight='bold')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.axvline(x=1, color='white', linewidth=0.5, alpha=0.3)

for bar, ratio in zip(bars, ratios):
    ax2.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2,
             f'{ratio:.1f}×', va='center', fontsize=9, color='white')

# Legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#f59e0b', label='Single-step'),
                   Patch(facecolor='#8b5cf6', label='Multi-step')]
ax1.legend(handles=legend_elements)

plt.suptitle('Tokens as Computation: Each Token = One Forward Pass',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

avg_single_ratio = np.mean([d["ratio"] for d in token_data if d["steps"] == 1])
avg_multi_ratio = np.mean([d["ratio"] for d in token_data if d["steps"] > 1])
print(f"\nAverage token ratio (CoT/Direct):")
print(f"  Single-step: {avg_single_ratio:.1f}×")
print(f"  Multi-step:  {avg_multi_ratio:.1f}×")
print(f"\nThe model 'allocates' computation proportional to problem difficulty —")
print(f"not because it decides to, but because more complex problems require")
print(f"more intermediate results, and each result is generated token by token.")
```

</details>

**What just happened:** You measured the computation cost of CoT in tokens. More complex problems generated more intermediate tokens — and each token is an additional forward pass through N transformer blocks. The model doesn't "decide" to think harder; it generates more intermediate results because the problem requires them, and those results happen to provide the additional computation needed for a correct answer.

The token ratio shows the computation multiplier: a multi-step problem might use 10-20× more computation with CoT than without. That's 10-20× more forward passes, each running through the same N transformer blocks. Same model, same weights, dramatically more computation.

---

## Exercise 3: Error Propagation Experiment (Supported)

The lesson emphasized that the model does not "catch" errors in intermediate reasoning steps. If step 2 is wrong, subsequent steps build on the error. In this exercise, you'll test that claim by manually corrupting intermediate steps and observing how errors propagate.

The setup: take a multi-step problem, get the model's CoT solution, then replace one intermediate result with a wrong number. Feed the corrupted chain back to the model and ask it to continue. Observe that it continues from the error without catching it.

**Before running, predict:**
- Will the model notice the error and correct it? Or continue from it?
- Will an error early in the chain have a larger impact than an error late in the chain?
- What does this tell you about whether the model is "reasoning" vs "generating"?

In [None]:
# --- Step 1: Get a correct CoT solution ---

problem = "What is 17 × 24?"
correct_answer = 408

cot_prompt = f"{problem} Let's work through this step by step."
correct_cot = call_llm(cot_prompt, max_tokens=300)

print("CORRECT CoT SOLUTION")
print("=" * 50)
print_wrapped(correct_cot)
print(f"\nExtracted answer: {extract_number(correct_cot)}")
print(f"Correct answer: {correct_answer}")

In [None]:
# --- Step 2: Corrupt an early intermediate step ---
# We'll give the model a partial CoT with a WRONG intermediate result
# and ask it to continue.

# Early corruption: change 17 × 20 = 340 to 17 × 20 = 350 (off by 10)
early_corruption = f"""What is 17 × 24? Let's work through this step by step.

Step 1: Break 24 into 20 + 4
Step 2: 17 × 20 = 350

Continue from here and give the final answer."""

early_result = call_llm(early_corruption, max_tokens=300)

print("EARLY CORRUPTION (17 × 20 = 350 instead of 340)")
print("=" * 50)
print_wrapped(early_result)
print(f"\nExtracted answer: {extract_number(early_result)}")
print(f"Correct answer: {correct_answer}")
print(f"\nDid the model catch the error? {'Yes' if extract_number(early_result) == correct_answer else 'No — it continued from the wrong intermediate result.'}")

In [None]:
# --- Step 3: Corrupt a late intermediate step ---
# Late corruption: correct first step, wrong second step

late_corruption = f"""What is 17 × 24? Let's work through this step by step.

Step 1: Break 24 into 20 + 4
Step 2: 17 × 20 = 340
Step 3: 17 × 4 = 72

Continue from here and give the final answer."""

late_result = call_llm(late_corruption, max_tokens=300)

print("LATE CORRUPTION (17 × 4 = 72 instead of 68)")
print("=" * 50)
print_wrapped(late_result)
print(f"\nExtracted answer: {extract_number(late_result)}")
print(f"Correct answer: {correct_answer}")
print(f"\nDid the model catch the error? {'Yes' if extract_number(late_result) == correct_answer else 'No — it continued from the wrong intermediate result.'}")

In [None]:
# --- Step 4: Test with more problems ---
# TODO: Try corruption on a different problem. Choose one of the multi-step
# problems from PROBLEMS, get the correct CoT, then corrupt one step.
#
# Suggested problem: "A farmer has 4 fields, each with 6 rows of corn.
#   Each row has 15 plants. How many plants total?"
#   Correct: 4 × 6 = 24 rows, 24 × 15 = 360 plants
#   Corruption: change 4 × 6 = 24 to 4 × 6 = 28
#
# YOUR CODE HERE (5-15 lines)
# 1. Write the corrupted prompt (with the wrong intermediate result)
# 2. Call the LLM
# 3. Print the result and check if the model caught the error



In [None]:
# --- Step 5: Summary of corruption results ---

print("ERROR PROPAGATION SUMMARY")
print("=" * 50)
print(f"\nProblem: 17 × 24 = {correct_answer}")
print(f"  No corruption:     {extract_number(correct_cot)}")
print(f"  Early corruption:  {extract_number(early_result)} (17 × 20 = 350 instead of 340)")
print(f"  Late corruption:   {extract_number(late_result)} (17 × 4 = 72 instead of 68)")
print()
print("The model does not 'catch' errors. It continues from whatever context")
print("exists — because it is not 'reasoning,' it is generating tokens that")
print("feed back as context for subsequent forward passes.")
print()
print("This is why CoT quality matters more than quantity. An error in an")
print("early step corrupts all subsequent steps. The model builds on the")
print("context, and if the context is wrong, the building is on a bad foundation.")

**What just happened:** You corrupted intermediate reasoning steps and observed that the model continues from the error without catching it. This demonstrates two key points from the lesson:

1. **The model is not "reasoning"** — it is generating tokens that feed back as context. If the context contains an error, subsequent forward passes build on the error. There is no "checking" step.

2. **CoT quality matters more than quantity** — a wrong intermediate step produces a wrong final answer. Longer chains are not automatically better; they need to contain *correct* intermediate results. This connects to the ICL finding that more examples don't always help, and the RAG finding that irrelevant context dilutes attention.

<details>
<summary>Solution for Step 4</summary>

```python
farmer_corruption = """A farmer has 4 fields, each with 6 rows of corn. Each row has 15 plants. How many plants total?
Let's work through this step by step.

Step 1: Find total rows: 4 fields × 6 rows = 28 rows

Continue from here and give the final answer."""

farmer_result = call_llm(farmer_corruption, max_tokens=300)
print("FARMER PROBLEM CORRUPTION (4 × 6 = 28 instead of 24)")
print("=" * 50)
print_wrapped(farmer_result)
print(f"\nExtracted answer: {extract_number(farmer_result)}")
print(f"Correct answer: 360")
print(f"Expected wrong answer: 28 × 15 = 420")
caught = extract_number(farmer_result) == 360
print(f"Did the model catch the error? {'Yes' if caught else 'No'}")
```

The model should produce 420 (28 × 15) instead of 360 (24 × 15). It continues from the corrupted intermediate result "28 rows" without noticing that 4 × 6 = 24, not 28.

</details>

---

## Exercise 4: Find the CoT Boundary (Independent)

You've seen that CoT helps on multi-step problems but not on single-step ones. Where exactly is the boundary? In this exercise, you design your own experiment to find it.

**Your task:**
1. Choose a domain (arithmetic, logic puzzles, word problems, or something else)
2. Design a set of problems with increasing complexity (at least 6 problems, spanning from trivially easy to genuinely hard)
3. Test each with and without CoT
4. Plot accuracy vs complexity for both strategies
5. Identify the approximate threshold where CoT starts helping

**No skeleton is provided.** Design the experiment yourself. Think about:
- How will you measure "complexity"? (Number of operations? Digit count? Reasoning steps?)
- How will you ensure the problems are comparable?
- How many trials per problem for reliable accuracy estimates?

The solution is in the `<details>` block at the end.

In [None]:
# Your experiment here.
#
# 1. Define your problems with increasing complexity
# 2. Test each with direct and CoT prompting
# 3. Measure accuracy (run multiple trials if using temperature > 0)
# 4. Plot the results
# 5. Identify the boundary



In [None]:
# Reflection:
#
# 1. Where was the boundary? At what complexity did CoT start helping?
# 2. Was the boundary sharp or gradual?
# 3. Is the boundary task-dependent? Would it be different for logic puzzles vs arithmetic?
# 4. Is the boundary model-dependent? Would a larger model have the boundary at a different point?
#
# Print your observations:
print("Reflection:")
print("  1. Boundary location: ...")
print("  2. Sharp or gradual: ...")
print("  3. Task-dependent: ...")
print("  4. Model-dependent: ...")

<details>
<summary>Solution</summary>

**Design rationale:** Use arithmetic with increasing digit count as the complexity axis. This gives a clean, measurable complexity metric (number of digits) and unambiguous correctness (the answer is a number).

```python
import random

# Generate multiplication problems with increasing digit counts
# Complexity = total digits across both operands
rng = random.Random(42)

boundary_problems = []
for complexity in [2, 3, 4, 5, 6]:
    # Generate 3 problems at each complexity level
    for trial in range(3):
        if complexity == 2:  # 1×1 digit
            a, b = rng.randint(2, 9), rng.randint(2, 9)
        elif complexity == 3:  # 2×1 digit
            a, b = rng.randint(10, 99), rng.randint(2, 9)
        elif complexity == 4:  # 2×2 digit
            a, b = rng.randint(10, 99), rng.randint(10, 99)
        elif complexity == 5:  # 3×2 digit
            a, b = rng.randint(100, 999), rng.randint(10, 99)
        else:  # 3×3 digit
            a, b = rng.randint(100, 999), rng.randint(100, 999)
        boundary_problems.append((complexity, a, b, a * b))

# Test each problem
boundary_results = []
for complexity, a, b, answer in boundary_problems:
    problem = f"What is {a} × {b}?"
    
    direct = call_llm(f"{problem} Answer with ONLY the number.")
    direct_correct = extract_number(direct) == answer
    
    cot = call_llm(f"{problem} Let's work through this step by step.", max_tokens=500)
    cot_correct = extract_number(cot) == answer
    
    boundary_results.append({
        "complexity": complexity,
        "problem": f"{a}×{b}",
        "answer": answer,
        "direct_correct": direct_correct,
        "cot_correct": cot_correct,
    })
    
    d_sym = "✓" if direct_correct else "✗"
    c_sym = "✓" if cot_correct else "✗"
    print(f"complexity={complexity} | {a}×{b}={answer} | direct {d_sym} | cot {c_sym}")

# Aggregate by complexity level
complexities = sorted(set(r["complexity"] for r in boundary_results))
direct_accs = []
cot_accs = []
for c in complexities:
    level_results = [r for r in boundary_results if r["complexity"] == c]
    direct_accs.append(sum(r["direct_correct"] for r in level_results) / len(level_results))
    cot_accs.append(sum(r["cot_correct"] for r in level_results) / len(level_results))

# Plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(complexities, [a * 100 for a in direct_accs], 'o-', color='#f59e0b',
        label='Direct', linewidth=2, markersize=8)
ax.plot(complexities, [a * 100 for a in cot_accs], 's-', color='#8b5cf6',
        label='CoT', linewidth=2, markersize=8)
ax.set_xlabel('Complexity (total digits)', fontsize=11)
ax.set_ylabel('Accuracy (%)', fontsize=11)
ax.set_title('Finding the CoT Boundary', fontsize=13, fontweight='bold')
ax.set_ylim(-5, 110)
ax.legend(fontsize=11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Annotate the boundary region
for i, c in enumerate(complexities):
    if direct_accs[i] < cot_accs[i]:
        ax.axvspan(c - 0.3, c + 0.3, alpha=0.1, color='#8b5cf6')

plt.tight_layout()
plt.show()

# Find the boundary
boundary = None
for i, c in enumerate(complexities):
    if direct_accs[i] < cot_accs[i]:
        boundary = c
        break

print(f"\nApproximate boundary: complexity = {boundary}")
print(f"Below this, direct prompting is sufficient (single-pass capacity).")
print(f"Above this, CoT provides needed additional computation.")
print(f"\nThe boundary is task-dependent (arithmetic vs logic) and model-dependent")
print(f"(larger models can handle more complexity in a single pass).")
```

**Expected findings:** The boundary typically appears around 2×2 digit multiplication (complexity 4). Single-digit multiplication fits in one forward pass. Two-digit multiplication starts to exceed it. The boundary is somewhat model-dependent — larger models may handle 2×2 in a single pass but fail at 3×2.

**Key insight:** The boundary corresponds to the point where the problem exceeds single-forward-pass computational capacity. Below the boundary, the model can map input to answer in N transformer blocks. Above it, the intermediate results provide the additional computation needed.

</details>

---

## Key Takeaways

1. **Chain-of-thought works because each intermediate token triggers an additional forward pass.** The model does not "think harder." It runs more forward passes, each building on the context of previous ones. You measured this: multi-step problems used 10-20× more tokens (and therefore forward passes) than direct answers.

2. **CoT helps when the problem exceeds single-forward-pass capacity.** Single-step arithmetic fits in one pass. Multi-step problems do not. The boundary is task-dependent and model-dependent, but the criterion is always computational complexity.

3. **The model does not catch errors in intermediate steps.** Corrupted intermediate results propagate to the final answer. The model continues from whatever context exists. This proves it is generating tokens, not "reasoning" — it has no mechanism to verify its own intermediate results.

4. **CoT quality matters more than quantity.** Each intermediate token must provide a useful result. More tokens add more forward passes, but only useful tokens add useful computation. Noise in the chain dilutes attention, just like irrelevant documents in RAG.

5. **The mechanism is the autoregressive loop you already understand.** Same `generate()` method. Same forward pass. Same N transformer blocks. The only difference is how many tokens are generated. CoT is not a new mechanism — it is a new way to see the existing one.