# Wizard Coder Approach

Goal is to generate a breadth of solutions and common errors for problems that LLMs can solve but suffer from hallucinations (easy-medium difficulty)

1. Propose k different approaches to solve the problem (high level overview)
2. Generate complete solutions for k different approaches
   a. Execute code portion against test cases (only proceed if it passes all the cases)
3. For each validated solution:
   a. Break into step-by-step nodes
   b. Link the nodes into a linear trajectory
   c. Validate each step (marking which ones are executable)
4. Connect all trajectories to the problem root
5. For each trajectory:
   a. Identify key decision points
   b. Generate error branches at these points
   c. Connect error branches to the tree
   d. Follow branching factors to augment trajectories with valid variants (pass context of nodes on the same level to ensure uniqueness)
   e. Connect these branches to the tree
6. Create embeddings for all nodes


Note: this was the very first implementation I came up with (several versions came after this that I cannot share). With the goal of consistently identifying the solution branch/errors that the someone working through (generates an annotated reasoning tree to ground model for certain problems). 

In [259]:
def visualize_tree_advanced(root, indent=0, max_code_preview=50, show_code=True, show_metadata=True, depth_limit=None, use_colors=True):
    # Stop if we've reached the depth limit
    if depth_limit is not None and indent > depth_limit:
        return
    
    # Color codes
    if use_colors:
        BLUE = "\033[94m"
        GREEN = "\033[92m"
        YELLOW = "\033[93m"
        RED = "\033[91m"
        BOLD = "\033[1m"
        UNDERLINE = "\033[4m"
        END = "\033[0m"
    else:
        BLUE = GREEN = YELLOW = RED = BOLD = UNDERLINE = END = ""
    
    # Create indentation
    prefix = "  " * indent
    
    # Determine node type indicator and color
    if root.is_terminal:
        node_type = "🏁"  # Terminal node
        color = GREEN
    elif not root.children:
        node_type = "🔚"  # Leaf but not marked as terminal
        color = YELLOW
    elif root.is_correct is False:
        node_type = "❌"  # Error node
        color = RED
    elif root.step_number == 0:
        node_type = "🔍"  # Root node
        color = BLUE
    else:
        node_type = "✅" if root.is_correct else "⚪"  # Correct node or neutral
        color = GREEN if root.is_correct else BLUE
    
    # Format step description
    step_info = f"[Step {root.step_number}]" if root.step_number > 0 else "[ROOT]"
    step_desc = f"{step_info} {root.step_description}"
    if len(step_desc) > 80:
        step_desc = step_desc[:77] + "..."
    
    # Print node info
    print(f"{prefix}{node_type} {color}{BOLD}{step_desc}{END}")
    
    # Print code preview if available and requested
    if show_code and root.code_state and (indent > 0 or root.step_number > 0):
        code_preview = root.code_state.replace("\n", " ").strip()
        if len(code_preview) > max_code_preview:
            code_preview = code_preview[:max_code_preview] + "..."
        print(f"{prefix}   └─ {YELLOW}Code:{END} {code_preview}")
    
    # Print metadata for terminal nodes if requested
    if show_metadata and root.is_terminal:
        if root.time_complexity:
            print(f"{prefix}   └─ {BLUE}Complexity:{END} Time: {root.time_complexity} | Space: {root.space_complexity}")
        if root.data_structures:
            print(f"{prefix}   └─ {BLUE}Data structures:{END} {', '.join(root.data_structures)}")
        if root.concepts:
            print(f"{prefix}   └─ {BLUE}Concepts:{END} {', '.join(root.concepts)}")
    
    # Add node ID for debugging
    if show_metadata and root.id:
        print(f"{prefix}   └─ {BLUE}ID:{END} {root.id}")
    
    # Recursively print children
    for i, child in enumerate(root.children):
        is_last = (i == len(root.children) - 1)
        if is_last:
            print(f"{prefix}   └─{'─' * (4 if indent > 0 else 2)}")
        else:
            print(f"{prefix}   ├─{'─' * (4 if indent > 0 else 2)}")
        visualize_tree_advanced(child, indent + 1, max_code_preview, show_code, show_metadata, depth_limit, use_colors)

def visualize_tree_with_graphviz(
    root, 
    output_file='z_solution_tree', 
    format='svg',  # Changed default to svg 
    view=False,
    show_code=False,
    show_metadata=False,
    use_colors=True,
    max_node_width=50,
    rank_direction='TB',  # 'TB' (top-bottom), 'LR' (left-right)
    font_size=10,
    include_ids=False
):
    try:
        import graphviz
        import textwrap
    except ImportError:
        print("Required libraries not installed. Please install with:")
        print("pip install graphviz")
        print("Note: You also need the Graphviz executable installed on your system.")
        return None
    
    # Create a new directed graph
    dot = graphviz.Digraph(
        comment='Solution Tree',
        format=format,
        node_attr={
            'shape': 'box', 
            'style': 'filled',
            'fontname': 'Arial',
            'fontsize': str(font_size),
            'margin': '0.2,0.1',
            'width': '0',  # Auto-width
            'height': '0'  # Auto-height
        },
        edge_attr={'fontname': 'Arial', 'fontsize': str(font_size-1)}
    )
    
    # Set graph direction
    dot.attr(rankdir=rank_direction)
    
    # Helper function to wrap text to a certain width
    def wrap_text(text, width):
        return '\n'.join(textwrap.wrap(text, width=width))
    
    # Helper function to add nodes recursively
    def add_node_to_dot(node):
        # Determine node style based on node type and user preferences
        if use_colors:
            if node.step_number == 0:
                fillcolor = 'lightblue'
            elif node.is_terminal:
                fillcolor = 'palegreen'
            elif node.is_correct is False:
                fillcolor = 'salmon'
            elif hasattr(node, 'is_key_decision') and node.is_key_decision:
                fillcolor = 'moccasin'
            else:
                fillcolor = 'lightyellow'
        else:
            # Greyscale colors if use_colors is False
            if node.step_number == 0:
                fillcolor = 'lightgrey'
            elif node.is_terminal:
                fillcolor = 'white'
            elif node.is_correct is False:
                fillcolor = 'lightgrey'
            else:
                fillcolor = 'white'
        
        # Create node label
        label_parts = []
        label_parts.append(f"{node.id}\n")
        
        # Add node type header
        if node.step_number == 0:
            label_parts.append("ROOT")
        elif node.step_number == 1:
            label_parts.append(f"HEAD - Step 1 ({node.trajectory_approach})")
        elif node.is_terminal:
            label_parts.append(f"TERMINAL (Step {node.step_number})")
        elif node.is_correct is False:
            label_parts.append(f"ERROR (Step {node.step_number})")
        else:
            label_parts.append(f"Step {node.step_number}")
        
        # Add description (wrapped to max width)
        if node.step_description:
            label_parts.append(wrap_text(node.step_description, max_node_width))
        
        # Add code snippet if requested (for non-terminal nodes) or always for terminal nodes
        if node.code_state and (show_code or node.is_terminal):
            label_parts.append("CODE:\n" + node.code_state)
        
        # Add metadata if requested and available
        if show_metadata or node.is_terminal:  # Always show metadata for terminal nodes
            metadata_parts = []
            if node.time_complexity:
                metadata_parts.append(f"Time: {node.time_complexity}")
            if node.space_complexity:
                metadata_parts.append(f"Space: {node.space_complexity}")
            if node.data_structures:
                metadata_parts.append(f"Data structures: {', '.join(node.data_structures)}")
            if node.concepts:
                concepts_text = ", ".join(node.concepts)
                metadata_parts.append(f"Concepts: {wrap_text(concepts_text, max_node_width - 10)}")
            
            if metadata_parts:
                label_parts.append('\n'.join(metadata_parts))
        
        # Add node ID if requested (for debugging)
        if include_ids and node.id:
            label_parts.append(f"ID: {node.id}")
        
        # Join all parts with dividers
        label = '\n\n'.join(label_parts)
        
        # Add the node to the graph
        dot.node(node.id, label=label, fillcolor=fillcolor)
        
        # Add all children recursively
        for child in node.children:
            add_node_to_dot(child)
            dot.edge(node.id, child.id)
    
    # Build the graph
    add_node_to_dot(root)
    
    # Render the graph
    output_path = dot.render(filename=output_file, cleanup=True, view=view)
    
    return output_path



In [110]:
from getpass import getpass
from dotenv import load_dotenv
import os
from openai import OpenAI

load_dotenv()
groq_key = os.getenv("GROQ_KEY", "Empty")

model = getpass("Enter the model name: ")
api_endpoint = getpass("Enter the API endpoint (default: https://api.openai.com): ")
# Models: "llama3-8b-8192", "llama3-70b-8192"
# Endpoint: "https://api.groq.com/openai"

api_endpoint = api_endpoint if api_endpoint else "https://api.openai.com"
api_key = groq_key

openai_api_base = f"{api_endpoint}/v1"

print(f"Model: {model}")
print(f"API Endpoint: {api_endpoint}")
print(f"OpenAI API Base: {openai_api_base}")
if api_key == "Empty":
    print("No API key needed.")
