In [1]:
from tqdm import tqdm
from __future__ import annotations
from pathlib import Path
import os
import requests
import json
from typing import Any, Dict, Optional, List
from collections import Counter
import time
import re

In [2]:
class LLMClient:
    
    def __init__(
        self,
        api_key: str = "cse476",
        api_base: str = "http://10.4.58.53:41701/v1",
        model: str = "bens_model"
    ):

        self.api_key = api_key
        self.api_base = api_base
        self.model = model
        self.call_count = 0
    
    def call(
        self,
        prompt: str,
        system: str = "You are a helpful assistant.",
        temperature: float = 0.0,
        max_tokens: int = 128,
        timeout: int = 60
    ) -> Dict:
        self.call_count += 1
        
        url = f"{self.api_base}/chat/completions"
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": system},
                {"role": "user", "content": prompt}
            ],
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        
        try:
            resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
            status = resp.status_code
            
            if status == 200:
                data = resp.json()
                text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
                return {
                    "ok": True,
                    "text": text,
                    "status": status,
                    "error": None,
                    "raw": data
                }
            else:
                try:
                    err_text = resp.json()
                except Exception:
                    err_text = resp.text
                
                return {
                    "ok": False,
                    "text": None,
                    "status": status,
                    "error": str(err_text),
                    "raw": None
                }
                
        except requests.RequestException as e:
            return {
                "ok": False,
                "text": None,
                "status": -1,
                "error": str(e),
                "raw": None
            }
    
    def reset_counter(self):
        self.call_count = 0
    
    def get_call_count(self) -> int:
        return self.call_count

    def call_with_retry(
        self,
        prompt: str,
        system: str = "You are a helpful assistant.",
        temperature: float = 0.0,
        max_tokens: int = 128,
        timeout: int = 60,
        max_retries: int = 3
    ) -> Dict:
        
        last_result = None
        
        for attempt in range(max_retries):
            result = self.call(prompt, system, temperature, max_tokens, timeout)
            
            if result["ok"]:
                return result
            
            last_result = result
            
            if result["status"] == 400:
                return result
            
            if attempt < max_retries - 1:
                wait_time = (2 ** attempt)
                time.sleep(wait_time)
        
        return last_result

    def call_model_chat_completions(
        prompt: str,
        system: str = "You are a helpful assistant.",
        model: str = "bens_model",
        temperature: float = 0.0,
        timeout: int = 60
    ) -> Dict:
        client = LLMClient(model=model)
        return client.call(prompt, system, temperature, timeout=timeout)

In [3]:
if __name__ == "__main__":
    print("Testing LLM Client:")
    
    client = LLMClient()
    
    print("\nTest 1: Simple Math")
    result = client.call("What is 5 + 3? Answer with just the number.")
    
    if result["ok"]:
        print(f"Success! Answer: {result['text']}")
        print(f"   API calls made: {client.get_call_count()}")
    else:
        print(f"Failed: {result['error']}")

    print("\nTest 2: Common Sense")
    result = client.call(
        prompt="You place an ice cube in a glass of water and mark the water level. After the ice melts, does the water level rise, fall, or stay the same? ",
        system="Answer with exactly one of: 'rise', 'fall', 'stay the same'."
    )
    
    if result["ok"]:
        print(f"Success! Answer: {result['text']}")
        print(f"   API calls made: {client.get_call_count()}")
    else:
        print(f"Failed: {result['error']}")
        
    print(f"Total API calls: {client.get_call_count()}")

    print("\nTest 3: logic_race")
    result = client.call(
        prompt="In a race, you pass the person in second place. What position are you now in? ",
        system="Answer with a single word like 'first', 'second', 'third'."
    )
    
    if result["ok"]:
        print(f"Success! Answer: {result['text']}")
        print(f"   API calls made: {client.get_call_count()}")
    else:
        print(f"Failed: {result['error']}")
        
    print(f"Total API calls: {client.get_call_count()}")

Testing LLM Client:

Test 1: Simple Math
Success! Answer: 8
   API calls made: 1

Test 2: Common Sense
Success! Answer: stay the same
   API calls made: 2
Total API calls: 2

Test 3: logic_race
Success! Answer: second
   API calls made: 3
Total API calls: 3


