Reasoning with LLMs: Chain-of-Thought [CoT] & Verifier on Math and Code Problems



This project implements a compact reasoning pipeline for Large Language Models (LLMs) using Chain-of-Thought prompting, self-consistency sampling, and lightweight verifiers. A small hand-crafted dataset of mathematical and programming tasks is used to evaluate how well LLMs reason step by step, where they fail, and how verifiers can improve reliability. The goal is to demonstrate and understand transparent and trustworthy reasoning methods aligned with current research directions in efficient LLM reasoning.

Short Dataset (Math and Programming Samples) 

In [1]:
# Math dataset
math_dataset = [
    {"id": 1, "question": "If Lina has 5 apples and buys 7 more, how many does she have?", "answer": "12"},
    {"id": 2, "question": "A train travels 40 km in 2 hours. What is its speed in km/h?", "answer": "20"},
    {"id": 3, "question": "What is 15 × 6?", "answer": "90"},
    {"id": 4, "question": "If x + 3 = 10, what is x?", "answer": "7"},
    {"id": 5, "question": "What is 144 ÷ 12?", "answer": "12"},
    {"id": 6, "question": "If a rectangle has sides 4 and 9, what is its area?", "answer": "36"},
    {"id": 7, "question": "Add 123 and 456.", "answer": "579"},
    {"id": 8, "question": "What is 2 to the power 8?", "answer": "256"},
    {"id": 9, "question": "If a = 3 and b = 4, what is a^2 + b^2?", "answer": "25"},
    {"id": 10, "question": "What is (7 × 8) - 10?", "answer": "46"},
]

# Programming dataset
programming_dataset = [
    {
        "id": 1,
        "task": "Write a Python function is_prime(n) that returns True if n is prime.",
        "function_name": "is_prime",
        "tests": [
            ("is_prime(2)", True),
            ("is_prime(15)", False),
            ("is_prime(17)", True),
        ]
    },
    {
        "id": 2,
        "task": "Write a Python function factorial(n) returning factorial of n.",
        "function_name": "factorial",
        "tests": [
            ("factorial(0)", 1),
            ("factorial(5)", 120),
        ]
    },
    {
        "id": 3,
        "task": "Write a Python function reverse_string(s) returning the reversed string.",
        "function_name": "reverse_string",
        "tests": [
            ("reverse_string('abc')", "cba"),
        ]
    },
    {
        "id": 4,
        "task": "Write a Python function sum_list(xs) returning the sum of numbers in a list.",
        "function_name": "sum_list",
        "tests": [
            ("sum_list([1,2,3])", 6),
        ]
    },
    {
        "id": 5,
        "task": "Write a Python function fibonacci(n) returning nth Fibonacci number (0-indexed).",
        "function_name": "fibonacci",
        "tests": [
            ("fibonacci(0)", 0),
            ("fibonacci(5)", 5),
        ]
    },
]



In [2]:
dataset = {
    "math": math_dataset,
    "programming": programming_dataset
}


Load The Model Using Hugging Face

In [3]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
from transformers import pipeline, set_seed

generator = pipeline("text2text-generation", model="google/flan-t5-xl")
set_seed(42)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use mps:0


In [5]:
# Trial Run

prompt = "Question: What is 12 * 8?\nLet's think step by step and calculate the final numeric answer only."
output = generator(prompt, max_new_tokens=50)
print(output[0]["generated_text"])


12 * 8 = 96. The final answer: 96.


In [6]:
# Chain-of-thought CoT style math prompt
def cot_prompt_math(question):
    return f"Solve step by step and give the final numeric answer.\nQuestion: {question}"

# Programming prompt
def code_prompt(task):
    return f"Task: {task}\nWrite a Python function. Only output the function code."


In [7]:
import re

def extract_and_eval(text):
    # Look for 'final answer'
    m = re.search(r'Final\s*Answer[:\s]*([-+]?\d+\.?\d*)', text, re.IGNORECASE)
    if m: return m.group(1)
    
    # Evaluate simple arithmetic expressions
    exprs = re.findall(r'[-+]?\d+\s*[\+\-\*/]\s*[-+]?\d+', text)
    if exprs:
        try:
            return str(eval(exprs[-1]))
        except:
            pass
    
    # Last number fallback
    nums = re.findall(r"[-+]?\d+\.?\d*", text)
    return nums[-1] if nums else None