else:
    print(f"API Key Set")

client = OpenAI( # Intialize OpenAI client
    api_key=api_key,
    base_url=openai_api_base
)

Model: llama3-70b-8192
API Endpoint: https://api.groq.com/openai
OpenAI API Base: https://api.groq.com/openai/v1
API Key Set


In [2]:
def chat_completion_request_openai(prompt): # function that will call the OpenAI API
    messages = [
        {"role": "user", "content": prompt}
    ]

    chat_response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=1.0,
        max_tokens=1500,
    )

    if chat_response.choices:
        completion_text = chat_response.choices[0].message.content
    else:
        completion_text = "No response from the model."
    return completion_text

In [None]:
class WizardNode:
    def __init__(self):
        # Core state representation
        self.trajectory_approach = ""
        self.step_description = ""        # Natural language description of this step
        self.code_state = ""              # Code at this step (cumulative, executable)
        self.step_number = 0              # Position in solution trajectory
        
        # Metadata
        self.time_complexity = ""         # e.g., "O(n)"
        self.space_complexity = ""        # e.g., "O(1)"
        self.data_structures = []         # e.g., ["array", "variable"]
        self.concepts = []                # e.g., ["iteration", "variable tracking"]
        
        # Tree structure
        self.parent = None                # Parent node
        self.children = []                # Next step options (both correct and incorrect)
        
        # Status
        self.is_correct = True            # Whether this step is on a correct solution path
        self.is_terminal = False          # Whether this is a final solution state
        self.error_type = None            # If incorrect, classification of error
        self.id = ""                      # Unique identifier
        
        # Embedding cache
        self.embeddings = {} 

        # Maybe progressively more specific hints

        # Add cumulativesteps (really important) --> all steps leading to this step (this will help expanding the tree quickly)
        """
        Step 1: Initialize two pointers, left and right, to the center of the palindrome.
        Step 2: Expand the pointers outward, checking if the characters at the pointers are the same and update the maximum length and the longest palindromic substring.
        Step 3: Repeat step 2 until the characters at the pointers are not the same or the pointers are out of the string bounds.
        """

In [None]:
import re, ast, concurrent.futures, traceback, json
from typing import List, Dict, Any
import importlib
import problems.wizard_coder_problems as wizard_coder_problems
importlib.reload(wizard_coder_problems) # Reload the module to ensure the latest version is used
from problems.wizard_coder_problems import problems
"""
- Root node: Higher branching (4-5) to capture main solution approaches (k-determined)
- Decision points: Medium branching (2-3) for important algorithm decision points
- Implementation details: Lower branching (1-2) for minor variations
- Error nodes: Limited branching (0-1) to avoid exponential growth (also limit branching depth here)
"""

def propose_best_approaches(question_text, k=4): # Get k-best approaches
    prompt = f"""
    You are an expert Python programmer analyzing the following coding problem:
    
    {question_text}
    
    Based on this problem, list exactly {k} distinct approaches to solve it.
    
    Each approach should be fundamentally different in its core strategy (not just implementation details).
    Include both optimal and interesting suboptimal approaches.
    
    Format your response as a simple comma-separated list of approach names, 
    with each name being just 1-4 words (e.g., "Hash Map", "Two Pointers", "Binary Search").
    Do not include numbering, explanations, or any other text.
    """
    
    response = chat_completion_request_openai(prompt)  # Get response from LLM
    print(response)
   
    approaches = [approach.strip() for approach in response.split(',')]  # Parse the response into a list of approach names
    
    unique_approaches = [] # Limit to k approaches and ensure no duplicates
    for approach in approaches:
        if approach not in unique_approaches:
            unique_approaches.append(approach)
            if len(unique_approaches) >= k:
                break
    
    return unique_approaches[:k]


def generate_one_shot_solutions(question_text, approaches):
    solutions = []
    
    for approach in approaches:
        prompt = f"""
        You are implementing a solution to the following problem:
        
        {question_text}
        
        Please implement a solution using the "{approach}" approach.
        
        Your response must include:
        1. Numbered steps (5-7 steps) explaining your approach
        2. Complete, executable Python code that follows the starter code format
        
        For the steps:
        - Each step must be EXACTLY ONE SENTENCE
        - Make each step clear and focused on one specific action or concept
        - Avoid explanations or reasoning - just state what is done in each step
        - Use simple, direct language
        
        For the code:
        - Ensure it's fully executable and handles all edge cases
        - Build from the starter code provided in the question
        - Follow good Python practices with proper indentation
        - Make sure it passes all the example test cases
        
        Format your response as follows:
        
        STEPS:
        Step 1: [One sentence description]
        Step 2: [One sentence description]
        ...
        Step 5: [One sentence description]
        
        CODE:
        ```python
        [Your complete solution code here, starting with the provided starter code]
        ```
        """
        
        response = chat_completion_request_openai(prompt)
        parsed_solution = parse_solution_response(response, approach)
        solutions.append(parsed_solution)
    
    return solutions


def parse_solution_response(response, approach):
    solution = {
        'approach': approach,
        'steps': [],
        'code': '',
    }
    
    # Extract steps
    steps_match = re.search(r'STEPS:(.*?)(?:CODE:|```python)', response, re.DOTALL)
    if steps_match:
        steps_text = steps_match.group(1).strip()
        step_pattern = r'Step\s+(\d+):\s+(.*?)(?=Step\s+\d+:|$)'
        steps = re.findall(step_pattern, steps_text, re.DOTALL)
        solution['steps'] = [step[1].strip() for step in steps]
    else:
        # Fallback extraction
        step_pattern = r'(?:Step|STEP)\s+(\d+)[:.]\s+(.*?)(?=(?:Step|STEP)\s+\d+[:."]|CODE:|```python|$)'
        steps = re.findall(step_pattern, response, re.DOTALL | re.IGNORECASE)
        solution['steps'] = [step[1].strip() for step in steps]
    
    # Clean steps to ensure one sentence each
    cleaned_steps = []
    for step in solution['steps']:
        # Take only the first sentence
        first_sentence = re.split(r'[.!?]', step)[0].strip()
        if first_sentence:
            # Add period if missing
            if not first_sentence.endswith(('.', '!', '?')):
                first_sentence += '.'
            cleaned_steps.append(first_sentence)
    
    solution['steps'] = cleaned_steps
    
    # Extract code
    code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
    if code_match:
        solution['code'] = code_match.group(1).strip()
    else:
        # Fallback extraction
        code_section = re.search(r'CODE:(.*?)(?=$)', response, re.DOTALL)
        if code_section:
            solution['code'] = code_section.group(1).strip()
    
    return solution


# Execute test_cases on code associated with each generated Solution (validate correctness)
def execute_test_cases(code: str, test_cases: list, function_name: str, print_cases: bool = True):
    local_vars = {}

    try:
        # Add 'List' to globals to ensure it is recognized
        from typing import List
        globals()["List"] = List
        # Execute the provided code to define the function in local_vars
        exec(code, globals(), local_vars)
    except Exception as e:
        if print_cases:
            print(f"Error in executing provided code: {e}")
            print(traceback.format_exc())
        return {
            "success": False,
            "error": f"Code execution error: {str(e)}",
            "details": traceback.format_exc(),
            "results": []
        }
    
    # Get the Solution class
    solution_class = local_vars.get("Solution")
    if not solution_class:
        if print_cases:
            print("No 'Solution' class found in the provided code.")
        return {
            "success": False,
            "error": "No 'Solution' class found in the provided code.",
            "results": []
        }

    solution_instance = solution_class() # Instantiate the Solution class

    # Get the function (method) by name
    if not hasattr(solution_instance, function_name):
        if print_cases:
            print(f"Function '{function_name}' not found in 'Solution' class.")
        return {
            "success": False,
            "error": f"Function '{function_name}' not found in 'Solution' class.",
            "results": []
        }
    
    func = getattr(solution_instance, function_name)

    results = []
    all_passed = True
    
    for i, case in enumerate(test_cases):
        test_input = case["input"]
        expected_outputs = case["output"]
        
        # Convert expected_outputs to a list if it's not already
        if not isinstance(expected_outputs, list):
            expected_outputs = [expected_outputs]

        try:
            # Set variables in function parameters
            exec(test_input, globals(), local_vars)
            input_vars = {}
            for line in test_input.strip().split('\n'):
                if '=' in line:
                    var_name = line.strip().split('=')[0].strip()
                    input_vars[var_name] = local_vars[var_name]

            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(func, **input_vars)
                try:
                    actual_output = future.result(timeout=5)  # 5 seconds timeout (detect infinite loops)
                    
                    # Check if the output matches any of the expected outputs
                    is_correct = False
                    for expected_output in expected_outputs:
                        try:
                            expected_obj = ast.literal_eval(expected_output) if isinstance(expected_output, str) else expected_output
                        except (ValueError, SyntaxError):
                            expected_obj = expected_output
                        
                        if str(actual_output) == str(expected_obj):
                            is_correct = True
                            break
                    
                    if print_cases:
                        print(f"Test case {i + 1}:")
                        print(f"Input: {test_input}")
                        print(f"Expected (any of): {expected_outputs}")
                        print(f"Actual: {actual_output}")
                        print(f"Result: {'✅ Pass' if is_correct else '❌ Fail'}")
                        print()

                    result = {
                        "test_id": i + 1,
                        "input": test_input,
                        "expected": expected_outputs,
                        "actual": str(actual_output),
                        "pass": is_correct
                    }
                    
                    if not is_correct:
                        all_passed = False
                    
                    results.append(result)
                    
                except concurrent.futures.TimeoutError:
                    if print_cases:
                        print(f"Test case {i + 1} timed out (possible infinite loop).")
                    all_passed = False
                    results.append({
                        "test_id": i + 1,
                        "input": test_input,
                        "expected": expected_outputs,
                        "actual": None,
                        "pass": False,
                        "error": "TimeoutError: Execution exceeded the time limit, likely infinite loop."
                    })

        except Exception as e:
            if print_cases:
                print(f"Error executing test case {i + 1}: {e}")
                print(traceback.format_exc())
            all_passed = False
            results.append({
                "test_id": i + 1,
                "input": test_input,
                "expected": expected_outputs,
                "actual": None,
                "pass": False,
                "error": str(e),
                "traceback": traceback.format_exc()
            })

    return {
        "success": all_passed,
        "passed_count": sum(1 for r in results if r.get("pass", False)),
        "total_count": len(test_cases),
        "results": results
    }