In [4]:
class DataLoader:
    def __init__(self, data_path: str):
        self.data_path = data_path
        self.data = self._load_data()
        self.domain_stats = self._compute_stats()
    
    def _load_data(self) -> List[Dict]:
        with open(self.data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    
    def _compute_stats(self) -> Dict:
        domain_counts = Counter(item['domain'] for item in self.data)
        return dict(domain_counts)
    
    def get_all(self) -> List[Dict]:
        return self.data
    
    def get_by_domain(self, domain: str) -> List[Dict]:
        return [item for item in self.data if item['domain'] == domain]
    
    def get_item(self, index: int) -> Dict:
        return self.data[index]
    
    def get_stats(self) -> Dict:
        return {
            'total': len(self.data),
            'domains': self.domain_stats
        }

In [5]:
class Problem:       
    def __init__(self, data: Dict):
        self.input = data['input']
        self.expected_output = data.get('output', None)
        self.domain = data.get('domain', 'unknown')
    
    def __repr__(self):
        return f"Problem(domain='{self.domain}', input_len={len(self.input)})"
    
    def get_input(self) -> str:
        return self.input
    
    def get_expected_output(self) -> str:
        return self.expected_output if self.expected_output else ""
    
    def get_domain(self) -> str:
        return self.domain
    
    def has_expected_output(self) -> bool:
        return self.expected_output is not None

In [6]:
loader = DataLoader('data/cse476_final_project_dev_data.json')
print(f"Loaded {len(loader.get_all())} problems")

Loaded 1000 problems


In [7]:
class BaselineSolver:
    def __init__(self, client: LLMClient):
        self.client = client
        self.name = "Baseline (Direct)"
    
    def solve(self, problem: Problem) -> Dict:
        question = problem.get_input()
        
        system_message = (
            "You are a helpful assistant. "
            "Give ONLY the final answer - no explanation, no reasoning, no LaTeX. "
            "Just the answer value."
        )
        
        result = self.client.call(
            prompt=question + "\n\nAnswer with ONLY the final value, nothing else.",
            system=system_message,
            temperature=0.0,
            max_tokens=128
        )
        
        if result["ok"]:
            answer = result["text"].strip()
            answer = self._clean_answer(answer)
            return {
                "answer": answer,
                "success": True,
                "api_calls": 1,
                "method": self.name
            }
        else:
            return {
                "answer": None,
                "success": False,
                "api_calls": 1,
                "method": self.name,
                "error": result["error"]
            }
    
    def _clean_answer(self, answer: str) -> str:
        import re
        
        if answer.lower().startswith(('we are', 'the ', 'let ', 'given')):
            match = re.search(r'(?:is|=)\s*\{?\{?([^{}]+?)\}?\}?\s*\.?\s*$', answer)
            if match:
                answer = match.group(1).strip()
            else:
                match = re.search(r'\*\*([^*]+)\*\*', answer)
                if match:
                    answer = match.group(1).strip()
        
        answer = re.sub(r'\{\\{', '', answer)
        answer = re.sub(r'\\}\}', '', answer)
        answer = re.sub(r'\{\{', '', answer)
        answer = re.sub(r'\}\}', '', answer)
        answer = re.sub(r'(?<!\\)\{', '', answer)
        answer = re.sub(r'(?<!\\)\}', '', answer)
        answer = re.sub(r'\\[a-zA-Z]+', '', answer)
        answer = re.sub(r'\$+', '', answer)
        answer = re.sub(r'\^', '', answer)
        answer = re.sub(r'_', '', answer)
        answer = re.sub(r'\*+', '', answer)
        answer = re.sub(r'^answer:\s*', '', answer, flags=re.IGNORECASE)
        
        return answer.strip()

In [8]:
class ChainOfThoughtSolver:
    
    def __init__(self, client: LLMClient):
        self.client = client
        self.name = "Chain-of-Thought"
    
    def solve(self, problem: Problem) -> Dict:
        question = problem.get_input()
        
        cot_prompt = f"{question} Think step by step to solve this problem. Show your reasoning, then provide the final answer."
        
        system_message = (
            "You are a helpful assistant that solves problems step-by-step. "
            "Show your reasoning process clearly, then provide the final answer."
        )
        
        result = self.client.call(
            prompt=cot_prompt,
            system=system_message,
            temperature=0.0,
            max_tokens=512
        )
        
        if result["ok"]:
            full_response = result["text"].strip()
            final_answer = self._extract_answer(full_response)
            
            final_answer = self._normalize_answer_format(final_answer)
            
            return {
                "answer": final_answer,
                "reasoning": full_response,
                "success": True,
                "api_calls": 1,
                "method": self.name
            }
        else:
            return {
                "answer": None,
                "reasoning": None,
                "success": False,
                "api_calls": 1,
                "method": self.name,
                "error": result["error"]
            }

    def _normalize_answer_format(self, answer: str) -> str:
        if not answer:
            return answer
        
        answer_lower = answer.lower().strip()
        
        if 'water level' in answer_lower:
            match = re.search(r'water\s+level\s+(.+)', answer_lower)
            if match:
                answer_lower = match.group(1).strip()
        
        answer_lower = re.sub(r'\bstays\b', 'stay', answer_lower)
        answer_lower = re.sub(r'\brises\b', 'rise', answer_lower)
        answer_lower = re.sub(r'\bfalls\b', 'fall', answer_lower)
        answer_lower = re.sub(r'\bremains\b', 'remain', answer_lower)
        answer_lower = re.sub(r'\bremain(?:s)?\s+the\s+same\b', 'stay the same', answer_lower)
        answer_lower = re.sub(r'\s+place$', '', answer_lower)
        answer_lower = re.sub(r'\s+position$', '', answer_lower)
        answer_lower = re.sub(r'^the\s+', '', answer_lower)
        answer_lower = re.sub(r'^a\s+', '', answer_lower)
        
        return answer_lower.strip()
    
    def _extract_answer(self, reasoning: str) -> str:
        if any(keyword in reasoning for keyword in ['def ', 'function ', 'return ', 'class ']):
            code_block = self._extract_code_block(reasoning)
            if code_block and len(code_block) > 10:
                return code_block
        
        last_part = reasoning[-500:] if len(reasoning) > 500 else reasoning
        
        # PRIORITY 1: \boxed{answer} - LaTeX boxed answer
        boxed_match = re.search(r'\\boxed\{([^}]+)\}', reasoning)
        if boxed_match:
            answer = boxed_match.group(1).strip()
            return self._clean_answer(answer)
        
        # PRIORITY 2: Multiple Choice (A, B, C, D)
        mc_patterns = [
            r'[Aa]nswer\s*(?:is|:)\s*\(?([A-Da-d])\)?[\.\s]',
            r'[Cc]orrect\s*(?:answer|option|choice)\s*(?:is|:)\s*\(?([A-Da-d])\)?',
            r'\b([A-D])\s*(?:is correct|is the answer)',
            r'[Tt]he\s+answer\s+is\s+\(?([A-Da-d])\)?[\.\s]',
        ]
        
        for pattern in mc_patterns:
            mc_match = re.search(pattern, last_part)
            if mc_match:
                return mc_match.group(1).upper()
        
        # PRIORITY 3: Final Answer: X (most explicit)
        final_patterns = [
            r'[Ff]inal\s+[Aa]nswer\s*:?\s*\$?([^.\n]+)',
            r'[Tt]he\s+final\s+answer\s+is\s*:?\s*\$?([^.\n]+)',
            r'[Aa]nswer\s*:\s*\$?([^.\n]+)',
        ]
        
        for pattern in final_patterns:
            match = re.search(pattern, last_part)
            if match:
                answer = match.group(1).strip()
                answer = self._clean_answer(answer)
                if answer and len(answer) >= 1 and len(answer) <= 100:
                    if not answer.lower().startswith(('the ', 'therefore', 'so ', 'thus')):
                        return answer
        
        # PRIORITY 4: Look for "= X" pattern (math answers)
        equals_patterns = [
            r'=\s*\$?(-?[\d,]+\.?\d*)\$?\s*$',
            r'=\s*\$?(-?[\d,]+\.?\d*)\$?\s*(?:dollars|pounds|ounces|meters|cm|kg|hours|minutes)',
            r'=\s*\$?\$?(-?[\d,]+\.?\d*)',
        ]
        
        for pattern in equals_patterns:
            match = re.search(pattern, last_part, re.MULTILINE)
            if match:
                num = match.group(1).replace(',', '')
                return num
        
        # PRIORITY 5: Standalone number at end
        lines = [line.strip() for line in last_part.split('\n') if line.strip()]
        if lines:
            last_line = lines[-1]
            number_only = re.match(r'^\$?(-?[\d,]+\.?\d*)\$?$', last_line.strip())
            if number_only:
                return number_only.group(1).replace(',', '')
            
            clean_last = self._clean_answer(last_line)
            if clean_last and 1 <= len(clean_last) <= 30:
                if not clean_last.lower().startswith(('we ', 'the ', 'therefore', 'so ', 'thus', 'hence')):
                    return clean_last
        
        # PRIORITY 6: Position words
        position_words = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth']
        for word in position_words:
            if re.search(r'\b' + word + r'\b', last_part[-150:], re.IGNORECASE):
                return word
        
        # PRIORITY 7: "The answer is X" 
        answer_is = re.search(
            r'[Tt]he\s+answer\s+is\s*:?\s*\$?([^.\n]+?)(?:\.|$)',
            last_part
        )
        if answer_is:
            answer = answer_is.group(1).strip()
            answer = self._clean_answer(answer)
            if answer and 1 <= len(answer) <= 50:
                if not answer.lower().startswith(('the ', 'therefore', 'so ')):
                    return answer
        
        # PRIORITY 8: Extract any number from last 100 chars
        last_100 = last_part[-100:]
        numbers = re.findall(r'(-?[\d,]+\.?\d*)', last_100)
        if numbers:
            return numbers[-1].replace(',', '')
        
        # Get last meaningful line
        for line in reversed(lines):
            clean = self._clean_answer(line)
            if clean and 1 <= len(clean) <= 50:
                if not clean.lower().startswith(('we ', 'let ', 'the ', 'therefore', 'given')):
                    return clean
        
        return "Unable to extract"

    def _extract_code_block(self, reasoning: str) -> str:
        
        code_fence = re.search(r'```(?:python)?\s*\n(.*?)```', reasoning, re.DOTALL)
        if code_fence:
            code = code_fence.group(1).strip()
            if 'def ' in code or 'class ' in code:
                return code
        
        all_fences = re.findall(r'```(?:python)?\s*\n(.*?)```', reasoning, re.DOTALL)
        if all_fences:
            longest = max(all_fences, key=len)
            if 'def ' in longest:
                return longest.strip()
        
        func_pattern = r'(def\s+\w+\s*\([^)]*\)\s*:\s*(?:\n(?:[ \t]+.+)?)+)'
        functions = re.findall(func_pattern, reasoning)
        if functions:
            return functions[-1].strip()
        
        lines = reasoning.split('\n')
        code_lines = []
        in_function = False
        base_indent = 0
        
        for line in lines:
            stripped = line.strip()
            
            if stripped.startswith('def '):
                in_function = True
                base_indent = len(line) - len(line.lstrip())
                code_lines = [line]
            elif in_function:
                current_indent = len(line) - len(line.lstrip()) if line.strip() else base_indent + 4
                
                if current_indent > base_indent or stripped == '':
                    code_lines.append(line)
                elif stripped and current_indent <= base_indent:
                    break
        
        if code_lines:
            code = '\n'.join(code_lines).rstrip()
            if 'def ' in code and ':' in code:
                return code
        
        return None
        
        lines = reasoning.split('\n')
        code_lines = []
        in_code = False
        indent_level = 0
        
        for line in lines:
            if line.strip().startswith('def '):
                in_code = True
                indent_level = len(line) - len(line.lstrip())
                code_lines.append(line)
            elif in_code:
                if line.strip() == '':
                    code_lines.append(line)
                elif len(line) - len(line.lstrip()) > indent_level:
                    code_lines.append(line)
                else:
                    break
        
        if code_lines:
            code = '\n'.join(code_lines)
            if 'def ' in code and len(code) < 500:
                return code
        
        return None
    
    def _clean_answer(self, answer: str) -> str:
        if not answer:
            return ""
        
        answer = re.sub(r'\\boxed\{([^}]+)\}', r'\1', answer)
        answer = re.sub(r'\$+', '', answer)
        answer = re.sub(r'\\[a-zA-Z]+\{?', '', answer)
        answer = re.sub(r'\\', '', answer)
        answer = re.sub(r'\*+', '', answer)
        answer = re.sub(r'[{}()\[\]]', '', answer)
        answer = re.sub(r'\n+', ' ', answer)
        answer = re.sub(r'\s+', ' ', answer)
        answer = re.sub(r'^[.!?;:,\s]+', '', answer)
        answer = re.sub(r'[.!?;:,\s]+$', '', answer)
        answer = re.sub(r'^(?:you\s+)?(?:are\s+)?(?:now\s+)?(?:in\s+)?(?:the\s+)?', '', answer, flags=re.IGNORECASE)
        answer = re.sub(r'^(?:the|is|be|it|that)\s+', '', answer, flags=re.IGNORECASE)
        number_match = re.search(r'\b(\d+\.?\d*)\b', answer)
        if number_match and len(answer) > 10:
            return number_match.group(1)

        fraction_match = re.search(r'(\d+/\d+)', answer)
        if fraction_match:
            return fraction_match.group(1)
        
        sci_match = re.search(r'(-?\d+\.?\d*[eE][+-]?\d+)', answer)
        if sci_match:
            return sci_match.group(1)
        
        return answer.strip()

In [9]:
class SelfConsistencySolver:
    def __init__(self, client: LLMClient, num_samples: int = 5):
        self.client = client
        self.num_samples = num_samples
        self.name = "Self-Consistency"
    
    def solve(self, problem: Problem) -> Dict:
        question = problem.get_input()
        
        cot_prompt = f"{question} Think step by step to solve this problem. Show your reasoning, then provide the final answer."
        
        system_message = (
            "You are a helpful assistant that solves problems step-by-step. "
            "Show your reasoning process clearly, then provide the final answer."
        )
        
        all_answers = []
        all_reasoning = []
        successful_calls = 0
        
        for i in range(self.num_samples):
            result = self.client.call(
                prompt=cot_prompt,
                system=system_message,
                temperature=0.1,
                max_tokens=512
            )
            
            if result["ok"]:
                full_response = result["text"].strip()
                answer = self._extract_answer(full_response)
                
                answer = self._normalize_answer_format(answer)
                
                all_answers.append(answer)
                all_reasoning.append(full_response)
                successful_calls += 1
            else:
                all_answers.append(None)
                all_reasoning.append(None)
        
        if successful_calls == 0:
            return {
                "answer": None,
                "all_answers": all_answers,
                "vote_counts": {},
                "all_reasoning": all_reasoning,
                "success": False,
                "api_calls": self.num_samples,
                "method": self.name,
                "error": "All API calls failed"
            }
        
        valid_answers = [a for a in all_answers if a is not None]
        normalized_answers = [self._normalize_answer(a) for a in valid_answers]
        vote_counts = Counter(normalized_answers)
        
        if vote_counts:
            majority_answer = vote_counts.most_common(1)[0][0]
            confidence = vote_counts[majority_answer] / len(valid_answers)
        else:
            majority_answer = None
            confidence = 0.0
        
        return {
            "answer": majority_answer,
            "all_answers": valid_answers,
            "vote_counts": dict(vote_counts),
            "confidence": confidence,
            "all_reasoning": [r for r in all_reasoning if r is not None],
            "success": True,
            "api_calls": self.num_samples,
            "method": self.name
        }
    
    def _extract_answer(self, reasoning: str) -> str:
        
        # PRIORITY 0: For CODING questions - return the code block
        if any(keyword in reasoning for keyword in ['def ', 'function ', 'return ', 'class ']):
            code_block = self._extract_code_block(reasoning)
            if code_block and len(code_block) > 10:
                return code_block
        
        last_part = reasoning[-500:] if len(reasoning) > 500 else reasoning
        
        # PRIORITY 1: \boxed{answer} - LaTeX boxed answer
        boxed_match = re.search(r'\\boxed\{([^}]+)\}', reasoning)
        if boxed_match:
            answer = boxed_match.group(1).strip()
            return self._clean_answer(answer)
        
        # PRIORITY 2: Multiple Choice (A, B, C, D)
        mc_patterns = [
            r'[Aa]nswer\s*(?:is|:)\s*\(?([A-Da-d])\)?[\.\s]',
            r'[Cc]orrect\s*(?:answer|option|choice)\s*(?:is|:)\s*\(?([A-Da-d])\)?',
            r'\b([A-D])\s*(?:is correct|is the answer)',
            r'[Tt]he\s+answer\s+is\s+\(?([A-Da-d])\)?[\.\s]',
        ]
        
        for pattern in mc_patterns:
            mc_match = re.search(pattern, last_part)
            if mc_match:
                return mc_match.group(1).upper()
        
        # PRIORITY 3: Final Answer: X (most explicit)
        final_patterns = [
            r'[Ff]inal\s+[Aa]nswer\s*:?\s*\$?([^.\n]+)',
            r'[Tt]he\s+final\s+answer\s+is\s*:?\s*\$?([^.\n]+)',
            r'[Aa]nswer\s*:\s*\$?([^.\n]+)',
        ]
        
        for pattern in final_patterns:
            match = re.search(pattern, last_part)
            if match:
                answer = match.group(1).strip()
                answer = self._clean_answer(answer)
                if answer and len(answer) >= 1 and len(answer) <= 100:
                    if not answer.lower().startswith(('the ', 'therefore', 'so ', 'thus')):
                        return answer
        
        # PRIORITY 4: Look for "= X" pattern (math answers)
        equals_patterns = [
            r'=\s*\$?(-?[\d,]+\.?\d*)\$?\s*$',
            r'=\s*\$?(-?[\d,]+\.?\d*)\$?\s*(?:dollars|pounds|ounces|meters|cm|kg|hours|minutes)',
            r'=\s*\$?\$?(-?[\d,]+\.?\d*)',
        ]
        
        for pattern in equals_patterns:
            match = re.search(pattern, last_part, re.MULTILINE)
            if match:
                num = match.group(1).replace(',', '')
                return num
        
        # PRIORITY 5: Standalone number at end
        lines = [line.strip() for line in last_part.split('\n') if line.strip()]
        if lines:
            last_line = lines[-1]
            number_only = re.match(r'^\$?(-?[\d,]+\.?\d*)\$?$', last_line.strip())
            if number_only:
                return number_only.group(1).replace(',', '')
            
            clean_last = self._clean_answer(last_line)
            if clean_last and 1 <= len(clean_last) <= 30:
                if not clean_last.lower().startswith(('we ', 'the ', 'therefore', 'so ', 'thus', 'hence')):
                    return clean_last
        
        # PRIORITY 6: Position words
        position_words = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth']
        for word in position_words:
            if re.search(r'\b' + word + r'\b', last_part[-150:], re.IGNORECASE):
                return word
        
        # PRIORITY 7: "The answer is X" 
        answer_is = re.search(
            r'[Tt]he\s+answer\s+is\s*:?\s*\$?([^.\n]+?)(?:\.|$)',
            last_part
        )
        if answer_is:
            answer = answer_is.group(1).strip()
            answer = self._clean_answer(answer)
            if answer and 1 <= len(answer) <= 50:
                if not answer.lower().startswith(('the ', 'therefore', 'so ')):
                    return answer
        
        # PRIORITY 8: Extract any number from last 100 chars
        last_100 = last_part[-100:]
        numbers = re.findall(r'(-?[\d,]+\.?\d*)', last_100)
        if numbers:
            return numbers[-1].replace(',', '')
        
        for line in reversed(lines):
            clean = self._clean_answer(line)
            if clean and 1 <= len(clean) <= 50:
                if not clean.lower().startswith(('we ', 'let ', 'the ', 'therefore', 'given')):
                    return clean
        
        return "Unable to extract"

    def _extract_code_block(self, reasoning: str) -> str:
        
        code_fence = re.search(r'```(?:python)?\s*\n(.*?)```', reasoning, re.DOTALL)
        if code_fence:
            code = code_fence.group(1).strip()
            if 'def ' in code or 'class ' in code:
                return code
        
        all_fences = re.findall(r'```(?:python)?\s*\n(.*?)```', reasoning, re.DOTALL)
        if all_fences:
            longest = max(all_fences, key=len)
            if 'def ' in longest:
                return longest.strip()
        
        func_pattern = r'(def\s+\w+\s*\([^)]*\)\s*:\s*(?:\n(?:[ \t]+.+)?)+)'
        functions = re.findall(func_pattern, reasoning)
        if functions:
            return functions[-1].strip()
        
        lines = reasoning.split('\n')
        code_lines = []
        in_function = False
        base_indent = 0
        
        for line in lines:
            stripped = line.strip()
            
            if stripped.startswith('def '):
                in_function = True
                base_indent = len(line) - len(line.lstrip())
                code_lines = [line]
            elif in_function:
                current_indent = len(line) - len(line.lstrip()) if line.strip() else base_indent + 4
                
                if current_indent > base_indent or stripped == '':
                    code_lines.append(line)
                elif stripped and current_indent <= base_indent:
                    break
        
        if code_lines:
            code = '\n'.join(code_lines).rstrip()
            if 'def ' in code and ':' in code:
                return code
        
        return None
        
        lines = reasoning.split('\n')
        code_lines = []
        in_code = False
        indent_level = 0
        
        for line in lines:
            if line.strip().startswith('def '):
                in_code = True
                indent_level = len(line) - len(line.lstrip())
                code_lines.append(line)
            elif in_code:
                if line.strip() == '':
                    code_lines.append(line)
                elif len(line) - len(line.lstrip()) > indent_level:
                    code_lines.append(line)
                else:
                    break
        
        if code_lines:
            code = '\n'.join(code_lines)
            if 'def ' in code and len(code) < 500:
                return code
        
        return None
    
    def _clean_answer(self, answer: str) -> str:
        if not answer:
            return ""
        answer = re.sub(r'\\boxed\{([^}]+)\}', r'\1', answer)
        answer = re.sub(r'\$+', '', answer)
        answer = re.sub(r'\\[a-zA-Z]+\{?', '', answer)
        answer = re.sub(r'\\', '', answer)
        answer = re.sub(r'\*+', '', answer)
        answer = re.sub(r'[{}()\[\]]', '', answer)
        answer = re.sub(r'\n+', ' ', answer)
        answer = re.sub(r'\s+', ' ', answer)
        answer = re.sub(r'^[.!?;:,\s]+', '', answer)
        answer = re.sub(r'[.!?;:,\s]+$', '', answer)
        answer = re.sub(r'^(?:you\s+)?(?:are\s+)?(?:now\s+)?(?:in\s+)?(?:the\s+)?', '', answer, flags=re.IGNORECASE)
        answer = re.sub(r'^(?:the|is|be|it|that)\s+', '', answer, flags=re.IGNORECASE)
        number_match = re.search(r'\b(\d+\.?\d*)\b', answer)
        if number_match and len(answer) > 10:
            return number_match.group(1)

        fraction_match = re.search(r'(\d+/\d+)', answer)
        if fraction_match:
            return fraction_match.group(1)
        
        sci_match = re.search(r'(-?\d+\.?\d*[eE][+-]?\d+)', answer)
        if sci_match:
            return sci_match.group(1)
        
        return answer.strip()
    
    def _normalize_answer_format(self, answer: str) -> str:
        if not answer:
            return answer
        
        answer_lower = answer.lower().strip()
        if 'water level' in answer_lower:
            match = re.search(r'water\s+level\s+(.+)', answer_lower)
            if match:
                answer_lower = match.group(1).strip()
        
        answer_lower = re.sub(r'\bstays\b', 'stay', answer_lower)
        answer_lower = re.sub(r'\brises\b', 'rise', answer_lower)
        answer_lower = re.sub(r'\bfalls\b', 'fall', answer_lower)
        answer_lower = re.sub(r'\bremains\b', 'remain', answer_lower)
        answer_lower = re.sub(r'\bremain(?:s)?\s+the\s+same\b', 'stay the same', answer_lower)
        answer_lower = re.sub(r'\s+place$', '', answer_lower)
        answer_lower = re.sub(r'\s+position$', '', answer_lower)
        answer_lower = re.sub(r'^the\s+', '', answer_lower)
        answer_lower = re.sub(r'^a\s+', '', answer_lower)
        
        return answer_lower.strip()
    
    def _normalize_answer(self, answer: str) -> str:
        if not answer:
            return ""
        
        normalized = answer.lower().strip()
        normalized = re.sub(r'\*\*', '', normalized)
        normalized = re.sub(r'\$', '', normalized)
        normalized = re.sub(r'\\boxed\{([^}]+)\}', r'\1', normalized)
        normalized = re.sub(r'\\', '', normalized)
        normalized = re.sub(r'[{}]', '', normalized)
        normalized = re.sub(r'[.!?]+$', '', normalized)
        normalized = re.sub(r'\s+', ' ', normalized)
        
        return normalized.strip()

In [10]:
client = LLMClient()
print("LLM Client created")

loader = DataLoader('data/cse476_final_project_dev_data.json')
print(f"Loaded {len(loader.get_all())} problems")

baseline = BaselineSolver(client)
cot = ChainOfThoughtSolver(client)
sc = SelfConsistencySolver(client, num_samples=5)

print("All solvers ready!")

stats = loader.get_stats()
print(f"\nDataset: {stats['total']} problems")
for domain, count in stats['domains'].items():
    print(f"  {domain}: {count}")

LLM Client created
Loaded 1000 problems
All solvers ready!

Dataset: 1000 problems
  math: 300
  coding: 100
  future_prediction: 100
  planning: 100
  common_sense: 400


In [11]:
test_problems = [
    {
        "input": "What is 17 + 28?",
        "output": "45",
        "domain": "math"
    },
    {
        "input": "In a race, you pass the person in second place. What position are you now in?",
        "output": "second",
        "domain": "logic"
    },
    {
        "input": "You place an ice cube in a glass of water. After the ice melts, does the water level rise, fall, or stay the same?",
        "output": "stay the same",
        "domain": "common_sense"
    }
]

print("="*70)
print("TESTING BASELINE SOLVER")
print("="*70)

client.reset_counter()

for prob_data in test_problems:
    problem = Problem(prob_data)
    result = baseline.solve(problem)
    
    match = result['answer'] == prob_data['output']
    
    print(f"\n {prob_data['domain'].upper()}")
    print(f"   Expected: '{prob_data['output']}'")
    print(f"   Got: '{result['answer']}'")

print(f"\nTotal API calls: {client.get_call_count()}")

print("="*70)
print("Testing CoT on all problem types:")
print("="*70)

for prob_data in test_problems:
    problem = Problem(prob_data)
    result = cot.solve(problem)
    
    match = result['answer'] == prob_data['output']
    
    print(f"\n {prob_data['domain'].upper()}")
    print(f"   Expected: '{prob_data['output']}'")
    print(f"   Got: '{result['answer']}'")

print(f"\nTotal API calls: {client.get_call_count()}")

print("\n" + "="*70)
print("TESTING SELF-CONSISTENCY SOLVER")
print("="*70)

for prob_data in test_problems:
    problem = Problem(prob_data)
    result = sc.solve(problem)
    
    match = result['answer'] == prob_data['output']
    
    print(f"\n {prob_data['domain'].upper()}")
    print(f"   Expected: '{prob_data['output']}'")
    print(f"   Got: '{result['answer']}'")
    print(f"   Confidence: {result['confidence']:.0%}")
    print(f"   All answers: {result['all_answers']}")

print(f"\nTotal API calls: {client.get_call_count()}")

TESTING BASELINE SOLVER

 MATH
   Expected: '45'
   Got: '45'

 LOGIC
   Expected: 'second'
   Got: 'second place'

 COMMON_SENSE
   Expected: 'stay the same'
   Got: 'stay the same'

Total API calls: 3
Testing CoT on all problem types:

 MATH
   Expected: '45'
   Got: '45'

 LOGIC
   Expected: 'second'
   Got: 'second'

 COMMON_SENSE
   Expected: 'stay the same'
   Got: 'stay the same'

Total API calls: 6

TESTING SELF-CONSISTENCY SOLVER

 MATH
   Expected: '45'
   Got: '45'
   Confidence: 100%
   All answers: ['45', '45', '45', '45', '45']

 LOGIC
   Expected: 'second'
   Got: 'second'
   Confidence: 80%
   All answers: ['second', 'second', 'first', 'second', 'second']

 COMMON_SENSE
   Expected: 'stay the same'
   Got: 'stay the same'
   Confidence: 100%
   All answers: ['stay the same', 'stay the same', 'stay the same', 'stay the same', 'stay the same']

Total API calls: 21


In [12]:
class IntelligentAgent:
    
    def __init__(self, baseline, cot, sc):
        self.baseline = baseline
        self.cot = cot
        self.sc = sc
    
    def _detect_domain(self, question: str) -> str:
        q = question.lower()
        
        # Multiple choice (A, B, C, D pattern) - simple factual
        if re.search(r'\bA\.\s*\w+.*\bB\.\s*\w+', question):
            return 'multiple_choice'
        
        # Knowledge retrieval - factual
        if 'context:' in q:
            return 'knowledge_retrieval'
        
        # Coding detection
        if any(kw in q for kw in ['write a function', 'write code', 'implement', 
                                   'algorithm', 'program that', 'def ', 'return the']):
            return 'coding'
        
        # Planning detection - complex multi-step reasoning
        if any(kw in q for kw in ['plan', 'schedule', 'steps to', 'how would you organize',
                                   'strategy', 'approach to']):
            return 'planning'
        
        # Math detection
        math_keywords = ['calculate', 'compute', 'solve', 'equation', 'sum', 'product', 
                         'how many', 'how much', 'total', 'percentage', 'percent', '%',
                         'divide', 'multiply', 'subtract', 'add']
        if any(kw in q for kw in math_keywords) or re.search(r'\d+\s*[\+\-\*\/\%]\s*\d+', q):
            return 'math'
        
        # Common sense 
        common_sense_keywords = ['what is', 'what are', 'who is', 'who was', 'where is',
                                  'when did', 'when was', 'which', 'name the', 'define',
                                  'true or false', 'yes or no', 'is it']
        if any(kw in q for kw in common_sense_keywords):
            return 'common_sense'
        
        return 'general'
    
    def solve(self, problem: Problem) -> Dict:
        question = problem.get_input()
        domain = self._detect_domain(question)
        
        # Strategy:
        # - Baseline (1 call): factual questions
        # - CoT (1 call): math, coding questions
        # - SC (3 calls): planning questions
        
        if domain in ['multiple_choice', 'knowledge_retrieval', 'common_sense']:
            return self.baseline.solve(problem)
        
        elif domain == 'planning':
            return self.sc.solve(problem)
        
        elif domain in ['math', 'coding']:
            return self.cot.solve(problem)
        
        else:
            return self.cot.solve(problem)

In [13]:
INPUT_PATH = Path("cse_476_final_project_test_data.json")
OUTPUT_PATH = Path("cse_476_final_project_answers.json")


def load_questions(path: Path) -> List[Dict[str, Any]]:
    with open(str(path), 'r', encoding='utf-8') as fp:
        data = json.load(fp)
    if not isinstance(data, list):
        raise ValueError("Input file must contain a list of question objects.")
    return data


def build_answers(questions: List[Dict[str, Any]], resume: bool = True) -> List[Dict[str, str]]:
    
    answers = []
    start_from = 0
    
    # Auto-resume from temp file if it exists
    if resume:
        try:
            with open('temp_answers.json', 'r', encoding='utf-8') as f:
                answers = json.load(f)
            start_from = len(answers)
            print(f"✓ Found temp_answers.json with {start_from} answers")
            print(f"✓ Resuming from question {start_from + 1}...")
        except FileNotFoundError:
            print("No temp file found, starting fresh...")
        except json.JSONDecodeError:
            print("Temp file corrupted, starting fresh...")
            answers = []
    
    if start_from >= len(questions):
        print(f"Already complete! {start_from} answers for {len(questions)} questions")
        return answers
    
    print("\nInitializing solvers...")
    client = LLMClient()
    baseline = BaselineSolver(client)
    cot = ChainOfThoughtSolver(client)
    sc = SelfConsistencySolver(client, num_samples=3)
    agent = IntelligentAgent(baseline, cot, sc)
    
    total = len(questions)
    remaining = total - start_from
    
    print(f"Processing {remaining} remaining questions (out of {total} total)...\n")
    
    # Create progress bar
    pbar = tqdm(
        enumerate(questions[start_from:], start=start_from + 1),
        total=remaining,
        desc="Processing",
        unit="q"
    )
    
    for idx, question_data in pbar:
        problem = Problem(question_data)
        
        try:
            result = agent.solve(problem)
            answer_text = result.get("answer", "Error")
            
            if answer_text is None or answer_text == "":
                answer_text = "Unable to determine answer"
            
            answer_text = str(answer_text)
            answer_text = post_process_answer(answer_text)
            
            if len(answer_text) > 4900:
                answer_text = answer_text[:4900]
        
        except Exception as e:
            tqdm.write(f"Error on question {idx}: {str(e)}")
            answer_text = "Error processing question"
        
        answers.append({"output": answer_text})
        
        # Update progress bar
        pbar.set_postfix({
            'API': client.get_call_count(),
            'avg': f"{client.get_call_count()/(idx - start_from):.1f}"
        })
        
        if idx % 50 == 0 or idx == total:
            with open('temp_answers.json', 'w', encoding='utf-8') as f:
                json.dump(answers, f, ensure_ascii=False, indent=2)
    
    pbar.close()
    
    print(f"\n{'='*50}")
    print(f"Total API calls: {client.get_call_count()}")
    print(f"Average per problem: {client.get_call_count()/remaining:.2f}")
    print(f"{'='*50}")
    
    return answers


def validate_results(questions: List[Dict[str, Any]], answers: List[Dict[str, Any]]) -> None:
    if len(questions) != len(answers):
        raise ValueError(
            f"Mismatched lengths: {len(questions)} questions vs {len(answers)} answers."
        )
    for idx, answer in enumerate(answers):
        if "output" not in answer:
            raise ValueError(f"Missing 'output' field for answer index {idx}.")
        if not isinstance(answer["output"], str):
            raise TypeError(
                f"Answer at index {idx} has non-string output: {type(answer['output'])}"
            )
        if len(answer["output"]) >= 5000:
            raise ValueError(
                f"Answer at index {idx} exceeds 5000 characters "
                f"({len(answer['output'])} chars)."
            )

def post_process_answer(answer: str) -> str:
    if answer.startswith('def ') or '\ndef ' in answer:
        if len(answer) > 1000:
            answer = answer[:1000]
        return answer
    
    answer = re.sub(r'\n+', ' ', answer)
    
    if '\\n' in answer:
        parts = answer.split('\\n')
        for part in reversed(parts):
            if part.strip():
                answer = part.strip()
                break
    
    answer = re.sub(r'\\[a-zA-Z]+', '', answer)
    answer = re.sub(r'\$+', '', answer)
    answer = re.sub(r'\\', '', answer)
    if answer.lower().startswith(("let's", "to solve", "step ", "first,", "we need")):
        number = re.search(r'\b(\d+\.?\d*)\b', answer)
        if number:
            return number.group(1)
    
    if len(answer) > 200:
        first_line = answer.split('\n')[0] if '\n' in answer else answer[:100]
        return first_line.strip()
    
    return answer.strip()

def main() -> None:
    print("="*70)
    print("CSE 476 Final Project - Answer Generation")
    print("="*70)
    
    try:
        questions = load_questions(INPUT_PATH)
        print(f"\nLoaded {len(questions)} questions from {INPUT_PATH}")
        
        answers = build_answers(questions, resume=True)
        
        print(f"\nWriting answers to {OUTPUT_PATH}...")
        with OUTPUT_PATH.open("w", encoding="utf-8") as fp:
            json.dump(answers, fp, ensure_ascii=False, indent=2)
        
        with OUTPUT_PATH.open("r", encoding="utf-8") as fp:
            saved_answers = json.load(fp)
        validate_results(questions, saved_answers)
        
        print(f"\n SUCCESS!")
        print(f"Wrote {len(answers)} answers to {OUTPUT_PATH}")
        print(f"Format validated successfully.")
        print("="*70)
        
        try:
            import os
            os.remove('temp_answers.json')
            print("Cleaned up temp_answers.json")
        except:
            pass
    
    except KeyboardInterrupt:
        print(f"\n\n Interrupted! Progress saved to temp_answers.json")
        print("Run again to resume from where you left off.")
    
    except Exception as e:
        print(f"\n ERROR: {str(e)}")
        print("Progress saved to temp_answers.json - run again to resume.")
        print("="*70)
        raise


if __name__ == "__main__":
    main()

CSE 476 Final Project - Answer Generation

Loaded 6208 questions from cse_476_final_project_test_data.json
✓ Found temp_answers.json with 6200 answers
✓ Resuming from question 6201...

Initializing solvers...
Processing 8 remaining questions (out of 6208 total)...



Processing: 100%|██████████| 8/8 [05:43<00:00, 42.93s/q, API=24, avg=3.0]


Total API calls: 24
Average per problem: 3.00

Writing answers to cse_476_final_project_answers.json...

 SUCCESS!
Wrote 6208 answers to cse_476_final_project_answers.json
Format validated successfully.
Cleaned up temp_answers.json