def extract_function(text):
    idx = text.find("def ")
    if idx == -1: return None
    return text[idx:]


In [8]:
def verify_math(pred, truth):
    try:
        return float(pred) == float(truth)
    except:
        return False
    

import multiprocessing, traceback

def _worker_run(code, test_cases, q):
    try:
        local_env = {}
        exec(code, {}, local_env)
        funcs = [v for v in local_env.values() if callable(v)]
        if not funcs:
            q.put((False, "No function found"))
            return
        func = funcs[0]
        for expr, expected in test_cases:
            if eval(expr, {}, local_env) != expected:
                q.put((False, f"Failed case {expr} != {expected}"))
                return
        q.put((True, "ok"))
    except:
        q.put((False, traceback.format_exc()))

def run_unit_tests(func_code, test_cases):
    try:
        local_env = {}
        exec(func_code, {}, local_env)
        # Get the first function defined
        func = [v for v in local_env.values() if callable(v)][0]
        for expr, expected in test_cases:
            if eval(expr, {}, local_env) != expected:
                return False, f"Failed case {expr} != {expected}"
        return True, "ok"
    except Exception as e:
        return False, str(e)



In [9]:
from collections import Counter

def generate_answers(prompt, n=5, max_new_tokens=120, temp=0.3):
    res = generator(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temp,
        num_return_sequences=n,
        num_beams=1
    )
    return [r["generated_text"] for r in res]

def majority_vote(ans_list):
    cleaned = [a for a in ans_list if a is not None]
    if not cleaned: return None
    counts = Counter(cleaned)
    return counts.most_common(1)[0][0]


In [10]:
import csv
from tqdm import tqdm

math_results = []

for item in tqdm(dataset["math"]):
    prompt = cot_prompt_math(item["question"])
    
    # Greedy
    greedy_out = generator(prompt, max_new_tokens=100, do_sample=False)[0]["generated_text"]
    greedy_ans = extract_and_eval(greedy_out)
    
    # Self-consistency
    sampled_outs = generate_answers(prompt, n=5, max_new_tokens=100, temp=0.3)
    sampled_ans = [extract_and_eval(t) for t in sampled_outs]
    majority = majority_vote(sampled_ans)
    
    math_results.append({
        "id": item["id"],
        "question": item["question"],
        "truth": item["answer"],
        "greedy_text": greedy_out,
        "greedy": greedy_ans,
        "sampled_texts": sampled_outs,
        "sampled_answers": sampled_ans,
        "majority": majority,
        "greedy_correct": verify_math(greedy_ans, item["answer"]),
        "majority_correct": verify_math(majority, item["answer"])
    })

# Save CSV
with open("math_results.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=math_results[0].keys())
    writer.writeheader()
    writer.writerows(math_results)


prog_results = []

for item in tqdm(dataset["programming"]):
    prompt = code_prompt(item["task"])
    
    # Greedy
    greedy_out = generator(prompt, max_new_tokens=200, do_sample=False)[0]["generated_text"]
    greedy_func = extract_function(greedy_out)
    greedy_ok, greedy_msg = (False, None)
    if greedy_func:
        greedy_ok, greedy_msg = run_unit_tests(greedy_func, item["tests"])
    
    # Self-consistency
    sampled_outs = generate_answers(prompt, n=5, max_new_tokens=200, temp=0.7)
    sampled_funcs = [extract_function(t) for t in sampled_outs if extract_function(t)]
    passing_funcs = []
    for f in sampled_funcs:
        ok, msg = run_unit_tests(f, item["tests"])
        if ok: passing_funcs.append(f)
    
    majority_func = None
    majority_ok = False
    if passing_funcs:
        majority_func = Counter(passing_funcs).most_common(1)[0][0]
        majority_ok = True
    
    prog_results.append({
        "id": item["id"],
        "task": item["task"],
        "greedy_function": greedy_func,
        "greedy_ok": greedy_ok,
        "greedy_msg": greedy_msg,
        "sampled_texts": sampled_outs,
        "majority_function": majority_func,
        "majority_ok": majority_ok
    })

# Save CSV
with open("programming_results.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=prog_results[0].keys())
    writer.writeheader()
    writer.writerows(prog_results)


100%|██████████| 10/10 [05:08<00:00, 30.86s/it]
100%|██████████| 5/5 [1:04:40<00:00, 776.19s/it] 