# -------------- Initial Tree Creation --------------


# Break down the solution into step-wise WizardNodes (and attatch these linear trajectories to the root node) --> this will be the baseline tree which will be expanded in the next step
def create_solution_tree(question_name, valid_solutions):
    root = WizardNode()
    root.step_description = f"Problem analysis for {question_name}"
    root.code_state = f"# Starting to solve {question_name}"
    root.step_number = 0
    root.id = f"{question_name}_root"
    
    print("\n\n\n\n----------------------")
    for solution_idx, solution in enumerate(valid_solutions):
        print(f"Processing solution {solution_idx + 1} for {question_name}")
        trajectory = break_down_solution(question_name, solution) # Create a trajectory for this solution
        
        if trajectory and trajectory.children: # Connect the first node of the trajectory to the root
            first_step = trajectory.children[0]
            first_step.parent = root
            root.children.append(first_step)
    
    return root

def break_down_solution(question_name, solution):
    approach = solution['approach']
    steps = solution['steps']
    final_code = solution['code']  # This is the validated code that passed test cases
    
    # Create a head node for this trajectory
    trajectory_head = WizardNode()
    trajectory_head.step_description = f"Solution approach: {approach}"
    trajectory_head.id = f"{question_name}_{approach.replace(' ', '_').lower()}_head"
    trajectory_head.trajectory_approach = approach
    
    # Skip if no steps
    if not steps:
        return trajectory_head
    
    # Generate code for each step
    step_codes = generate_step_by_step_code(final_code, steps)
    
    # Create a node for each step
    current_node = trajectory_head
    for step_idx, (step_desc, step_code) in enumerate(zip(steps, step_codes)):
        # Create a new node
        node = WizardNode()
        node.step_description = step_desc
        node.step_number = step_idx + 1
        node.id = f"{question_name}_{approach.replace(' ', '_').lower()}_step_{step_idx + 1}"
        node.trajectory_approach = approach
        
        # If this is the terminal node (last step), use the exact original validated code
        if step_idx == len(steps) - 1:
            node.code_state = final_code  # Use the original validated code
            node.is_terminal = True
            
            # Add metadata for terminal node
            node.time_complexity = extract_complexity(final_code, "time")
            node.space_complexity = extract_complexity(final_code, "space")
            node.data_structures = extract_data_structures(final_code)
            node.concepts = extract_concepts(final_code)
        else:
            # For non-terminal nodes, use the generated step code
            node.code_state = step_code
            node.is_terminal = False
        
        # Link to parent
        node.parent = current_node
        current_node.children.append(node)
        
        # Move to this node
        current_node = node
    
    return trajectory_head

def generate_step_by_step_code(final_code, steps): # Use LLM to generate code for each step
    prompt = f"""
    Given the following complete solution code and step descriptions, break down the code into incremental states 
    that correspond to each step. Each state should build upon the previous one and be fully executable if possible.
    
    The final solution code is:
    ```python
    {final_code}
    ```
    
    And the steps are:
    {steps}
    
    For each step, provide the cumulative code state at that point. The code should gradually evolve from the initial 
    structure to the final solution. Include the class and function definition in each step.
    
    Format your response as:
    
    STEP_1_CODE:
    ```python
    [Code state after step 1]
    ```
    
    STEP_2_CODE:
    ```python
    [Code state after step 2]
    ```
    
    And so on for all steps.
    """
    
    response = chat_completion_request_openai(prompt)
    
    # Parse the response to extract step codes
    step_codes = []
    for i in range(1, len(steps) + 1):
        marker = f"STEP_{i}_CODE:"
        next_marker = f"STEP_{i+1}_CODE:" if i < len(steps) else None
        
        start_idx = response.find(marker) + len(marker)
        end_idx = response.find(next_marker) if next_marker else len(response)
        
        if start_idx < len(marker):  # Marker not found
            # Fallback extraction using code blocks
            code_blocks = re.findall(r"```python(.*?)```", response, re.DOTALL)
            if i <= len(code_blocks):
                step_code = code_blocks[i-1].strip()
            else:
                # If all else fails, use the final code for all remaining steps
                step_code = final_code
        else:
            step_text = response[start_idx:end_idx].strip()
            code_match = re.search(r"```python(.*?)```", step_text, re.DOTALL)
            if code_match:
                step_code = code_match.group(1).strip()
            else:
                step_code = step_text.strip()
        
        step_codes.append(step_code)
    
    # Ensure we have the right number of steps
    while len(step_codes) < len(steps):
        step_codes.append(final_code)
    
    return step_codes

def extract_complexity(code, complexity_type):
    # Look for comments that mention complexity
    pattern = r"#.*" + complexity_type + r".*complexity.*?[oO]\(([^)]+)\)"
    match = re.search(pattern, code, re.IGNORECASE)
    if match:
        return f"O({match.group(1)})"
    
    # If not found in comments, try static analysis
    if complexity_type == "time":
        if "for" in code and "for" in code[code.find("for")+3:]:
            return "O(n²)"  # Nested loops
        elif "for" in code:
            return "O(n)"  # Single loop
        else:
            return "O(1)"  # No loops
    else:  # space complexity
        if "append" in code or "extend" in code or "dict" in code or "set" in code:
            return "O(n)"  # Growing data structures
        else:
            return "O(1)"  # Constant space
    
def extract_data_structures(code):
    data_structures = []
    
    if "[]" in code or "list" in code.lower():
        data_structures.append("list")
    if "{}" in code or "dict" in code.lower():
        data_structures.append("dictionary")
    if "set(" in code or "set" in code.lower():
        data_structures.append("set")
    if "tuple" in code.lower():
        data_structures.append("tuple")
    if "str" in code.lower() or "string" in code.lower():
        data_structures.append("string")
    
    return data_structures

def extract_concepts(code): # Make this more sophisticated
    concepts = []
    
    if "for" in code:
        concepts.append("iteration")
    if "while" in code:
        concepts.append("loops")
    if "if" in code:
        concepts.append("conditional logic")
    if code.count("def ") > 1:
        concepts.append("helper functions")
    if "return" in code:
        concepts.append("return values")
    if "==" in code or "!=" in code or ">" in code or "<" in code:
        concepts.append("comparison")
    if "+=" in code or "-=" in code:
        concepts.append("compound assignment")
    if "[" in code and ":" in code and "]" in code:
        concepts.append("slicing")
    if "try" in code and "except" in code:
        concepts.append("exception handling")
    if "class" in code and "self" in code:
        concepts.append("object-oriented programming")
    
    return concepts





# -------------- Generate Key Decision Points and Errors with associated Solution -------------- IS ACCEPTABLE FOR RIGHT NOW (need to reduce times misformatted JSON is returned)


def identify_key_decision_points(solution, question_name):
    approach = solution['approach']
    steps = solution['steps']
    code = solution['code']
    
    # Create the prompt in parts to avoid f-string formatting issues
    prompt_part1 = f"""
    I'm building a reasoning tree for teaching students how to solve the coding problem "{question_name}".
    
    Below is a {approach} solution broken down into steps:
    
    {steps}
    
    Final code:
    ```python
    {code}
    ```
    
    For this solution, identify which steps (by number) represent KEY DECISION POINTS where students often need to make important algorithmic or implementation choices.
    
    Key decision points are steps where:
    1. There are multiple valid ways to proceed
    2. Students commonly make conceptual errors
    3. The step involves a critical insight or technique
    4. The step requires careful implementation
    
    Include 2-3 key decision points (more for complex solutions, fewer for simple ones).
    
    Return your response as a valid JSON array of objects, where each object has:
    - "step": The step number (integer)
    - "reason": A brief explanation of why this is a key decision point (string)
    
    Example format:
    """
    
    # Example JSON with escaped braces (not using f-string)
    example_json = """
    {
        "key_points": [
            { "step": 2, "reason": "Choose the right initialization strategy" },
            { "step": 4, "reason": "This is where you decide how far to expand pointers" }
        ]
    }

    Remember: newlines MUST be escaped using \\n instead of \n to be valid JSON.
    """ 
    
    # Combine the parts
    prompt = prompt_part1 + example_json
    
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a JSON-only assistant. Your response must be a valid JSON object containing an array of key decision points."},
                {"role": "user", "content": prompt}
            ],
            response_format={"type": "json_object"}
        )
        
        # Parse the JSON response
        result = json.loads(response.choices[0].message.content)
        
        # The result might be wrapped in a container object, so handle both cases
        if isinstance(result, list):
            return result
        elif isinstance(result, dict) and "key_points" in result:
            return result.get("key_points", [])
        else:
            # Try to find any array in the response
            for key, value in result.items():
                if isinstance(value, list):
                    return value
            return []
    except Exception as e:
        print(f"Error getting response: {e}")
        return []

def generate_common_errors(solution, question_name):
    """
    Use LLM with JSON mode to generate common errors for a solution.
    """
    approach = solution['approach']
    steps = solution['steps']
    code = solution['code']
    
    # Create the prompt in parts to avoid f-string formatting issues
    prompt_part1 = f"""
    I'm building a reasoning tree for teaching students how to solve the coding problem "{question_name}".
    
    Below is a {approach} solution broken down into steps:
    
    {steps}
    
    Final code:
    ```python
    {code}
    ```
    
    For this solution, generate 3-4 common student errors that might occur at different steps.
    
    For each error:
    1. Identify which step number the error would occur at
    2. Provide a brief description of the error
    3. Categorize the error type (e.g., "off-by-one", "initialization error", "logic error")
    
    DO NOT include code snippets in your response, just focus on the concepts.
    
    Return your response as a valid JSON array of objects, where each object has:
    - "step": The step number (integer)
    - "description": A brief description of the error (string)
    - "type": The type of error (string)
    
    Example format:
    """
    
    # Example JSON with escaped braces (not using f-string)
    example_json = """
    {
        "common_errors": [
            { "step": 2, "description": "Initializing max_value to 0 instead of the first element", "type": "initialization error" },
            { "step": 3, "description": "Using incorrect loop bounds", "type": "off-by-one" }
        ]
    }

    Remember: newlines MUST be escaped using \\n instead of \n to be valid JSON.
    """ 
    
    # Combine the parts
    prompt = prompt_part1 + example_json
    
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a JSON-only assistant. Your response must be a valid JSON object containing an array of common errors."},
                {"role": "user", "content": prompt}
            ],
            response_format={"type": "json_object"}
        )
        
        # Parse the JSON response
        result = json.loads(response.choices[0].message.content)
        
        # The result might be wrapped in a container object, so handle both cases
        if isinstance(result, list):
            return result
        elif isinstance(result, dict) and "errors" in result:
            return result.get("errors", [])
        else:
            # Try to find any array in the response
            for key, value in result.items():
                if isinstance(value, list):
                    return value
            return []
    except Exception as e:
        print(f"Error getting response: {e}")
        return []
    





# -------------- Augment tree with generated key_decision_points and common_errors -------------- MISSING key decision_points logic


def augment_solution_tree(solution_tree_root, key_decision_points_by_approach, common_errors_by_approach, question_name):
    """
    Augment an existing solution tree with error branches at specific steps.
    
    Args:
        solution_tree_root (WizardNode): Root of the existing solution tree
        key_decision_points_by_approach (dict): Dictionary mapping approach names to key decision points
        common_errors_by_approach (dict): Dictionary mapping approach names to common errors
        question_name (str): Name of the problem
        
    Returns:
        WizardNode: Augmented tree with error branches
    """
    print("\nAugmenting solution tree with error branches...")
    
    # For each approach
    for approach, key_points in key_decision_points_by_approach.items():
        errors = common_errors_by_approach.get(approach, [])
        
        print(f"\nProcessing {approach} approach...")
        print(f"  Key Decision Points: {len(key_points)}")
        print(f"  Common Errors: {len(errors)}")
        
        # Find all nodes in the tree for this approach
        all_nodes = get_all_nodes(solution_tree_root)
        approach_nodes = [node for node in all_nodes if hasattr(node, 'trajectory_approach') 
                          and node.trajectory_approach == approach]
        
        # Group nodes by step number for easy lookup
        nodes_by_step = {}
        for node in approach_nodes:
            if node.step_number > 0:  # Skip root/head nodes
                nodes_by_step[node.step_number] = node
        
        # Mark key decision points
        for key_point in key_points:
            step_num = key_point.get("step")
            if step_num in nodes_by_step:
                node = nodes_by_step[step_num]
                # Dynamically Add metadata to mark as key decision point
                node.is_key_decision = True
                node.key_decision_reason = key_point.get("reason", "")
                print(f"  Marked step {step_num} as key decision point: {key_point.get('reason', '')}")
        
        # Add error branches
        for error in errors:
            step_num = error.get("step")
            if step_num not in nodes_by_step:
                print(f"  Warning: Step {step_num} not found for error: {error.get('description', '')}")
                continue
                
            target_node = nodes_by_step[step_num]
            
            # Ensure the target node has a parent
            if not target_node.parent:
                print(f"  Warning: Node for step {step_num} has no parent")
                continue
            
            # Create error node
            error_node = WizardNode()
            error_node.step_number = step_num
            error_node.step_description = error.get("description", "Common error")
            error_node.is_correct = False
            error_node.error_type = error.get("type", "logic error")
            error_node.trajectory_approach = approach
            
            # Copy the code from the target node and introduce the error
            error_node.code_state = generate_error_code(target_node.code_state, error)
            
            # Generate unique ID for the error node
            error_type_slug = error.get("type", "error").replace(" ", "_").lower()
            error_node.id = f"{question_name}_{approach.replace(' ', '_').lower()}_step_{step_num}_error_{error_type_slug}"
            
            # Add to parent (making it a sibling of the target node)
            error_node.parent = target_node.parent
            target_node.parent.children.append(error_node)
            
            print(f"  Added error branch for step {step_num}: {error.get('description', '')}")
    
    print("\nTree augmentation complete!")
    return solution_tree_root

def get_all_nodes(root): # Get all nodes in the tree using BFS
    if not root:
        return []
        
    all_nodes = [root]
    queue = [root]
    
    while queue:
        current = queue.pop(0)
        for child in current.children:
            all_nodes.append(child)
            queue.append(child)
    
    return all_nodes

def generate_error_code(original_code, error):
    error_type = error.get("type", "").lower()
    error_desc = error.get("description", "").lower()
    
    prompt = f"""
    You are introducing a specific error into correct code.
    
    Here is the correct code:
    ```python
    {original_code}
    ```
    
    Introduce this specific error:
    Type: {error_type}
    Description: {error_desc}
    
    Return ONLY the modified code with the error introduced. Do not include any explanations.
    """
    
    try:
        # Make API call to generate erroneous code
        response = client.chat.completions.create(
            model=model,  # Can use a smaller model to save costs
            messages=[
                {"role": "system", "content": "You are a code modification assistant that introduces specific errors."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2  # Low temperature for more deterministic results
        )
        
        # Get the code and clean it
        error_code = response.choices[0].message.content.strip()
        # Remove markdown code blocks if present
        error_code = re.sub(r'```python|```', '', error_code).strip()
        
        return error_code
    except Exception as e:
        print(f"Error generating error code: {e}")
        
        # Simple fallback based on error type
        if "off-by-one" in error_type:
            # Try to introduce an off-by-one error
            if "range" in original_code:
                return original_code.replace("range(len(", "range(len(").replace("range(", "range(", 1).replace(")", "-1)", 1)
            elif "<=" in original_code:
                return original_code.replace("<=", "<")
            elif ">=" in original_code:
                return original_code.replace(">=", ">")
        elif "initialization" in error_type:
            # Try to introduce an initialization error
            if "= 0" in original_code:
                return original_code.replace("= 0", "= 1")
            elif "= []" in original_code:
                return original_code.replace("= []", "= None")
            elif '= ""' in original_code:
                return original_code.replace('= ""', "= None")
        elif "logic" in error_type:
            # Try to introduce a logic error
            if "==" in original_code:
                return original_code.replace("==", "!=")
            elif "<" in original_code:
                return original_code.replace("<", ">")
            elif ">" in original_code:
                return original_code.replace(">", "<")
        
        # If all else fails, return original with a comment
        return original_code + f"\n# Error would be: {error_desc} (type: {error_type})"
    





# -------------- Create k-solutions rooted at key-decision points --------------



def find_parents_of_key_decision_points(tree_root):
    """
    Find all nodes that have a direct child marked as a key decision point.
    
    Args:
        tree_root (WizardNode): Root of the solution tree
        
    Returns:
        list: List of tuples (parent_node, key_decision_child)
    """
    # Get all nodes in the tree
    all_nodes = get_all_nodes(tree_root)
    
    # Find nodes that have a child marked as a key decision point
    parents_of_key_decisions = []
    for node in all_nodes:
        if hasattr(node, 'children') and node.children:
            # Check each child to see if it's a key decision point
            key_decision_children = []
            for child in node.children:
                if hasattr(child, 'is_key_decision') and child.is_key_decision:
                    key_decision_children.append(child)
            
            # If there's at least one key decision child, add to our list
            if key_decision_children:
                # Add the parent and its key decision children
                for key_child in key_decision_children:
                    parents_of_key_decisions.append((node, key_child))
    
    return parents_of_key_decisions

def print_key_decision_points_info(tree_root):
    """
    Print information about nodes that have children marked as key decision points.
    
    Args:
        tree_root (WizardNode): Root of the solution tree
        
    Returns:
        list: The identified parent nodes for later use
    """
    parents_of_key_decisions = find_parents_of_key_decision_points(tree_root)
    
    print(f"\nFound {len(parents_of_key_decisions)} parent nodes with children marked as key decision points:\n")
    
    for i, (parent, key_child) in enumerate(parents_of_key_decisions):
        # Print information about the parent
        print(f"Parent {i+1}:")
        print(f"  ID: {parent.id if hasattr(parent, 'id') else 'Unknown'}")
        print(f"  Step Number: {parent.step_number if hasattr(parent, 'step_number') else 'Unknown'}")
        print(f"  Description: {parent.step_description if hasattr(parent, 'step_description') else 'Unknown'}")
        print(f"  Approach: {parent.trajectory_approach if hasattr(parent, 'trajectory_approach') else 'Unknown'}")
        print(f"  Has code: {'Yes' if hasattr(parent, 'code_state') and parent.code_state else 'No'}")
        
        # Print information about the key decision point child
        print(f"\n  Key Decision Point Child:")
        print(f"    ID: {key_child.id if hasattr(key_child, 'id') else 'Unknown'}")
        print(f"    Step Number: {key_child.step_number if hasattr(key_child, 'step_number') else 'Unknown'}")
        print(f"    Description: {key_child.step_description if hasattr(key_child, 'step_description') else 'Unknown'}")
        print(f"    Reason: {key_child.key_decision_reason if hasattr(key_child, 'key_decision_reason') else 'Unknown'}")
        print(f"    Has code: {'Yes' if hasattr(key_child, 'code_state') and key_child.code_state else 'No'}")
        print()
        
        # Print brief info about siblings of the key decision point
        if hasattr(parent, 'children') and len(parent.children) > 1:
            siblings = [child for child in parent.children if child != key_child]
            print(f"  Siblings of Key Decision Point: {len(siblings)}")
            for j, sibling in enumerate(siblings):
                is_error = hasattr(sibling, 'is_correct') and sibling.is_correct is False
                print(f"    Sibling {j+1}: Step {sibling.step_number if hasattr(sibling, 'step_number') else 'Unknown'}, "
                      f"{'Error Node' if is_error else 'Normal Node'}")
        
        print("\n" + "-" * 80 + "\n")
    
    return parents_of_key_decisions

def build_cumulative_steps(tree_root, parents_list):
    """
    Build cumulative steps from root to each parent node of key decision points.
    
    This follows the trajectory_approach branch from root to the parent node,
    collecting steps along the way.
    
    Args:
        tree_root (WizardNode): Root of the solution tree
        parents_list (list): List of (parent, key_child) tuples from find_parents_of_key_decision_points
        
    Returns:
        list: Enhanced list of (parent, key_child, cumulative_steps) tuples
    """
    enhanced_parents = []
    
    for parent, key_child in parents_list:
        # Get the approach to follow
        approach = key_child.trajectory_approach if hasattr(key_child, 'trajectory_approach') else None
        
        if not approach:
            print(f"Warning: Cannot determine approach for key decision point {key_child.id if hasattr(key_child, 'id') else 'Unknown'}")
            # Add to list without trajectory
            enhanced_parents.append((parent, key_child, []))
            continue
        
        # Find the path from root to parent following this approach
        path = find_path_by_approach(tree_root, parent, approach)
        
        if not path:
            print(f"Warning: Could not find path from root to parent for approach {approach}")
            # Add to list without trajectory
            enhanced_parents.append((parent, key_child, []))
            continue
        
        # Collect cumulative steps
        cumulative_steps = []
        
        for node in path:
            if hasattr(node, 'step_description') and node.step_description:
                # Skip the root node description
                if node != tree_root and node.step_number > 0:
                    cumulative_steps.append(node.step_description)
        
        # Store cumulative steps on the parent node for later use
        parent.cumulative_steps = cumulative_steps
        
        # Add to enhanced list
        enhanced_parents.append((parent, key_child, cumulative_steps))
    
    return enhanced_parents

def find_path_by_approach(root, target_node, approach):
    """
    Find a path from root to target node following a specific approach.
    
    Args:
        root (WizardNode): The root node
        target_node (WizardNode): The target node to find
        approach (str): The approach to follow
        
    Returns:
        list: Path from root to target, or empty list if not found
    """
    # Use BFS to find the path
    queue = [(root, [root])]  # (node, path_so_far)
    visited = set()
    
    while queue:
        current, path = queue.pop(0)
        
        # If we found the target node
        if current == target_node:
            return path
        
        # Skip visited nodes
        if current in visited:
            continue
        
        visited.add(current)
        
        # Expand children, prioritizing those matching the approach
        if hasattr(current, 'children'):
            # Sort children by approach match (approach matches first)
            matched_children = []
            other_children = []
            
            for child in current.children:
                child_approach = child.trajectory_approach if hasattr(child, 'trajectory_approach') else None
                
                if child_approach and child_approach == approach:
                    matched_children.append(child)
                else:
                    other_children.append(child)
            
            # Process matched children first, then others
            for child in matched_children + other_children:
                if child not in visited:
                    queue.append((child, path + [child]))
    
    # Path not found
    return []

def print_cumulative_steps(enhanced_parents):
    """
    Print the cumulative steps for each parent of a key decision point.
    
    Args:
        enhanced_parents (list): List of enhanced parent tuples
        
    Returns:
        None
    """
    print("\nCumulative Steps from Root to Parents of Key Decision Points:\n")
    
    for i, (parent, key_child, steps) in enumerate(enhanced_parents):
        approach = key_child.trajectory_approach if hasattr(key_child, 'trajectory_approach') else "Unknown"
        
        print(f"Trajectory {i+1} (Approach: {approach}):")
        print(f"  Target: Step {parent.step_number if hasattr(parent, 'step_number') else 'Unknown'}, "
              f"Parent of Key Decision at Step {key_child.step_number if hasattr(key_child, 'step_number') else 'Unknown'}")
        
        print("\n  Cumulative Steps:")
        for j, step in enumerate(steps):
            print(f"    Step {j+1}: {step}")
        
        # Show a preview of the code state
        if hasattr(parent, 'code_state') and parent.code_state:
            code_lines = parent.code_state.split('\n')
            preview_lines = 5
            
            print(f"\n  Code State Preview ({len(code_lines)} lines):")
            
            if len(code_lines) > preview_lines:
                for line in code_lines[:preview_lines]:
                    print(f"    {line}")
                print(f"    ... ({len(code_lines) - preview_lines} more lines)")
            else:
                for line in code_lines:
                    print(f"    {line}")
        else:
            print("\n  No code state available")
        
        print("\n" + "-" * 80 + "\n")

def generate_alternative_solutions(enhanced_parents, question_name, k=1):
    problem_desc = ""
    if question_name in problems:
        problem_desc = problems[question_name]['description']
    
    results = []
    
    for parent, key_child, cumulative_steps in enhanced_parents:
        approach = key_child.trajectory_approach if hasattr(key_child, 'trajectory_approach') else "Unknown"
        current_code = parent.code_state if hasattr(parent, 'code_state') else ""
        
        # Get the key decision step number
        key_step = key_child.step_number if hasattr(key_child, 'step_number') else (parent.step_number + 1 if hasattr(parent, 'step_number') else 1)
        
        print(f"\nGenerating {k} alternative solutions for {approach} from step {key_step}...")
        
        alternative_solutions = []
        
        # Create alternative approach names
        alternative_approaches = []
        for i in range(k):
            alt_approach = f"{approach} Step {key_step} Variation {i+1}"
            alternative_approaches.append(alt_approach)
        
        # Generate solutions
        for alt_approach in alternative_approaches:
            # Create the prompt
            prompt = f"""
            You are implementing a solution to the following problem:
            
            {problem_desc}
            
            I've already implemented the solution up to step {key_step-1} using the {approach} approach.
            
            Previous steps:
            {cumulative_steps}
            
            Current code state:
            ```python
            {current_code}
            ```
            
            Please provide an ALTERNATIVE implementation for ONLY the remaining steps, starting from step {key_step}.
            Your implementation must build on the existing code - DO NOT modify or repeat the previous steps.
            
            Your response must include:
            1. Numbered steps (starting from step {key_step}) explaining your approach to complete the solution
            2. Complete, executable Python code for the FULL solution
            
            For the steps:
            - ONLY include steps from step {key_step} onwards (DO NOT repeat earlier steps)
            - Each step must be EXACTLY ONE SENTENCE
            - Make each step clear and focused on one specific action or concept
            - Avoid explanations or reasoning - just state what is done in each step
            - Use simple, direct language
            
            For the code:
            - Ensure it's fully executable and handles all edge cases
            - Build from the existing code state provided above
            - Follow good Python practices with proper indentation
            
            Format your response as follows:
            
            STEPS:
            Step {key_step}: [One sentence description]
            Step {key_step+1}: [One sentence description]
            Step {key_step+2}: [One sentence description]
            ...and so on with NO GAPS in step numbers
            
            CODE:
            ```python
            [Your complete solution code here, building on the existing code]
            ```
            """
            
            try:
                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {"role": "system", "content": "You are a programming expert that completes partial solutions, providing only the REMAINING steps."},
                        {"role": "user", "content": prompt}
                    ]
                )
                
                response_text = response.choices[0].message.content
                print(f"  Response for {alt_approach}: {response_text}...")
                
                # Custom parsing for steps starting from key_step
                solution = {
                    'approach': alt_approach,
                    'steps': [],
                    'original_steps': [],  # Store the original steps with their original numbering
                    'code': '',
                }
                
                # Extract steps (but only from key_step onwards)
                steps_match = re.search(r'STEPS:(.*?)(?:CODE:|```python)', response_text, re.DOTALL)
                if steps_match:
                    steps_text = steps_match.group(1).strip()
                    step_pattern = r'Step\s+(\d+):\s+(.*?)(?=Step\s+\d+:|$)'
                    steps = re.findall(step_pattern, steps_text, re.DOTALL)
                    
                    # Store the original steps with their original numbering
                    for step_num_str, step_desc in steps:
                        try:
                            step_num = int(step_num_str)
                            if step_num >= key_step:  # Only include steps from key_step onwards
                                # Process the step description (one sentence only)
                                first_sentence = re.split(r'[.!?]', step_desc.strip())[0].strip()
                                if first_sentence:
                                    if not first_sentence.endswith(('.', '!', '?')):
                                        first_sentence += '.'
                                    solution['original_steps'].append((step_num, first_sentence))
                                    solution['steps'].append(first_sentence)
                        except ValueError:
                            continue
                
                # Extract code
                code_match = re.search(r'```python(.*?)```', response_text, re.DOTALL)
                if code_match:
                    solution['code'] = code_match.group(1).strip()
                else:
                    code_section = re.search(r'CODE:(.*?)(?=$)', response_text, re.DOTALL)
                    if code_section:
                        solution['code'] = code_section.group(1).strip()
                
                # Validate the solution if test cases are available
                if solution['code'] and question_name in problems and 'test_cases' in problems[question_name]:
                    print(f"  Validating {alt_approach}...")
                    test_results = execute_test_cases(solution['code'], problems[question_name]['test_cases'], question_name, print_cases=False)
                    
                    solution['is_validated'] = test_results['success']
                    solution['percent_passed'] = test_results.get('passed_count', 0) / test_results.get('total_count', 1)
                    
                    print(f"  {alt_approach}: {'✅ Passed' if solution['is_validated'] else '❎ Partial Pass'} "
                          f"({int(solution['percent_passed'] * 100)}% of tests)")
                else:
                    # If no test cases or no code, assume it's invalid
                    solution['is_validated'] = False
                    solution['percent_passed'] = 0
                    print(f"  {alt_approach}: ❌ Failed (no valid code generated)")
                
                # Add to our list
                alternative_solutions.append(solution)
            except Exception as e:
                print(f"  Error generating solution for {alt_approach}: {e}")
        
        # Add results for this parent
        results.append((parent, key_child, alternative_solutions))
    
    return results

def print_alternative_solutions(alternatives_results):
    """
    Print the alternative solutions, preserving the original step numbering.
    """
    for i, (parent, key_child, alternative_solutions) in enumerate(alternatives_results):
        approach = key_child.trajectory_approach if hasattr(key_child, 'trajectory_approach') else "Unknown"
        
        print(f"\nAlternatives for {approach} (Parent of Key Decision Point):")
        
        for j, solution in enumerate(alternative_solutions):
            print(f"\nAlternative {j+1}: {solution['approach']}")
            print(f"  Validation: {'✅ Passed' if solution.get('is_validated', False) else '❎ Partial Pass'} "
                  f"({int(solution.get('percent_passed', 0) * 100)}% of tests)")
            
            # Use steps without original numbering for display
            print("\n  Steps:")
            if 'original_steps' in solution and solution['original_steps']:
                # Print steps with their original numbering
                for orig_step_num, step_text in solution['original_steps']:
                    print(f"    Step {orig_step_num}: {step_text}")
            else:
                # If no original numbering available, use sequential numbering
                for k, step in enumerate(solution['steps']):
                    print(f"    Step {k+1}: {step}")
            
            print("\n  Code:")
            code_lines = solution['code'].split('\n')
            preview_lines = min(10, len(code_lines))
            
            for line in code_lines[:preview_lines]:
                print(f"    {line}")
            
            if len(code_lines) > preview_lines:
                print(f"    ... ({len(code_lines) - preview_lines} more lines)")
            
            print("\n" + "-" * 40)





#  -------------- Breakdown those k-solutions into branches and Attach them to the Tree --------------


def attach_alternative_solutions(root, alternatives_results):
    """
    Break down alternative solutions into linear trajectories of WizardNodes
    and attach them to the parent node in the existing tree.
    
    Args:
        root: The root node of the existing solution tree
        alternatives_results: The list of tuples containing (parent_node, key_child, alternative_solutions)
    """
    for parent_node, key_child, alternative_solutions in alternatives_results:
        print(f"\nAttaching alternatives for {key_child.trajectory_approach} from step {key_child.step_number}")
        
        # Only process validated solutions
        valid_alternatives = [sol for sol in alternative_solutions if sol.get('is_validated', False)]
        
        if not valid_alternatives:
            print("  No valid alternatives found.")
            continue
            
        for alt_idx, solution in enumerate(valid_alternatives):
            print(f"  Processing alternative {alt_idx + 1}: {solution['approach']}")
            
            # Get the base approach name without the variation part
            base_approach = solution['approach'].split(' Variation ')[0]
            
            # Create trajectory for this alternative solution
            alt_trajectory = break_down_alternative_solution(parent_node, key_child, solution)
            
            if alt_trajectory and len(alt_trajectory) > 0:
                # Connect first node of trajectory to the parent node
                first_alt_node = alt_trajectory[0]
                first_alt_node.parent = parent_node
                parent_node.children.append(first_alt_node)
                
                print(f"  Successfully attached alternative solution {solution['approach']} with {len(alt_trajectory)} steps")
            else:
                print(f"  Failed to create trajectory for {solution['approach']}")

def break_down_alternative_solution(parent_node, key_child, solution):
    """
    Break down an alternative solution into a linear trajectory of WizardNodes.
    
    Args:
        parent_node: The parent node where the alternative solution will be attached
        key_child: The key child node that defines the step number where we start
        solution: The alternative solution dictionary
    
    Returns:
        A list of WizardNode objects representing the linear trajectory
    """
    approach = solution['approach']
    
    # Use original_steps if available, otherwise use steps
    if 'original_steps' in solution and solution['original_steps']:
        # Sort by step number to ensure we process in the correct order
        steps_with_numbers = solution['original_steps']
        steps_with_numbers.sort(key=lambda x: x[0])  # Sort by step number
        
        # Get just the step descriptions in the sorted order
        steps = [step_text for _, step_text in steps_with_numbers]
        
        # Get the starting step number (from the first step in original_steps)
        start_step_number = steps_with_numbers[0][0]
    else:
        steps = solution['steps']
        
        # If we don't have original step numbers, start from the key_child's step_number
        start_step_number = key_child.step_number if hasattr(key_child, 'step_number') else 1
    
    final_code = solution['code']  # This is the validated code that passed test cases
    
    # Skip if no steps
    if not steps:
        print(f"  No steps found for {approach}")
        return []
    
    # Get the question name from the parent node's ID
    question_name = parent_node.id.split('_')[0] if '_' in parent_node.id else "unknown"
    
    # Generate code for each step
    step_codes = generate_step_by_step_code_for_alternative_solutions(final_code, steps)
    
    # Create a node for each step
    trajectory_nodes = []
    
    for step_idx, (step_desc, step_code) in enumerate(zip(steps, step_codes)):
        # Calculate the actual step number for this node
        if 'original_steps' in solution and solution['original_steps']:
            # If we have original step numbers, use those
            actual_step_number = steps_with_numbers[step_idx][0]
        else:
            # Otherwise, just increment from the starting step number
            actual_step_number = start_step_number + step_idx
        
        # Create a new node
        node = WizardNode()
        node.step_description = step_desc
        node.step_number = actual_step_number
        node.id = f"{question_name}_{approach.replace(' ', '_').lower()}_step_{actual_step_number}"
        node.trajectory_approach = approach
        
        # If this is the terminal node (last step), use the exact original validated code
        if step_idx == len(steps) - 1:
            node.code_state = final_code  # Use the original validated code
            node.is_terminal = True
            
            # Add metadata for terminal node
            node.time_complexity = extract_complexity(final_code, "time")
            node.space_complexity = extract_complexity(final_code, "space")
            node.data_structures = extract_data_structures(final_code)
            node.concepts = extract_concepts(final_code)
        else:
            # For non-terminal nodes, use the generated step code
            node.code_state = step_code
            node.is_terminal = False
        
        # Link to previous node in trajectory if not the first node
        if trajectory_nodes:
            previous_node = trajectory_nodes[-1]
            node.parent = previous_node
            previous_node.children.append(node)
        
        # Add to trajectory
        trajectory_nodes.append(node)
    
    return trajectory_nodes

def generate_step_by_step_code_for_alternative_solutions(final_code, steps):
    """
    Generate intermediate code states for each step in the solution.
    This function needs to be implemented to break down the final code
    into intermediate stages matching each step description.
    
    Args:
        final_code: The complete validated code
        steps: List of step descriptions
    
    Returns:
        List of code states corresponding to each step
    """
    # For a simple implementation, we could just return the final code for each step
    # but ideally this would generate proper intermediate code states
    
    # Example implementation using LLM to generate intermediate code for each step
    step_codes = []
    accumulated_code = ""
    
    for i, step in enumerate(steps):
        # Here you would call your LLM to generate code for this step
        # based on the previous steps and the final solution
        
        # For now, let's create a simple placeholder implementation
        if i == len(steps) - 1:
            # For the last step, use the final code
            step_code = final_code
        else:
            # For intermediate steps, create a progressive version
            # In a real implementation, you'd use the LLM to generate this
            lines = final_code.split('\n')
            step_progress = (i + 1) / len(steps)
            num_lines = max(1, int(len(lines) * step_progress))
            step_code = '\n'.join(lines[:num_lines])
            
            # Add a comment about the current step
            step_code += f"\n# Step {i+1}: {step}"
        
        step_codes.append(step_code)
        accumulated_code = step_code
    
    return step_codes

def process_and_attach_alternative_solutions(root, question_name, enhanced_parents, k=1):
    """
    Generate alternative solutions, break them down into trajectories, and attach to tree.
    
    Args:
        root: The root node of the solution tree
        question_name: The name of the question being solved
        enhanced_parents: List of parent nodes with key decision points
        k: Number of alternative solutions to generate for each parent
    """
    # Generate alternative solutions
    alternatives_results = generate_alternative_solutions(enhanced_parents, question_name, k)
    
    # Attach them to the tree
    attach_alternative_solutions(root, alternatives_results)
    
    return root


In [None]:
import importlib
import problems.wizard_coder_problems as wizard_coder_problems
importlib.reload(wizard_coder_problems) # Reload the module to ensure the latest version is used
from problems.wizard_coder_problems import problems
from pprint import pprint
# twoSum, longestPalindrome

question_name = "longestPalindrome"
question = problems[question_name]

# ----------------- Propose and generate k-best approaches and solutions ----------------- 

k_approaches = propose_best_approaches(question['description'])
k_solutions = generate_one_shot_solutions(question['description'], k_approaches)

for i, solution in enumerate(k_solutions): # Print k-solutions
    print(f"Solution {i+1} using approach '{solution['approach']}':\n")
    for step in solution['steps']:
        print(f"Step {solution['steps'].index(step) + 1}: {step}")
    print(f"\n{solution['code']}")
    print("\n\n-------------------- Next Solution --------------------\n\n")

# -----------------  Execute test cases on each solution to validate them ----------------- 

for solution in k_solutions:
    print(f"Executing Test Cases for approach <{solution['approach']}> with following code:\n{solution['code']}\n")
    test_results = execute_test_cases(solution['code'], question['test_cases'], question_name)
    # Validate the solution with test results
    solution['is_validated'] = test_results['success'] # All test cases passed
    solution['percent_passed'] = test_results.get('passed_count', 0) / test_results.get('total_count', 1)

for solution in k_solutions:
    pprint(solution)

# -----------------  Create baseline solution tree -----------------

valid_solutions = [sol for sol in k_solutions if sol['is_validated']]
solution_tree_root = create_solution_tree(question_name, valid_solutions)


# Start expanding the base trajectories --> using branching_factors and erroneous branches to generate representative baseline tree


Expand Around Center, Two Pointers, Dynamic Programming, Manacher's Algorithm
Solution 1 using approach 'Expand Around Center':

Step 1: Initialize two pointers, left and right, to the center of the palindrome.
Step 2: Expand the pointers outward, checking if the characters at the pointers are the same and update the maximum length and the longest palindromic substring.
Step 3: Repeat step 2 until the characters at the pointers are not the same or the pointers are out of the string bounds.
Step 4: Repeat steps 1-3 for all possible centers of palindrome in the string.

class Solution:
    def longestPalindrome(self, s: str) -> str:
        def expand_around_center(s, left, right):
            while left >= 0 and right < len(s) and s[left] == s[right]:
                left -= 1
                right += 1
            return s[left + 1:right]
        
        longest = ""
        for i in range(len(s)):
            palindrome1 = expand_around_center(s, i, i)  # odd length
            palin

In [182]:
# Generate key decision points for the tree
# Generate common errors on the tree (and what steps they would occur at)

key_decision_points_by_approach = {}
common_errors_by_approach = {}

for solution in valid_solutions:
    approach = solution['approach']

    key_decision_points = identify_key_decision_points(solution, question_name)
    common_errors = generate_common_errors(solution, question_name)
    
    # Print the results
    print(f"Key Decision Points for {approach}:")
    pprint(key_decision_points)
    
    print(f"\nCommon Errors for {approach}:")
    pprint(common_errors)

    key_decision_points_by_approach[approach] = key_decision_points
    common_errors_by_approach[approach] = common_errors

    print("\n\n-------------------- Next Solution --------------------\n\n")

Key Decision Points for Expand Around Center:
[{'reason': 'Decide how to initialize pointers to the center of the '
            'palindrome, considering the possibility of odd-length '
            'palindromes.',
  'step': 1},
 {'reason': 'Determine how to expand the pointers outward, ensuring the '
            'correct updating of the maximum length and longest palindromic '
            'substring.',
  'step': 2},
 {'reason': 'Choose how to iterate through all possible centers of palindrome '
            'in the string, considering the need to handle both odd and even '
            'length palindromes.',
  'step': 4}]

Common Errors for Expand Around Center:
[{'description': 'Not considering even length palindromes',
  'step': 1,
  'type': 'logic error'},
 {'description': 'Not updating the maximum length correctly',
  'step': 2,
  'type': 'logic error'},
 {'description': 'Not iterating over all possible centers of palindrome',
  'step': 3,
  'type': 'logic error'}]


-----------------

In [281]:
# Augment the tree with common error branches AND mark nodes that are key decision points
import copy

solution_tree_copy_root = copy.deepcopy(solution_tree_root) # Make a deep copy of the solution tree to eliminate the need for generating a new tree

augmented_tree = augment_solution_tree(
    solution_tree_copy_root,  # root of the existing solution tree
    key_decision_points_by_approach,
    common_errors_by_approach,
    question_name
)

# Debug expanding 2 different ways --> common errors and key decision points (key decision points part should result in more correct branches (more representative), then common errors should capture the most common errors that students make)


Augmenting solution tree with error branches...

Processing Expand Around Center approach...
  Key Decision Points: 3
  Common Errors: 3
  Marked step 1 as key decision point: Decide how to initialize pointers to the center of the palindrome, considering the possibility of odd-length palindromes.
  Marked step 2 as key decision point: Determine how to expand the pointers outward, ensuring the correct updating of the maximum length and longest palindromic substring.
  Marked step 4 as key decision point: Choose how to iterate through all possible centers of palindrome in the string, considering the need to handle both odd and even length palindromes.
  Added error branch for step 1: Not considering even length palindromes
  Added error branch for step 2: Not updating the maximum length correctly
  Added error branch for step 3: Not iterating over all possible centers of palindrome

Processing Two Pointers approach...
  Key Decision Points: 3
  Common Errors: 3
  Marked step 2 as key de

In [282]:
# Augment the tree with key decision points branches
augmented_tree_copy_root = copy.deepcopy(augmented_tree) 

parents = print_key_decision_points_info(augmented_tree_copy_root) # Print the key decision points and their parents
enhanced_parents = build_cumulative_steps(augmented_tree_copy_root, parents) # Build the cumulative trajectories for each parent of a key decision point

# print_cumulative_steps(enhanced_parents) # Print the cumulative trajectories for each parent of a key decision point


Found 9 parent nodes with children marked as key decision points:

Parent 1:
  ID: longestPalindrome_root
  Step Number: 0
  Description: Problem analysis for longestPalindrome
  Approach: 
  Has code: Yes

  Key Decision Point Child:
    ID: longestPalindrome_expand_around_center_step_1
    Step Number: 1
    Description: Initialize two pointers, left and right, to the center of the palindrome.
    Reason: Decide how to initialize pointers to the center of the palindrome, considering the possibility of odd-length palindromes.
    Has code: Yes

  Siblings of Key Decision Point: 3
    Sibling 1: Step 1, Normal Node
    Sibling 2: Step 1, Normal Node
    Sibling 3: Step 1, Error Node

--------------------------------------------------------------------------------

Parent 2:
  ID: longestPalindrome_expand_around_center_step_1
  Step Number: 1
  Description: Initialize two pointers, left and right, to the center of the palindrome.
  Approach: Expand Around Center
  Has code: Yes

  Key 

In [290]:
alternative_solutions = generate_alternative_solutions(enhanced_parents, question_name, k=1)
print("\n\n\n\n-------------------- Printing the Alternative Solutions --------------------\n\n\n\n")
print_alternative_solutions(alternative_solutions) 


Generating 1 alternative solutions for Expand Around Center from step 1...
  Response for Expand Around Center Step 1 Variation 1: STEPS:
Step 1: Initialize two pointers, left and right, to the center of the palindrome.
Step 2: Expand the palindrome by moving the right pointer to the right as long as the characters at the left and right pointers are the same.
Step 4: Update the longest palindrome found so far if the current palindrome is longer.
Step 5: Repeat steps 1-4 for each character in the string as the center of the palindrome.

CODE:
```python
class Solution:
    def longestPalindrome(self, s: str) -> str:
        def expand_around_center(s, left, right):
            while left >= 0 and right < len(s) and s[left] == s[right]:
                left -= 1
                right += 1
            return s[left + 1:right]

        longest_palindrome = ""
        for i in range(len(s)):
            palindrome1 = expand_around_center(s, i, i)  # odd length palindrome
            palindr

In [291]:
# Breakdown alternative solutions into linear trajectories and attach them to the tree 
import copy

augmented_tree_with_key_points = copy.deepcopy(augmented_tree_copy_root) # Make a deep copy of the augmented tree to eliminate the need for generating a new tree
attach_alternative_solutions(augmented_tree_with_key_points, alternative_solutions)

# TODO: --> remove duplicate edges from the tree (or figure out how to copy the tree correctly), make sure that the key_point branches are being attached to the correct parent nodes, and that the code is being generated correctly for each step



Attaching alternatives for Expand Around Center from step 1
  Processing alternative 1: Expand Around Center Step 1 Variation 1
  Successfully attached alternative solution Expand Around Center Step 1 Variation 1 with 4 steps

Attaching alternatives for Expand Around Center from step 2
  Processing alternative 1: Expand Around Center Step 2 Variation 1
  Successfully attached alternative solution Expand Around Center Step 2 Variation 1 with 4 steps

Attaching alternatives for Two Pointers from step 2
  No valid alternatives found.

Attaching alternatives for Dynamic Programming from step 2
  No valid alternatives found.

Attaching alternatives for Two Pointers from step 3
  No valid alternatives found.

Attaching alternatives for Dynamic Programming from step 3
  Processing alternative 1: Dynamic Programming Step 3 Variation 1
  Successfully attached alternative solution Dynamic Programming Step 3 Variation 1 with 3 steps

Attaching alternatives for Expand Around Center from step 4
  

In [292]:
for i, solution in enumerate(alternative_solutions):
    pprint(solution)
    print("\n\n-------------------- Next Solution --------------------\n\n")

visualize_tree_with_graphviz(
    augmented_tree_with_key_points,
    output_file='z_detailed_tree_copy_key_points_detailed',
    show_code=True,
    show_metadata=True,
)

visualize_tree_with_graphviz(
    augmented_tree_with_key_points,
    output_file='z_detailed_tree_copy_key_points_simple',
    show_code=False,
    show_metadata=False,
)

(<__main__.WizardNode object at 0x119dd3810>,
 <__main__.WizardNode object at 0x119dd1090>,
 [{'approach': 'Expand Around Center Step 1 Variation 1',
   'code': 'class Solution:\n'
           '    def longestPalindrome(self, s: str) -> str:\n'
           '        def expand_around_center(s, left, right):\n'
           '            while left >= 0 and right < len(s) and s[left] == '
           's[right]:\n'
           '                left -= 1\n'
           '                right += 1\n'
           '            return s[left + 1:right]\n'
           '\n'
           '        longest_palindrome = ""\n'
           '        for i in range(len(s)):\n'
           '            palindrome1 = expand_around_center(s, i, i)  # odd '
           'length palindrome\n'
           '            palindrome2 = expand_around_center(s, i, i + 1)  # '
           'even length palindrome\n'
           '            if len(palindrome1) > len(longest_palindrome):\n'
           '                longest_palindrome

'z_detailed_tree_copy_key_points_simple.svg'

In [209]:
for solution in valid_solutions:
    print(f"Solution {solution['approach']}:\n")
    for step in solution['steps']:
        print(f"Step {solution['steps'].index(step) + 1}: {step}")
    print(f"\n{solution['code']}")
    print("\n\n-------------------- Next Solution --------------------\n\n")

Solution Expand Around Center:

Step 1: Initialize two pointers, left and right, to the center of the palindrome.
Step 2: Expand the pointers outward, checking if the characters at the pointers are the same and update the maximum length and the longest palindromic substring.
Step 3: Repeat step 2 until the characters at the pointers are not the same or the pointers are out of the string bounds.
Step 4: Repeat steps 1-3 for all possible centers of palindrome in the string.

class Solution:
    def longestPalindrome(self, s: str) -> str:
        def expand_around_center(s, left, right):
            while left >= 0 and right < len(s) and s[left] == s[right]:
                left -= 1
                right += 1
            return s[left + 1:right]
        
        longest = ""
        for i in range(len(s)):
            palindrome1 = expand_around_center(s, i, i)  # odd length
            palindrome2 = expand_around_center(s, i, i + 1)  # even length
            if len(palindrome1) > len(l

In [None]:
print(f"k_solutions: {len(k_solutions)}")
print(f"valid_solutions: {len(valid_solutions)}")

# Basic visualization
visualize_tree_with_graphviz(solution_tree_root)

# Detailed visualization with all information
visualize_tree_with_graphviz(
    solution_tree_root,
    output_file='z_detailed_tree',
    show_code=True,
    show_metadata=True,
)


k_solutions: 4
valid_solutions: 3


'z_detailed_tree.svg'

In [266]:
# Visualize copied tree

visualize_tree_with_graphviz(
    solution_tree_copy_root,
    output_file='z_detailed_tree_copy_errors',
    show_code=True,
    show_metadata=True,
)

'z_detailed_tree_copy_errors.svg'

In [143]:
print("\nAdvanced Tree Visualization:")
visualize_tree_advanced(
    solution_tree_root, 
    max_code_preview=100,
    show_code=True,
    show_metadata=True,
    depth_limit=None,  # Set a number to limit depth
    use_colors=True    # Set to False if terminal doesn't support colors
)


Advanced Tree Visualization:
🔍 [94m[1m[ROOT] Problem analysis for longestPalindrome[0m
   └─ [94mID:[0m longestPalindrome_root
   ├───
  ✅ [92m[1m[Step 1] Initialize two pointers, left and right, to the center of the palind...[0m
     └─ [93mCode:[0m ``` class Solution:     def longestPalindrome(self, s: str) -> str:         def expand_around_center...
     └─ [94mID:[0m longestPalindrome_expand_around_center_step_1
     └─────
    ✅ [92m[1m[Step 2] Expand the pointers outward, checking if the characters at the point...[0m
       └─ [93mCode:[0m class Solution:     def longestPalindrome(self, s: str) -> str:         def expand_around_center(s, ...
       └─ [94mID:[0m longestPalindrome_expand_around_center_step_2
       └─────
      ✅ [92m[1m[Step 3] Repeat step 2 until the characters at the pointers are not the same ...[0m
         └─ [93mCode:[0m class Solution:     def longestPalindrome(self, s: str) -> str:         def expand_around_center(s, ...
         └

In [None]:
for solution in k_solutions:
    print(solution['code'])
    print("\n\n\n---------------------\n\n\n")

# Issue --> terminal nodes have a refactored version of the validated code

for solution in k_solutions:
    print(f"Solution {solution['approach']}:")
    print(f"Validated: {solution['is_validated']}\n")
    for step in solution['steps']:
        print(f"Step {solution['steps'].index(step) + 1}: {step}")
    print("\n\n---------------------\n\n")

class Solution:
    def longestPalindrome(self, s: str) -> str:
        def expand_around_center(left, right):
            while left >= 0 and right < len(s) and s[left] == s[right]:
                left -= 1
                right += 1
            return s[left + 1:right]
        
        longest = ""
        for i in range(len(s)):
            palindrome1 = expand_around_center(i, i)  # for odd length palindrome
            palindrome2 = expand_around_center(i, i + 1)  # for even length palindrome
            if len(palindrome1) > len(longest):
                longest = palindrome1
            if len(palindrome2) > len(longest):
                longest = palindrome2
        return longest



---------------------



class Solution:
    def longestPalindrome(self, s: str) -> str:
        n = len(s)
        dp = [[False]*n for _ in range(n)]
        start, max_len = 0, 1
        
        for i in range(n-1, -1, -1):
            for j in range(i, n):
                if s[i] == s[j] and (