In [98]:
import ast
from dataclasses import dataclass
from typing import List, Tuple, Set, Dict
from collections import defaultdict



In [99]:
import ast
from dataclasses import dataclass
from typing import List, Set, Dict, Tuple

@dataclass(frozen=True)
class Operation:
    """Represents an incremental operation in the function composition"""
    input_expr: str
    output_expr: str
    function_str: str
    depth: int

    def __str__(self):
        return f"{self.input_expr} -> {self.output_expr}"

class FunctionDecomposer:
    def __init__(self, module_content: str):
        self.operations: Set[Operation] = set()
        self.variable_map: Dict[str, str] = {}
        self.counter = 0
        self.function_names = self._extract_function_names(module_content)
    
    def _extract_function_names(self, module_content: str) -> Set[str]:
        """Extract all function names defined in the module"""
        tree = ast.parse(module_content)
        return {node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)}
    
    def get_fresh_variable(self) -> str:
        """Generate a fresh variable name"""
        self.counter += 1
        return f"temp_{self.counter}"
    
    def is_function_ref(self, name: str) -> bool:
        """Check if a name refers to a function defined in the module"""
        return name in self.function_names
    
    def extract_operations(self, expr_str: str) -> List[Operation]:
        """Extract all incremental operations from a function expression"""
        tree = ast.parse(expr_str).body[0].value
        self.operations.clear()
        self.variable_map.clear()
        self.counter = 0
        self._process_node(tree, "I", 0)
        return sorted(self.operations, key=lambda x: (x.depth, x.input_expr))
    
    def _process_node(self, node: ast.AST, input_var: str, depth: int) -> Tuple[str, bool]:
        """
        Process an AST node and extract operations.
        Returns (expression_string, is_function_reference)
        """
        if isinstance(node, ast.Name):
            is_func = self.is_function_ref(node.id)
            return node.id, is_func
        
        if isinstance(node, ast.Call):
            func_name, _ = self._process_node(node.func, input_var, depth)
            
            # Process arguments
            processed_args = []
            for arg in node.args:
                arg_expr, is_func_ref = self._process_node(arg, input_var, depth + 1)
                
                # Only create variables for non-function, non-constant values
                if not (is_func_ref or arg_expr.isupper() or arg_expr == "I"):
                    var_name = self.get_fresh_variable()
                    self.variable_map[arg_expr] = var_name
                    # Create operation for the intermediate result
                    self.operations.add(Operation(
                        input_expr="I",
                        output_expr=arg_expr,
                        function_str=f"lambda x: {arg_expr}",
                        depth=depth + 1
                    ))
                    processed_args.append(var_name)
                else:
                    processed_args.append(arg_expr)
            
            # Create the function call string
            func_call = f"{func_name}({', '.join(processed_args)})"
            
            # Create operation for this function call
            self.operations.add(Operation(
                input_expr="I",
                output_expr=func_call,
                function_str=f"lambda x: {func_call}",
                depth=depth
            ))
            
            return func_call, False  # Function call is not a function reference
        
        return input_var, False
    def get_all_granularities(self) -> List[List[Operation]]:
        """
        Generate all possible granularities of the solution by merging operations.
        Returns a list of lists, where each inner list represents one possible granularity level
        ordered from finest (original operations) to coarsest (fully merged).
        """
        # First, build a dependency graph
        G = nx.DiGraph()
        var_to_op = {}  # Maps variable names to their generating operations
        
        # Add all operations as nodes
        for op in self.operations:
            G.add_node(op)
            
            # Extract output variable if this operation creates one
            if op.output_expr in self.variable_map:
                var_to_op[self.variable_map[op.output_expr]] = op
        
        # Add edges based on variable dependencies
        for op in self.operations:
            # Find all variables used in this operation's function string
            for var_name in self.variable_map.values():
                if var_name in op.function_str:
                    # Add edge from the operation that generates var_name to this operation
                    if var_name in var_to_op:
                        G.add_edge(var_to_op[var_name], op)
        
    def merge_operations(ops_to_merge: Set[Operation]) -> Operation:
        """Merge a set of operations into a single operation"""
        # Sort operations by dependency order
        sorted_ops = list(nx.topological_sort(G.subgraph(ops_to_merge)))
        
        # Create merged function string by substituting variables
        final_expr = sorted_ops[-1].output_expr
        substitutions = {}
        
        # Build substitutions map
        for op in sorted_ops[:-1]:
            if op.output_expr in self.variable_map:
                var_name = self.variable_map[op.output_expr]
                substitutions[var_name] = op.output_expr
        
        # Apply substitutions from deepest to shallowest
        merged_expr = final_expr
        while any(var in merged_expr for var in substitutions):
            for var, expr in substitutions.items():
                merged_expr = merged_expr.replace(var, f"({expr})")
        
        return Operation(
            input_expr="I",
            output_expr=merged_expr,
            function_str=f"lambda x: {merged_expr}",
            depth=min(op.depth for op in ops_to_merge)
        )
    
        def is_valid_merge(ops_to_merge: Set[Operation]) -> bool:
            """Check if a set of operations can be validly merged"""
            # Check if subgraph is connected
            subgraph = G.subgraph(ops_to_merge)
            return nx.is_weakly_connected(subgraph)
        
        def get_valid_merges(n: int) -> List[Set[Operation]]:
            """Get all valid combinations of n operations that can be merged"""
            valid_merges = []
            for ops in combinations(self.operations, n):
                ops_set = set(ops)
                if is_valid_merge(ops_set):
                    valid_merges.append(ops_set)
            return valid_merges
        
        # Generate all possible granularities
        all_granularities = []
        
        # Start with finest granularity (original operations)
        finest = sorted(self.operations, key=lambda x: (x.depth, x.output_expr))
        all_granularities.append(finest)
        
        # Generate coarser granularities by merging operations
        for size in range(2, len(self.operations) + 1):
            valid_merges = get_valid_merges(size)
            if not valid_merges:
                continue
                
            granularity = []
            used_ops = set()
            
            # For each valid merge
            for merge_group in valid_merges:
                # Skip if any operation in this group has been used
                if any(op in used_ops for op in merge_group):
                    continue
                    
                # Merge the operations
                merged_op = merge_operations(merge_group)
                granularity.append(merged_op)
                used_ops.update(merge_group)
                
            # Add remaining operations that weren't merged
            remaining_ops = [op for op in self.operations if op not in used_ops]
            granularity.extend(remaining_ops)
            
            # Sort operations at this granularity level
            granularity.sort(key=lambda x: (x.depth, x.output_expr))
            
            # Add this granularity level if it's different from previous ones
            if granularity and granularity not in all_granularities:
                all_granularities.append(granularity)
        
        return all_granularities

def generate_granular_problems(expr_str: str, module_content: str) -> Dict[str, Dict[str, str]]:
    """
    Generate problems at all possible granularities.
    Returns a dictionary mapping granularity level to a dictionary of problems.
    """
    decomposer = FunctionDecomposer(module_content)
    decomposer.extract_operations(expr_str)
    
    all_granularities = decomposer.get_all_granularities()
    print(all_granularities)
    problems = {}
    for i, granularity in enumerate(all_granularities):
        level_problems = {}
        for j, op in enumerate(granularity):
            problem_id = f"gran_{i}_step_{j:03d}"
            level_problems[problem_id] = op.function_str
        problems[f"level_{i}"] = level_problems
    
    return problems

def generate_incremental_problems(expr_str: str, module_content: str) -> Dict[str, str]:
    """Generate incremental problems from a function expression."""
    decomposer = FunctionDecomposer(module_content)
    operations = decomposer.extract_operations(expr_str)
    
    problems = {}
    for i, op in enumerate(operations):
        problem_id = f"inc_{i:03d}"
        problems[problem_id] = op.function_str
    
    return problems

# Example usage
# if __name__ == "__main__":
#     expr = "fill(I, THREE, insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))))"
    
#     # Read module content
#     with open("your_functions_file.py", "r") as f:
#         module_content = f.read()
    
#     print("Original expression:", expr)
#     print("\nGenerated Incremental Problems:")
#     problems = generate_incremental_problems(expr, module_content)
#     for problem_id, function_str in sorted(problems.items()):
#         print(f"{problem_id}: {function_str}")

In [100]:
import importlib.util
module_path = "dsl.py"
spec = importlib.util.spec_from_file_location("function_module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)


In [101]:
with open("dsl.py", "r") as f:
    module_content = f.read()
expr = "upscale(subtract(NINE, I), subtract(NINE, colorcount(I, ZERO)))"
problems = generate_incremental_problems(expr, module_content)
for problem_id, function_str in problems.items():
    print(f"{problem_id}: {function_str}")

inc_000: lambda x: upscale(temp_1, temp_3)
inc_001: lambda x: subtract(NINE, I)
inc_002: lambda x: subtract(NINE, temp_2)
inc_003: lambda x: colorcount(I, ZERO)


In [102]:
expr = "fill(fill(I, SEVEN, mapply(dneighbors, ofcolor(I, ONE))), FOUR, mapply(ineighbors, ofcolor(I, TWO)))"
problems = generate_incremental_problems(expr, module_content)
for problem_id, function_str in problems.items():
    print(f"{problem_id}: {function_str}")

inc_000: lambda x: fill(temp_3, FOUR, temp_5)
inc_001: lambda x: mapply(ineighbors, temp_4)
inc_002: lambda x: fill(I, SEVEN, temp_2)
inc_003: lambda x: mapply(dneighbors, temp_1)
inc_004: lambda x: ofcolor(I, TWO)
inc_005: lambda x: ofcolor(I, ONE)


In [103]:
expr = "fill(I, THREE, insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))))"
problems = generate_incremental_problems(expr, module_content)
for problem_id, function_str in problems.items():
    print(f"{problem_id}: {function_str}")

inc_000: lambda x: fill(I, THREE, temp_10)
inc_001: lambda x: insert(temp_4, temp_9)
inc_002: lambda x: halve(temp_3)
inc_003: lambda x: dneighbors(temp_8)
inc_004: lambda x: halve(temp_7)
inc_005: lambda x: apply_func(temp_1, temp_2)
inc_006: lambda x: apply_func(temp_5, temp_6)
inc_007: lambda x: fork(add, first, last)
inc_008: lambda x: ofcolor(I, ONE)
inc_009: lambda x: ofcolor(I, ONE)
inc_010: lambda x: fork(add, first, last)


In [104]:
from itertools import combinations
import networkx as nx
import ast
from dataclasses import dataclass
from typing import List, Set, Dict, Tuple
from itertools import combinations
import networkx as nx

@dataclass(frozen=True)
class Operation:
    """Represents an incremental operation in the function composition"""
    input_expr: str
    output_expr: str
    function_str: str
    depth: int

    def __str__(self):
        return f"{self.input_expr} -> {self.output_expr}"

class FunctionDecomposer:
    def __init__(self, module_content: str):
        self.operations: Set[Operation] = set()
        self.variable_map: Dict[str, str] = {}
        self.counter = 0
        self.function_names = self._extract_function_names(module_content)
    
    def _extract_function_names(self, module_content: str) -> Set[str]:
        """Extract all function names defined in the module"""
        tree = ast.parse(module_content)
        return {node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)}
    
    def get_fresh_variable(self) -> str:
        """Generate a fresh variable name"""
        self.counter += 1
        return f"temp_{self.counter}"
    
    def is_function_ref(self, name: str) -> bool:
        """Check if a name refers to a function defined in the module"""
        return name in self.function_names
    
    def extract_operations(self, expr_str: str) -> List[Operation]:
        """Extract all incremental operations from a function expression"""
        tree = ast.parse(expr_str).body[0].value
        self.operations.clear()
        self.variable_map.clear()
        self.counter = 0
        self._process_node(tree, "I", 0)
        return sorted(self.operations, key=lambda x: (x.depth, x.output_expr))
    
    def _process_node(self, node: ast.AST, input_var: str, depth: int) -> Tuple[str, bool]:
        """
        Process an AST node and extract operations.
        Returns (expression_string, is_function_reference)
        """
        if isinstance(node, ast.Name):
            is_func = self.is_function_ref(node.id)
            return node.id, is_func
        
        if isinstance(node, ast.Call):
            func_name, _ = self._process_node(node.func, input_var, depth)
            
            # Process arguments
            processed_args = []
            for arg in node.args:
                arg_expr, is_func_ref = self._process_node(arg, input_var, depth + 1)
                
                # Only create variables for non-function, non-constant values
                if not (is_func_ref or arg_expr.isupper() or arg_expr == "I"):
                    var_name = self.get_fresh_variable()
                    self.variable_map[arg_expr] = var_name
                    # Create operation for the intermediate result
                    self.operations.add(Operation(
                        input_expr="I",
                        output_expr=arg_expr,
                        function_str=f"lambda x: {arg_expr}",
                        depth=depth + 1
                    ))
                    processed_args.append(var_name)
                else:
                    processed_args.append(arg_expr)
            
            # Create the function call string
            func_call = f"{func_name}({', '.join(processed_args)})"
            
            # Create operation for this function call
            self.operations.add(Operation(
                input_expr="I",
                output_expr=func_call,
                function_str=f"lambda x: {func_call}",
                depth=depth
            ))
            
            return func_call, False  # Function call is not a function reference
        
        return input_var, False

    def _remove_unnecessary_parentheses(self, expr: str) -> str:
        """Remove unnecessary parentheses from expression"""
        stack = []
        indices_to_remove = set()
        last_open = -1
        
        # Find matching parentheses pairs
        for i, c in enumerate(expr):
            if c == '(':
                stack.append(i)
            elif c == ')' and stack:
                start = stack.pop()
                # Check if these parentheses are unnecessary
                inner = expr[start+1:i]
                if not inner.strip():  # Empty
                    indices_to_remove.add(start)
                    indices_to_remove.add(i)
                elif inner.count('(') == inner.count(')'):  # Balanced inner expression
                    if not any(c in inner for c in ', '): # No commas inside
                        # Check if inner contains only a simple expression
                        if not any(op in inner for op in ['+', '-', '*', '/', ' ']):
                            indices_to_remove.add(start)
                            indices_to_remove.add(i)
        
        # Rebuild string without unnecessary parentheses
        return ''.join(c for i, c in enumerate(expr) if i not in indices_to_remove)

    def _substitute_variables(self, expr: str, substitutions: Dict[str, str], seen: Set[str] = None) -> str:
        """
        Substitute variables in expression while avoiding circular dependencies.
        Returns cleaned-up expression with all possible substitutions made.
        """
        if seen is None:
            seen = set()
        
        # Find all variables that need to be substituted
        changes_made = True
        while changes_made:
            changes_made = False
            for var, sub_expr in substitutions.items():
                if var in expr and var not in seen:
                    seen.add(var)
                    # Recursively substitute in the replacement expression first
                    sub_expr = self._substitute_variables(sub_expr, substitutions, seen)
                    # Only add parentheses if the substituted expression contains operators
                    needs_parens = any(op in sub_expr for op in ['+', '-', '*', '/', ' ']) and not (sub_expr.startswith('(') and sub_expr.endswith(')'))
                    wrapped_expr = f"({sub_expr})" if needs_parens else sub_expr
                    expr = expr.replace(var, wrapped_expr)
                    changes_made = True
                    
        return self._remove_unnecessary_parentheses(expr)

    def get_all_granularities(self) -> List[List[Operation]]:
        """
        Generate all possible granularities of the solution by merging operations.
        Returns a list of lists, where each inner list represents one possible granularity level
        ordered from finest (original operations) to coarsest (fully merged).
        """
        # Build dependency graph
        G = nx.DiGraph()
        var_to_op = {}
        
        # Add all operations as nodes
        for op in self.operations:
            G.add_node(op)
            if op.output_expr in self.variable_map:
                var_to_op[self.variable_map[op.output_expr]] = op
        
        # Add edges based on variable dependencies
        for op in self.operations:
            for var_name in self.variable_map.values():
                if var_name in op.function_str:
                    if var_name in var_to_op:
                        G.add_edge(var_to_op[var_name], op)
        
        def merge_operations(ops_to_merge: Set[Operation]) -> Operation:
            """Merge a set of operations into a single operation"""
            sorted_ops = list(nx.topological_sort(G.subgraph(ops_to_merge)))
            
            # Build complete substitutions map
            substitutions = {}
            for op in sorted_ops:
                if op.output_expr in self.variable_map:
                    var_name = self.variable_map[op.output_expr]
                    substitutions[var_name] = op.output_expr
            
            # Get the final expression and substitute all variables
            final_expr = sorted_ops[-1].output_expr
            merged_expr = self._substitute_variables(final_expr, substitutions)
            
            return Operation(
                input_expr="I",
                output_expr=merged_expr,
                function_str=f"lambda x: {merged_expr}",
                depth=min(op.depth for op in ops_to_merge)
            )
        
        def merge_operations(ops_to_merge: Set[Operation]) -> Operation:
            """Merge a set of operations into a single operation"""
            # Sort operations by dependency order
            sorted_ops = list(nx.topological_sort(G.subgraph(ops_to_merge)))
            
            # Create merged function string by substituting variables
            final_expr = sorted_ops[-1].output_expr
            substitutions = {}
            
            # Build substitutions map
            for op in sorted_ops[:-1]:
                if op.output_expr in self.variable_map:
                    var_name = self.variable_map[op.output_expr]
                    substitutions[var_name] = op.output_expr
            
            # Apply substitutions from deepest to shallowest
            merged_expr = final_expr
            while any(var in merged_expr for var in substitutions):
                for var, expr in substitutions.items():
                    merged_expr = merged_expr.replace(var, f"({expr})")
            
            return Operation(
                input_expr="I",
                output_expr=merged_expr,
                function_str=f"lambda x: {merged_expr}",
                depth=min(op.depth for op in ops_to_merge)
            )
        
        def is_valid_merge(ops_to_merge: Set[Operation]) -> bool:
            """Check if a set of operations can be validly merged"""
            # Check if subgraph is connected
            subgraph = G.subgraph(ops_to_merge)
            return nx.is_weakly_connected(subgraph)
        
        def get_valid_merges(n: int) -> List[Set[Operation]]:
            """Get all valid combinations of n operations that can be merged"""
            valid_merges = []
            for ops in combinations(self.operations, n):
                ops_set = set(ops)
                if is_valid_merge(ops_set):
                    valid_merges.append(ops_set)
            return valid_merges
        
        # Generate all possible granularities
        all_granularities = []
        
        # Start with finest granularity (original operations)
        finest = sorted(self.operations, key=lambda x: (x.depth, x.output_expr))
        all_granularities.append(finest)
        
        # Generate coarser granularities by merging operations
        for size in range(2, len(self.operations) + 1):
            valid_merges = get_valid_merges(size)
            if not valid_merges:
                continue
                
            granularity = []
            used_ops = set()
            
            # For each valid merge
            for merge_group in valid_merges:
                # Skip if any operation in this group has been used
                if any(op in used_ops for op in merge_group):
                    continue
                    
                # Merge the operations
                merged_op = merge_operations(merge_group)
                granularity.append(merged_op)
                used_ops.update(merge_group)
                
            # Add remaining operations that weren't merged
            remaining_ops = [op for op in self.operations if op not in used_ops]
            granularity.extend(remaining_ops)
            
            # Sort operations at this granularity level
            granularity.sort(key=lambda x: (x.depth, x.output_expr))
            
            # Add this granularity level if it's different from previous ones
            if granularity and granularity not in all_granularities:
                all_granularities.append(granularity)
        
        return all_granularities

def generate_granular_problems(expr_str: str, module_content: str) -> Dict[str, Dict[str, str]]:
    """
    Generate problems at all possible granularities.
    Returns a dictionary mapping granularity level to a dictionary of problems.
    """
    decomposer = FunctionDecomposer(module_content)
    decomposer.extract_operations(expr_str)
    
    all_granularities = decomposer.get_all_granularities()
    
    problems = {}
    for i, granularity in enumerate(all_granularities):
        level_problems = {}
        for j, op in enumerate(granularity):
            problem_id = f"gran_{i}_step_{j:03d}"
            level_problems[problem_id] = op.function_str
        problems[f"level_{i}"] = level_problems
    
    return problems


In [105]:
expr = "fill(I, THREE, insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))))"
generate_granular_problems(expr, module_content)

{'level_0': {'gran_0_step_000': 'lambda x: fill(I, THREE, temp_10)',
  'gran_0_step_001': 'lambda x: insert(temp_4, temp_9)',
  'gran_0_step_002': 'lambda x: dneighbors(temp_8)',
  'gran_0_step_003': 'lambda x: halve(temp_3)',
  'gran_0_step_004': 'lambda x: apply_func(temp_1, temp_2)',
  'gran_0_step_005': 'lambda x: halve(temp_7)',
  'gran_0_step_006': 'lambda x: apply_func(temp_5, temp_6)',
  'gran_0_step_007': 'lambda x: fork(add, first, last)',
  'gran_0_step_008': 'lambda x: ofcolor(I, ONE)',
  'gran_0_step_009': 'lambda x: fork(add, first, last)',
  'gran_0_step_010': 'lambda x: ofcolor(I, ONE)'},
 'level_1': {'gran_1_step_000': 'lambda x: fill(I, THREE, temp_10)',
  'gran_1_step_001': 'lambda x: insert((halve(temp_3)), temp_9)',
  'gran_1_step_002': 'lambda x: dneighbors(temp_8)',
  'gran_1_step_003': 'lambda x: apply_func(temp_1, temp_2)',
  'gran_1_step_004': 'lambda x: halve((apply_func(temp_5, temp_6)))',
  'gran_1_step_005': 'lambda x: fork(add, first, last)',
  'gran_1_st

In [106]:
import ast
from dataclasses import dataclass
from typing import List, Dict, Set

@dataclass
class Step:
    expression: str
    result_var: str

class Solution:
    def __init__(self):
        self.steps: List[Step] = []
        self.temp_counter = 0
        self.var_map: Dict[str, str] = {}
        
    def add_step(self, expr: str) -> str:
        if expr in self.var_map:
            return self.var_map[expr]
        result_var = f"temp_{self.temp_counter}"
        self.temp_counter += 1
        self.steps.append(Step(expr, result_var))
        self.var_map[expr] = result_var
        return result_var

    def __str__(self) -> str:
        return "\n".join(f"step {i}: {step.result_var} = {step.expression}" 
                        for i, step in enumerate(self.steps))

def get_expression_str(node: ast.AST) -> str:
    """Convert an AST node back to its string representation."""
    if isinstance(node, ast.Name):
        return node.id
    elif isinstance(node, ast.Call):
        args = [get_expression_str(arg) for arg in node.args]
        return f"{node.func.id}({', '.join(args)})"
    else:
        raise ValueError(f"Unexpected node type: {type(node)}")

def collect_calls(node: ast.AST) -> List[ast.Call]:
    """Collect all Call nodes in order from innermost to outermost."""
    calls = []
    
    def visit(node):
        if isinstance(node, ast.Call):
            # First visit all arguments
            for arg in node.args:
                visit(arg)
            # Then add this call
            calls.append(node)
        return node
        
    visit(node)
    return calls

def substitute_expressions(expr: str, expr_map: Dict[str, str]) -> str:
    """Repeatedly substitute expressions until no more substitutions can be made."""
    prev_expr = None
    current_expr = expr
    
    while prev_expr != current_expr:
        prev_expr = current_expr
        # Sort expressions by length (longest first) to avoid partial replacements
        for old_expr, var in sorted(expr_map.items(), key=lambda x: len(x[0]), reverse=True):
            if old_expr in current_expr and old_expr != current_expr:
                current_expr = current_expr.replace(old_expr, var)
    
    return current_expr

def generate_solutions(expr: str) -> List[Solution]:
    # Parse the expression
    tree = ast.parse(expr, mode='eval')
    
    # Get all function calls from innermost to outermost
    calls = collect_calls(tree.body)
    
    # Create solutions
    solutions = []
    
    # Solution 0: Just the complete expression
    solution = Solution()
    solution.add_step(get_expression_str(tree.body))
    solutions.append(solution)
    
    # Create increasingly granular solutions
    for i in range(len(calls) - 1, -1, -1):
        solution = Solution()
        expr_map = {}
        
        # Process all calls from index i onwards
        for call in calls[i:]:
            # Convert the call to string form and substitute all known expressions
            current_expr = get_expression_str(call)
            current_expr = substitute_expressions(current_expr, expr_map)
            
            # Add this step and remember the mapping
            temp_var = solution.add_step(current_expr)
            expr_map[get_expression_str(call)] = temp_var
            
            # Update any previous steps that could use this new variable
            for step in solution.steps[:-1]:  # Skip the step we just added
                step.expression = substitute_expressions(step.expression, expr_map)
        
        if len(solution.steps) > 1:  # Only add if we've broken something down
            solutions.append(solution)
    
    return solutions

def print_all_solutions(expr: str):
    solutions = generate_solutions(expr)
    for i, solution in enumerate(solutions):
        print(f"\nSolution {i} ({'Most' if i == len(solutions)-1 else 'Least' if i == 0 else 'More'} Granular):")
        print(solution)

In [107]:
expr = "upscale(subtract(NINE, I), subtract(NINE, colorcount(I, ZERO)))"
print_all_solutions(expr)


Solution 0 (Least Granular):
step 0: temp_0 = upscale(subtract(NINE, I), subtract(NINE, colorcount(I, ZERO)))

Solution 1 (More Granular):
step 0: temp_0 = subtract(NINE, colorcount(I, ZERO))
step 1: temp_1 = upscale(subtract(NINE, I), temp_0)

Solution 2 (More Granular):
step 0: temp_0 = colorcount(I, ZERO)
step 1: temp_1 = subtract(NINE, temp_0)
step 2: temp_2 = upscale(subtract(NINE, I), temp_1)

Solution 3 (Most Granular):
step 0: temp_0 = subtract(NINE, I)
step 1: temp_1 = colorcount(I, ZERO)
step 2: temp_2 = subtract(NINE, temp_1)
step 3: temp_3 = upscale(temp_0, temp_2)


In [108]:
expr = "fill(I, THREE, insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))))"
print_all_solutions(expr)


Solution 0 (Least Granular):
step 0: temp_0 = fill(I, THREE, insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))))

Solution 1 (More Granular):
step 0: temp_0 = insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE)))))
step 1: temp_1 = fill(I, THREE, temp_0)

Solution 2 (More Granular):
step 0: temp_0 = dneighbors(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))))
step 1: temp_1 = insert(halve(apply_func(fork(add, first, last), ofcolor(I, ONE))), temp_0)
step 2: temp_2 = fill(I, THREE, temp_1)

Solution 3 (More Granular):
step 0: temp_0 = halve(apply_func(fork(add, first, last), ofcolor(I, ONE)))
step 1: temp_1 = dneighbors(temp_0)
step 2: temp_2 = insert(temp_0, temp_1)
step 3: temp_3 = fill(I, THREE, temp_2)

Solution 4 (More Granular):
step 0: temp_0 = apply_func(fork(add, first, last), ofcolor(I, ONE))
step 

In [109]:
expr = "canvas(branch(greater(size(sfilter(objects(I, F, T, T), compose(lbind(contained, TWO), palette))), ONE), ZERO, EIGHT), UNITY)"
# print_all_solutions(expr)

In [110]:
expr = "paint(I, shift(asobject(replace(subgrid(first(objects(I, T, F, T)), I), ZERO, TWO)), ulcorner(first(objects(I, T, F, T)))))"
# print_all_solutions(expr)

In [111]:
expr = """
fill(fill(I, branch(equality(mostcolor(toobject(neighbors(center
        (first(sizefilter(objects(I, T, F, F), ONE)))), I)), first(apply(
        color, difference(objects(I, T, F, F), sizefilter(objects(I, T, F,
        F), ONE))))), color(first(sizefilter(objects(I, T, F, F), ONE))),
        other(apply(color, sizefilter(objects(I, T, F, F), ONE)), color(
        first(sizefilter(objects(I, T, F, F), ONE))))), intersection(
        ofcolor(I, first(apply(color, difference(objects(I, T, F, F),
        sizefilter(objects(I, T, F, F), ONE))))), mapply(compose(fork(
        combine, fork(combine, rbind(shoot, UNITY), rbind(shoot, NEG_UNITY)
        ), fork(combine, rbind(shoot, DOWN_LEFT), rbind(shoot, UP_RIGHT))),
        center), sizefilter(objects(I, T, F, F), ONE)))), branch(equality(
        mostcolor(toobject(neighbors(center(first(sizefilter(objects(I, T,
        F, F), ONE)))), I)), first(apply(color, difference(objects(I, T, F,
        F), sizefilter(objects(I, T, F, F), ONE))))), other(apply(color,
        sizefilter(objects(I, T, F, F), ONE)), color(first(sizefilter(
        objects(I, T, F, F), ONE)))), color(first(sizefilter(objects(I, T,
        F, F), ONE)))), intersection(ofcolor(I, last(apply(color,
        difference(objects(I, T, F, F), sizefilter(objects(I, T, F, F), ONE
        ))))), mapply(compose(fork(combine, fork(combine, rbind(shoot,
        UNITY), rbind(shoot, NEG_UNITY)), fork(combine, rbind(shoot,
        DOWN_LEFT), rbind(shoot, UP_RIGHT))), center), sizefilter(objects(I,
        T, F, F), ONE))))
"""
# print_all_solutions(expr)

In [112]:
import ast
from typing import Dict, Any
import inspect
import dsl
from pprint import pformat

class StepEvaluator:
    def __init__(self):
        # Automatically collect all functions from the DSL module
        self.functions = {
            name: func for name, func in inspect.getmembers(dsl, inspect.isfunction)
            if not name.startswith('_')  # Skip private functions
        }
        
        # All constants remain the same as before
        self.constants = {
            'NEG_ONE': -1, 'NEG_TWO': -2, 'ZERO': 0, 'ONE': 1, 'TWO': 2,
            'THREE': 3, 'FOUR': 4, 'FIVE': 5, 'SIX': 6, 'SEVEN': 7,
            'EIGHT': 8, 'NINE': 9, 'DOWN': (1, 0), 'RIGHT': (0, 1),
            'UP': (-1, 0), 'LEFT': (0, -1), 'ORIGIN': (0, 0),
            'UNITY': (1, 1), 'NEG_UNITY': (-1, -1), 'UP_RIGHT': (-1, 1),
            'DOWN_LEFT': (1, -1), 'ZERO_BY_TWO': (0, 2),
            'TWO_BY_ZERO': (2, 0), 'TWO_BY_TWO': (2, 2),
            'THREE_BY_THREE': (3, 3), 'T': True, 'F': False
        }
        self.temp_values: Dict[str, Any] = {}

    def evaluate_name(self, node: ast.Name) -> Any:
        if node.id in self.temp_values:
            return self.temp_values[node.id]
        elif node.id == 'I':
            return self.input_grid
        elif node.id in self.constants:
            return self.constants[node.id]
        elif node.id in self.functions:
            return self.functions[node.id]
        else:
            raise ValueError(f"Unknown variable or function: {node.id}")
        
    def evaluate_call(self, node: ast.Call) -> Any:
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name not in self.functions:
                raise ValueError(f"Unknown function: {func_name}")
            func = self.functions[func_name]
        else:
            raise ValueError(f"Unsupported function reference type: {type(node.func)}")
            
        args = []
        for arg in node.args:
            if isinstance(arg, ast.Name):
                args.append(self.evaluate_name(arg))
            elif isinstance(arg, ast.Call):
                args.append(self.evaluate_call(arg))
            else:
                raise ValueError(f"Unsupported argument type: {type(arg)}")
                
        return func(*args)
    
    def evaluate_step(self, step_expr: str, result_var: str) -> Any:
        tree = ast.parse(step_expr, mode='eval')
        result = self.evaluate_call(tree.body)
        self.temp_values[result_var] = result
        return result

    def evaluate_solution(self, solution, input_grid) -> Dict[str, Any]:
        self.input_grid = input_grid
        self.temp_values.clear()
        results = {}
        for step in solution.steps:
            result = self.evaluate_step(step.expression, step.result_var)
            results[step.result_var] = result
        return results

def print_grid(grid):
    if not isinstance(grid, tuple) or not grid:
        print(grid)
        return
        
    if not isinstance(grid[0], tuple):
        print(grid)
        return
        
    print(f"Grid {len(grid)}x{len(grid[0])}:")
    for row in grid:
        print(" ".join(str(x) for x in row))
    print()

def format_collection(collection, indent=0):
    """Format a collection (set, frozenset, tuple) with readable output"""
    if len(collection) == 0:
        return "empty"
        
    items = list(collection)
    try:
        items.sort()
    except TypeError:
        pass
    
    formatted_items = []
    indent_str = " " * (indent + 2)
    
    for item in items:
        
        if isinstance(item, frozenset):
            formatted_items.append(format_collection(item, indent + 2))
        else:
            formatted_items.append(str(item))
            
    items_str = f",\n{indent_str}".join(formatted_items)
    return "{\n" + indent_str + items_str + "\n" + (" " * indent) + "}"

def format_result(result):
    if isinstance(result, tuple):
        if len(result) == 0:
            return "()"
        if isinstance(result[0], tuple):
            return f"Grid {len(result)}x{len(result[0])}"
        if len(result) == 2 and all(isinstance(x, int) for x in result):
            return f"Point/Vector {result}"
        return f"Tuple {format_collection(result)}"
    elif isinstance(result, (set, frozenset)):
        size_str = f"FrozenSet of size {len(result)}"
        elements_str = format_collection(result)
        return f"{size_str}\nElements: {elements_str}"
    elif callable(result):
        return f"Function {result.__name__}"
    return str(result)

def evaluate_all_solutions(solutions, input_grid):
    evaluator = StepEvaluator()
    all_results = []
    
    for i, solution in enumerate(solutions):
        print(f"\nSolution {i} ({'Most' if i == len(solutions)-1 else 'Least' if i == 0 else 'More'} Granular):")
        print(solution)
        print("\nIntermediate Results:")
        
        results = evaluator.evaluate_solution(solution, input_grid)
        for step in solution.steps:
            result = results[step.result_var]
            print(f"\n{step.result_var} = {step.expression}")
            print(f"Result type: {type(result).__name__}")
            print("Value:", end=" ")
            if isinstance(result, tuple) and result and isinstance(result[0], tuple):
                print_grid(result)
            else:
                print("\n" + format_result(result))
                
        all_results.append(results)
    
    return all_results

In [113]:
import ast
from typing import Dict, Any, List, Tuple
import inspect
import dsl

class StepEvaluator:
    def __init__(self):
        self.functions = {
            name: func for name, func in inspect.getmembers(dsl, inspect.isfunction)
            if not name.startswith('_')
        }
        self.constants = {
            'NEG_ONE': -1, 'NEG_TWO': -2, 'ZERO': 0, 'ONE': 1, 'TWO': 2,
            'THREE': 3, 'FOUR': 4, 'FIVE': 5, 'SIX': 6, 'SEVEN': 7,
            'EIGHT': 8, 'NINE': 9, 'DOWN': (1, 0), 'RIGHT': (0, 1),
            'UP': (-1, 0), 'LEFT': (0, -1), 'ORIGIN': (0, 0),
            'UNITY': (1, 1), 'NEG_UNITY': (-1, -1), 'UP_RIGHT': (-1, 1),
            'DOWN_LEFT': (1, -1), 'ZERO_BY_TWO': (0, 2),
            'TWO_BY_ZERO': (2, 0), 'TWO_BY_TWO': (2, 2),
            'THREE_BY_THREE': (3, 3), 'T': True, 'F': False
        }
        self.temp_values: Dict[str, Any] = {}
        self.current_step_inputs: Dict[str, Any] = {}
        
    def evaluate_name(self, node: ast.Name) -> Any:
        if node.id in self.temp_values:
            self.current_step_inputs[node.id] = self.temp_values[node.id]
            return self.temp_values[node.id]
        elif node.id == 'I':
            self.current_step_inputs['I'] = self.input_grid
            return self.input_grid
        elif node.id in self.constants:
            return self.constants[node.id]
        elif node.id in self.functions:
            return self.functions[node.id]
        else:
            raise ValueError(f"Unknown variable or function: {node.id}")
        
    def evaluate_call(self, node: ast.Call) -> Any:
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name not in self.functions:
                raise ValueError(f"Unknown function: {func_name}")
            func = self.functions[func_name]
        else:
            raise ValueError(f"Unsupported function reference type: {type(node.func)}")
            
        args = []
        for arg in node.args:
            if isinstance(arg, ast.Name):
                args.append(self.evaluate_name(arg))
            elif isinstance(arg, ast.Call):
                args.append(self.evaluate_call(arg))
            else:
                raise ValueError(f"Unsupported argument type: {type(arg)}")
                
        return func(*args)
    
    def evaluate_step(self, step_expr: str, result_var: str) -> Tuple[Any, Dict[str, Any]]:
        """Evaluate a step and return both result and inputs used"""
        self.current_step_inputs = {}
        tree = ast.parse(step_expr, mode='eval')
        result = self.evaluate_call(tree.body)
        self.temp_values[result_var] = result
        return result, dict(self.current_step_inputs)

    def evaluate_solution(self, solution, input_grid) -> List[Tuple[str, str, Any, Dict[str, Any]]]:
        """Evaluate solution and return list of (var, expr, result, inputs) tuples"""
        self.input_grid = input_grid
        self.temp_values.clear()
        
        step_results = []
        for step in solution.steps:
            result, inputs = self.evaluate_step(step.expression, step.result_var)
            step_results.append((step.result_var, step.expression, result, inputs))
            
        return step_results

def format_collection(collection, indent=0):
    """Format a collection (set, frozenset, tuple) with readable output"""
    if len(collection) == 0:
        return "empty"
        
    items = list(collection)
    try:
        items.sort()
    except TypeError:
        pass
    
    formatted_items = []
    indent_str = " " * (indent + 2)
    
    for item in items:
        
        if isinstance(item, frozenset):
            formatted_items.append(format_collection(item, indent + 2))
        else:
            formatted_items.append(str(item))
            
    items_str = f",\n{indent_str}".join(formatted_items)
    return "{\n" + indent_str + items_str + "\n" + (" " * indent) + "}"

def format_result(result):
    if isinstance(result, tuple):
        if len(result) == 0:
            return "()"
        if isinstance(result[0], tuple):
            return f"Grid {len(result)}x{len(result[0])}"
        if len(result) == 2 and all(isinstance(x, int) for x in result):
            return f"Point/Vector {result}"
        return f"Tuple {format_collection(result)}"
    elif isinstance(result, (set, frozenset)):
        unique_elements = set(result)
        size_str = f"FrozenSet of size {len(unique_elements)}"
        elements_str = format_collection(unique_elements)
        return f"{size_str}\nElements: {elements_str}"
    elif callable(result):
        return f"Function {result.__name__}"
    return str(result)

def print_grid(grid):
    if not isinstance(grid, tuple) or not grid:
        print(grid)
        return
    
    if not isinstance(grid[0], tuple):
        print(grid)
        return
    
    print(f"Grid {len(grid)}x{len(grid[0])}:")
    for row in grid:
        print(" ".join(str(x) for x in row))
    print()

def format_input_value(value):
    """Format an input value for display"""
    if isinstance(value, tuple) and value and isinstance(value[0], tuple):
        return f"Grid {len(value)}x{len(value[0])}"
    return format_result(value)

def evaluate_all_solutions(solutions, input_grid):
    evaluator = StepEvaluator()
    all_results = []
    
    for i, solution in enumerate(solutions):
        print(f"\nSolution {i} ({'Most' if i == len(solutions)-1 else 'Least' if i == 0 else 'More'} Granular):")
        print(solution)
        print("\nIntermediate Results:")
        
        step_results = evaluator.evaluate_solution(solution, input_grid)
        for var, expr, result, inputs in step_results:
            print(f"\n{var} = {expr}")
            if inputs:
                print("Inputs:")
                for input_var, input_val in inputs.items():
                    print(f"  {input_var}: {format_input_value(input_val)}")
            print(f"Result type: {type(result).__name__}")
            print("Value:", end=" ")
            if isinstance(result, tuple) and result and isinstance(result[0], tuple):
                print_grid(result)
            else:
                print("\n" + format_result(result))
                
        all_results.append(step_results)
    
    return all_results

In [114]:
# Example input grid
I = (
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 4, 0, 4, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 4, 0, 4, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
)

expr = """
replace(fill(underfill(I, NEG_ONE, mfilter(prapply(connect,
        ofcolor(I, FOUR), ofcolor(I, FOUR)), fork(either, vline, hline))),
        TWO, mapply(compose(backdrop, inbox), objects(underfill(I, NEG_ONE,
        mfilter(prapply(connect, ofcolor(I, FOUR), ofcolor(I, FOUR)), fork(
        either, vline, hline))), F, F, T))), NEG_ONE, ZERO)
"""
solutions = generate_solutions(expr)
# results = evaluate_all_solutions(solutions, I)

In [115]:
evaluator = StepEvaluator()
# evaluator.evaluate_solution(solutions[2], I)

In [116]:
list(frozenset([1,1,1]))

[1]

In [117]:
import json

In [118]:
with open("incremental_arc_dataset.json", 'r') as fp:
    incremental = json.load(fp)

In [119]:
# incremental['train'][0]

In [120]:
from dataclasses import dataclass
from typing import List, Dict, Set
import ast

@dataclass
class Step:
    expression: str
    result_var: str
    dependencies: Set[str] = None

    def __post_init__(self):
        if self.dependencies is None:
            self.dependencies = set()

class Solution:
    def __init__(self):
        self.steps: List[Step] = []
        self.temp_counter = 0
        self.var_map: Dict[str, str] = {}
        
    def add_step(self, expr: str, dependencies: Set[str] = None) -> str:
        if expr in self.var_map:
            return self.var_map[expr]
        result_var = f"temp_{self.temp_counter}"
        self.temp_counter += 1
        step = Step(expr, result_var, dependencies or set())
        self.steps.append(step)
        self.var_map[expr] = result_var
        return result_var

    def __str__(self) -> str:
        return "\n".join(f"step {i}: {step.result_var} = {step.expression}" 
                        for i, step in enumerate(self.steps))

    def reorder_steps(self):
        """Reorder steps based on dependencies"""
        dependency_graph = {step.result_var: step.dependencies for step in self.steps}
        visited = set()
        ordered_steps = []

        def visit(var):
            if var in visited:
                return
            visited.add(var)
            for dep in dependency_graph.get(var, []):
                visit(dep)
            step = next(s for s in self.steps if s.result_var == var)
            ordered_steps.append(step)

        for step in self.steps:
            visit(step.result_var)

        self.steps = ordered_steps

def merge_function_steps(step_results):
    """
    Merge steps where intermediate results are functions with the steps that use them,
    substituting function expressions for their variable references.
    Returns: List[Tuple[str, str, Any, Dict[str, Any]]] - merged step results
    """
    # Track which steps produce functions
    function_steps = {
        var: (expr, result, inputs) 
        for var, expr, result, inputs in step_results 
        if callable(result)
    }
    
    # Track which steps should be removed (merged into others)
    steps_to_remove = set()
    merged_steps = []
    
    for i, (var, expr, result, inputs) in enumerate(step_results):
        # Skip if this step has been marked for removal
        if var in steps_to_remove:
            continue
            
        # Check if this step uses any function results
        used_functions = {
            input_var: function_steps[input_var]
            for input_var in inputs.keys()
            if input_var in function_steps
        }
        
        if used_functions:
            # Start with current expression
            merged_expr = expr
            merged_inputs = dict(inputs)
            
            # Replace each function variable with its expression
            for func_var, (func_expr, _, func_inputs) in used_functions.items():
                # Remove the function variable from inputs
                merged_inputs.pop(func_var)
                # Add the function's inputs
                merged_inputs.update(func_inputs)
                # Mark function step for removal
                steps_to_remove.add(func_var)
                # Replace the function variable with its expression
                merged_expr = merged_expr.replace(func_var, f"{func_expr}")
            
            merged_steps.append((var, merged_expr, result, merged_inputs))
        else:
            # If this step produces a function but isn't used by any remaining steps, keep it
            if var in function_steps and any(
                var in later_inputs 
                for _, _, _, later_inputs in step_results[i+1:]
            ):
                continue
            merged_steps.append((var, expr, result, inputs))
    
    return merged_steps

def evaluate_all_solutions(solutions, input_grid):
    """Modified evaluate_all_solutions to include step merging with substitution"""
    evaluator = StepEvaluator()
    all_results = []
    
    for i, solution in enumerate(solutions):
        print(f"\nSolution {i} ({'Most' if i == len(solutions)-1 else 'Least' if i == 0 else 'More'} Granular):")
        print(solution)
        print("\nIntermediate Results:")
        
        # Get original step results
        step_results = evaluator.evaluate_solution(solution, input_grid)
        
        # Merge steps with substitution
        merged_steps = merge_function_steps(step_results)
        
        # Print merged results
        for var, expr, result, inputs in merged_steps:
            print(f"\n{var} = {expr}")
            if inputs:
                print("Inputs:")
                for input_var, input_val in inputs.items():
                    print(f"  {input_var}: {format_input_value(input_val)}")
            print(f"Result type: {type(result).__name__}")
            print("Value:", end=" ")
            if isinstance(result, tuple) and result and isinstance(result[0], tuple):
                print_grid(result)
            else:
                print("\n" + format_result(result))
                
        all_results.append(merged_steps)
    
    return all_results

In [121]:
# Example input grid
I = (
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 4, 0, 4, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 4, 0, 4, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
)

expr = """
replace(fill(underfill(I, NEG_ONE, mfilter(prapply(connect,
        ofcolor(I, FOUR), ofcolor(I, FOUR)), fork(either, vline, hline))),
        TWO, mapply(compose(backdrop, inbox), objects(underfill(I, NEG_ONE,
        mfilter(prapply(connect, ofcolor(I, FOUR), ofcolor(I, FOUR)), fork(
        either, vline, hline))), F, F, T))), NEG_ONE, ZERO)
"""
solutions = generate_solutions(expr)
# solutions = post_process_solutions(solutions, evaluator)
# results = evaluate_all_solutions(solutions, I)


In [122]:
evaluator = StepEvaluator()
step_results = evaluator.evaluate_solution(solutions[6], I)
merged_steps = merge_function_steps(step_results)

In [123]:
# merged_steps

In [124]:
import ast
from typing import Dict, Any, List, Tuple, Set
import inspect
import dsl
import re

class StepEvaluator:
    def __init__(self):
        self.functions = {
            name: func for name, func in inspect.getmembers(dsl, inspect.isfunction)
            if not name.startswith('_')
        }
        self.constants = {
            'NEG_ONE': -1, 'NEG_TWO': -2, 'ZERO': 0, 'ONE': 1, 'TWO': 2,
            'THREE': 3, 'FOUR': 4, 'FIVE': 5, 'SIX': 6, 'SEVEN': 7,
            'EIGHT': 8, 'NINE': 9, 'TEN': 10, 'DOWN': (1, 0), 'RIGHT': (0, 1),
            'UP': (-1, 0), 'LEFT': (0, -1), 'ORIGIN': (0, 0),
            'UNITY': (1, 1), 'NEG_UNITY': (-1, -1), 'UP_RIGHT': (-1, 1),
            'DOWN_LEFT': (1, -1), 'ZERO_BY_TWO': (0, 2),
            'TWO_BY_ZERO': (2, 0), 'TWO_BY_TWO': (2, 2),
            'THREE_BY_THREE': (3, 3), 'T': True, 'F': False
        }
        self.temp_values: Dict[str, Any] = {}
        self.current_step_inputs: Dict[str, Any] = {}
        self.pending_steps: Dict[str, str] = {}
        
    def evaluate_name(self, node: ast.Name) -> Any:
        if node.id in self.temp_values:
            self.current_step_inputs[node.id] = self.temp_values[node.id]
            return self.temp_values[node.id]
        elif node.id == 'I':
            self.current_step_inputs['I'] = self.input_grid
            return self.input_grid
        elif node.id in self.constants:
            return self.constants[node.id]
        elif node.id in self.functions:
            return self.functions[node.id]
        elif node.id in self.pending_steps:
            # Evaluate pending step first
            pending_expr = self.pending_steps[node.id]
            result = self.evaluate_expression(pending_expr)
            self.temp_values[node.id] = result
            self.current_step_inputs[node.id] = result
            return result
        else:
            raise ValueError(f"Unknown variable or function: {node.id}")
        
    def evaluate_call(self, node: ast.Call) -> Any:
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name not in self.functions:
                raise ValueError(f"Unknown function: {func_name}")
            func = self.functions[func_name]
        else:
            raise ValueError(f"Unsupported function reference type: {type(node.func)}")
            
        args = []
        for arg in node.args:
            if isinstance(arg, ast.Name):
                args.append(self.evaluate_name(arg))
            elif isinstance(arg, ast.Call):
                args.append(self.evaluate_call(arg))
            else:
                raise ValueError(f"Unsupported argument type: {type(arg)}")
                
        return func(*args)
    
    def evaluate_expression(self, expr: str) -> Any:
        """Evaluate a single expression"""
        tree = ast.parse(expr, mode='eval')
        return self.evaluate_call(tree.body)
    
    def evaluate_step(self, step_expr: str, result_var: str) -> Tuple[Any, Dict[str, Any]]:
        """Evaluate a step and return both result and inputs used"""
        self.current_step_inputs = {}
        result = self.evaluate_expression(step_expr)
        self.temp_values[result_var] = result
        return result, dict(self.current_step_inputs)

    def evaluate_solution(self, solution, input_grid) -> List[Tuple[str, str, Any, Dict[str, Any]]]:
        """Evaluate all steps in a solution"""
        self.input_grid = input_grid
        self.temp_values.clear()
        self.pending_steps.clear()
        
        # First collect all steps
        for step in solution.steps:
            self.pending_steps[step.result_var] = step.expression
            
        step_results = []
        # Then evaluate them in order
        for step in solution.steps:
            try:
                result, inputs = self.evaluate_step(step.expression, step.result_var)
                step_results.append((step.result_var, step.expression, result, inputs))
            except Exception as e:
                print(f"Error evaluating step {step.result_var} = {step.expression}")
                print(f"Current temp_values: {list(self.temp_values.keys())}")
                raise
            
        return step_results

# [Rest of the code remains the same: format_collection, format_result, print_grid, etc.]

def check_callable(r):
    if isinstance(r, (list, tuple, set, frozenset)):
        return all(check_callable(x) for x in r) and len(r) > 0 #empty set would read true
    else:
        return callable(r)

def compute_used(function_steps, inputs, expr):
    return {
        func_var: func_expr 
        for func_var, (func_expr, _) in function_steps.items() 
        if func_var in inputs or exact_pattern_match(func_var, expr)
    }

def exact_pattern_match(pattern_to_find, text):
    # Escape any special regex characters in the pattern
    escaped_pattern = re.escape(pattern_to_find)
    # Add word boundary at the end
    regex_pattern = fr'{escaped_pattern}\b'
    
    return re.search(regex_pattern, text)

def merge_used(function_steps, var, expr, result, inputs):
    # used_in_step = {
    #     func_var: func_expr 
    #     for func_var, (func_expr, _) in function_steps.items() 
    #     if func_var in inputs or func_var in expr
    # }
    # print("starting merge used")
    # print("orig expr", expr)
    # print("function steps", function_steps.keys())
    final_expr = expr
    final_var = var
    final_result = result
    final_inputs = inputs
    used_in_step = compute_used(function_steps, inputs, expr)
    # print("orig used", used_in_step)
    # print(used_in_step)
    while used_in_step:
        # print("in while",final_expr)
        # 3. Substitute expressions
        new_expr = final_expr
        new_inputs = dict(final_inputs)
        # print(used_in_step, "before")
        for func_var, func_expr in used_in_step.items():
            # Replace function variable with its expression
            new_expr = new_expr.replace(func_var, f"({func_expr})")
            # Update inputs
            new_inputs.update(function_steps[func_var][1])
            new_inputs.pop(func_var, None)
            # used_functions.add(func_var)
        
        used_in_step = compute_used(function_steps, new_inputs, new_expr)
        
        # print(used_in_step, "after", new_expr, new_inputs)
        final_expr = new_expr
        final_inputs = new_inputs
        # print("used_after", used_in_step)
        # print("after expr", new_expr)
        # break
            
        # return merge_used(function_steps, var, new_expr, result, new_inputs)
    # print(final_var, final_expr, final_result, final_inputs, "ALL DONE")
    return (final_var, final_expr, final_result, final_inputs)
    


def merge_function_steps(step_results):
    """
    Merge steps by repeatedly:
    1. Finding steps that return callables
    2. Finding steps that use those callables
    3. Substituting expressions
    4. Renaming variables
    5. Repeating until no callables remain
    6. Removing any unused function-producing steps
    """
    current_steps = list(step_results)
    # print(current_steps, "curr here")
    function_steps = {
        var: (expr, inputs) 
        for var, expr, result, inputs in current_steps 
        # if (callable(result) or (isinstance(result, (list, tuple, set, frozenset)) and all(callable(r) for r in result)))
        if check_callable(result)
    }

    while True:
        # 1. Find steps that return callables
        # print(current_steps[0], "curr")
        # print(type(current_steps[0]), "type")
        # try:
        # print(current_steps, "current")
        # print("here")
        # print(current_steps)
        function_steps = {
        var: (expr, inputs) 
        for var, expr, result, inputs in current_steps 
        # if (callable(result) or (isinstance(result, (list, tuple, set, frozenset)) and all(callable(r) for r in result)))
        if check_callable(result)
        }
        # except:
        #     return current_steps
        # print(function_steps)
        
        # print(function_steps.keys())
        if not function_steps:
            break

        # for func in function_steps:
        #     new_function_steps = {}
        #     used_functions = set()
        #     for var, (expr, inputs) in function_steps.items():
        #         used_in_var = {
        #         func_var: func_expr 
        #         for func_var, (func_expr, _) in function_steps.items() 
        #         if func_var in inputs or func_var in expr
        #     }
        #     print("used in func", var, used_in_var)
                
        # 2. Find steps that use those callables and perform substitution
        new_steps = []
        used_functions = set()
        
        for var, expr, result, inputs in current_steps:
            # print("iter through current")
            # Skip function steps that will be merged
            if var in function_steps:
                continue
                
            # Check if this step uses any functions
            # used_in_step = {
            #     func_var: func_expr 
            #     for func_var, (func_expr, _) in function_steps.items() 
            #     if func_var in inputs or func_var in expr
            # }
            # # print("used in", var, used_in_step)
            
            # if used_in_step:
            #     # 3. Substitute expressions
            #     new_expr = expr
            #     new_inputs = dict(inputs)
                
            #     for func_var, func_expr in used_in_step.items():
            #         # Replace function variable with its expression
            #         new_expr = new_expr.replace(func_var, f"({func_expr})")
            #         # Update inputs
            #         new_inputs.update(function_steps[func_var][1])
            #         new_inputs.pop(func_var, None)
            #         used_functions.add(func_var)
                    
            #     new_steps.append((var, new_expr, result, new_inputs))
            # else:
            #     new_steps.append((var, expr, result, inputs))
            new_steps.append(merge_used(function_steps, var, expr, result, inputs))
        
        # If no changes were made, break
        if len(new_steps) == len(current_steps):
            break
            
        current_steps = new_steps
        
        # new_steps = []
        # for (var, expr, result, inputs) in current_steps:
        #     # Skip function steps that will be merged
        #     if var in function_steps:
        #         continue
                
        #     # Check if this step uses any functions
        #     used_in_step = {
        #         func_var: func_expr 
        #         for func_var, (func_expr, _) in function_steps.items() 
        #         if func_var in inputs or func_var in expr
        #     }
        #     print("used in", var, used_in_step)
            
        #     if used_in_step:
        #         # 3. Substitute expressions
        #         new_expr = expr
        #         new_inputs = dict(inputs)
                
        #         for func_var, func_expr in used_in_step.items():
        #             # Replace function variable with its expression
        #             new_expr = new_expr.replace(func_var, f"({func_expr})")
        #             # Update inputs
        #             new_inputs.update(function_steps[func_var][1])
        #             new_inputs.pop(func_var, None)
        #             used_functions.add(func_var)
                    
        #         new_steps.append((var, new_expr, result, new_inputs))
        #     else:
        #         new_steps.append((var, expr, result, inputs))
        
        # # If no changes were made, break
        # if len(new_steps) == len(current_steps):
        #     break
            
        # current_steps = new_steps

    
    # Final cleanup: remove any remaining function-producing steps
    final_steps = []
    function_vars = {var for var, _, result, _ in current_steps if check_callable(result)}
    used_vars = set()
    
    # Collect all variables that are still being used
    for _, expr, _, inputs in current_steps:
        for var in function_vars:
            if var in expr or var in str(inputs):
                used_vars.add(var)
    
    # Keep only non-function steps and function steps that are still used
    final_steps = [
        step for step in current_steps
        if not check_callable(step[2]) or step[0] in used_vars
    ]
    
    return final_steps

def evaluate_all_solutions(solutions, input_grid):
    """Evaluate all solutions with complete cleanup"""
    evaluator = StepEvaluator()
    all_results = []
    
    for i, solution in enumerate(solutions):
        print(f"\nSolution {i} ({'Most' if i == len(solutions)-1 else 'Least' if i == 0 else 'More'} Granular):")
        print(solution)
        print("\nIntermediate Results:")
        
        try:
            # Get original step results
            step_results = evaluator.evaluate_solution(solution, input_grid)
            
            # Merge steps recursively and cleanup
            merged_steps = merge_function_steps(step_results)
            
            # Print merged results
            for var, expr, result, inputs in merged_steps:
                if not check_callable(result):  # Only show non-function results
                    print(f"\n{var} = {expr}")
                    if inputs:
                        print("Inputs:")
                        for input_var, input_val in inputs.items():
                            print(f"  {input_var}: {format_input_value(input_val)}")
                    print(f"Result type: {type(result).__name__}")
                    print("Value:", end=" ")
                    if isinstance(result, tuple) and result and isinstance(result[0], tuple):
                        print_grid(result)
                    else:
                        print("\n" + format_result(result))
                    
            all_results.append(merged_steps)
        except Exception as e:
            print(f"Error in solution {i}:")
            print(str(e))
            print("Continuing with next solution...")
            continue
    
    return all_results

In [125]:
import ast
from typing import Dict

def map_function_returns(file_path: str) -> Dict[str, str]:
    """
    Parse a Python file and map each function name to its return expression.
    
    Args:
        file_path: Path to the Python file to analyze
        
    Returns:
        Dictionary mapping function names to their return expressions as strings
    """
    with open(file_path, 'r') as file:
        source = file.read()
    
    tree = ast.parse(source)
    return_map = {}
    
    class ReturnVisitor(ast.NodeVisitor):
        def visit_FunctionDef(self, node):
            returns = []
            
            # Inner class to find return statements
            class ReturnFinder(ast.NodeVisitor):
                def visit_Return(self, return_node):
                    if return_node.value:  # Ignore bare 'return' statements
                        returns.append(ast.unparse(return_node.value))
            
            # Find all return statements in the function
            ReturnFinder().visit(node)
            
            # Join multiple returns with ' | ' if they exist
            if returns:
                return_map[node.name] = ' | '.join(returns)
            else:
                return_map[node.name] = 'None'  # Functions without explicit returns
    
    ReturnVisitor().visit(tree)
    return return_map

def print_return_map(file_path: str) -> None:
    """
    Print the function return mappings in a readable format.
    
    Args:
        file_path: Path to the Python file to analyze
    """
    return_map = map_function_returns(file_path)
    max_name_length = max(len(name) for name in return_map.keys())
    
    print("\nFunction Return Expressions:")
    print("-" * (max_name_length + 30))
    for func_name, return_expr in sorted(return_map.items()):
        print(f"{func_name:<{max_name_length}} -> {return_expr}")

In [126]:
map = map_function_returns("rewritten_solvers.py")

In [127]:
map.keys()

dict_keys(['solve_67a3c6ac', 'solve_68b16354', 'solve_74dd1130', 'solve_3c9b0459', 'solve_6150a2bd', 'solve_9172f3a0', 'solve_9dfd6313', 'solve_a416b8f3', 'solve_b1948b0a', 'solve_c59eb873', 'solve_c8f0f002', 'solve_d10ecb37', 'solve_d511f180', 'solve_ed36ccf7', 'solve_4c4377d9', 'solve_6d0aefbc', 'solve_6fa7a44f', 'solve_5614dbcf', 'solve_5bd6f4ac', 'solve_5582e5ca', 'solve_8be77c9e', 'solve_c9e6f938', 'solve_2dee498d', 'solve_1cf80156', 'solve_32597951', 'solve_25ff71a9', 'solve_0b148d64', 'solve_1f85a75f', 'solve_23b5c85d', 'solve_9ecd008a', 'solve_ac0a08a4', 'solve_be94b721', 'solve_c909285e', 'solve_f25ffba3', 'solve_c1d99e64', 'solve_b91ae062', 'solve_3aa6fb7a', 'solve_7b7f7511', 'solve_4258a5f9', 'solve_2dc579da', 'solve_28bf18c6', 'solve_3af2c5a8', 'solve_44f52bb0', 'solve_62c24649', 'solve_67e8384a', 'solve_7468f01a', 'solve_662c240a', 'solve_42a50994', 'solve_56ff96f3', 'solve_50cb2852', 'solve_4347f46a', 'solve_46f33fce', 'solve_a740d043', 'solve_a79310a0', 'solve_aabf363d',

In [128]:
import os
def get_data(train=True):
    path = f'arc_original/{"training" if train else "evaluation"}'
    data = {}
    for fn in os.listdir(path):
        with open(f'{path}/{fn}') as f:
            data[fn.rstrip('.json')] = json.load(f)
    ast = lambda g: tuple(tuple(r) for r in g)
    return {
        'train': {k: [{
            'input': ast(e['input']),
            'output': ast(e['output']),
        } for e in v['train']] for k, v in data.items()},
        'test': {k: [{
            'input': ast(e['input']),
            'output': ast(e['output']),
        } for e in v['test']] for k, v in data.items()}
    }
data = get_data()
I = data['train']['44d8ac46'][3]['input']

In [129]:
len(data['train']['44d8ac46'])

4

In [47]:
solutions[1].steps

[Step(expression='mfilter(apply(delta, objects(I, T, F, T)), square)', result_var='temp_0', dependencies=set()),
 Step(expression='fill(I, TWO, temp_0)', result_var='temp_1', dependencies=set())]

In [31]:
evaluator = StepEvaluator()
# merge_function_steps(evaluator.evaluate_solution(solutions[1], I))

In [74]:
check_callable(fs[0][2])

False

In [32]:
# expr = map['solve_9d9215db'] 3ac3eb23
expr = map['solve_44d8ac46']
solutions = generate_solutions(expr)
# results = evaluate_all_solutions(solutions, I)

In [33]:
# for i, res in enumerate(results):
#     print(i, res[0])
#     for tup in res:
#         if callable(tup[2]):
#             print(i)
#             break

In [34]:
evaluator = StepEvaluator()
# evaluator.evaluate_solution(solutions[10], I)

In [35]:
!jupyter notebook --NotebookApp.iopub_data_rate_limit=1.0e10



  _   _          _      _
 | | | |_ __  __| |__ _| |_ ___
 | |_| | '_ \/ _` / _` |  _/ -_)
  \___/| .__/\__,_\__,_|\__\___|
       |_|
                       
Read the migration plan to Notebook 7 to learn about the new features and the actions to take if you are using extensions.

https://jupyter-notebook.readthedocs.io/en/latest/migrate_to_notebook7.html

Please note that updating to Notebook 7 might break some of your extensions.

[32m[I 01:08:22.037 NotebookApp][m Serving notebooks from local directory: /scratch/rst306/arc_project/arc_example_generation/arc_dsl
[32m[I 01:08:22.037 NotebookApp][m Jupyter Notebook 6.5.4 is running at:
[32m[I 01:08:22.037 NotebookApp][m http://localhost:8888/?token=1fd9f5405afaa1856f1ce8def57e9f8f3a15baed306a9fa5
[32m[I 01:08:22.037 NotebookApp][m  or http://127.0.0.1:8888/?token=1fd9f5405afaa1856f1ce8def57e9f8f3a15baed306a9fa5
[32m[I 01:08:22.037 NotebookApp][m Use Control-C to stop this server and shut down all kernels (twice to skip conf

In [34]:
# for x in results[10][1][2]:
#     print(x, check_callable(x))
# check_callable(results[10][1][2])
# c = merge_function_steps(evaluator.evaluate_solution(solutions[10], I))

<class 'tuple'> type
dict_keys(['temp_0', 'temp_2', 'temp_3', 'temp_4'])
temp_1 interval(ONE, FOUR, ONE) (1, 2, 3) {} ALL DONE
temp_5 objects(I, F, T, T) frozenset({frozenset({(3, (3, 15))}), frozenset({(3, (16, 6)), (3, (15, 6)), (3, (16, 5)), (3, (15, 5))}), frozenset({(2, (1, 13))}), frozenset({(3, (6, 7)), (8, (5, 5)), (8, (4, 7)), (8, (4, 6)), (8, (5, 6)), (2, (4, 5)), (8, (6, 5))}), frozenset({(2, (11, 10)), (2, (12, 10)), (2, (11, 9)), (2, (12, 9))})}) {'I': ((1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1), (1, 1, 1, 1, 1, 2, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 8, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)

In [36]:
step_results = evaluator.evaluate_solution(solutions[7], I)
# merged_steps = merge_function_steps(step_results)

IndexError: list index out of range

In [54]:
merged_steps[2]

('temp_3',
 'paint(I, temp_2)',
 ((0, 2, 0, 0, 0, 8, 0, 0, 0, 0),
  (2, 0, 2, 0, 8, 0, 8, 0, 0, 0),
  (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
  (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
  (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
  (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)),
 {'I': ((0, 2, 0, 0, 0, 8, 0, 0, 0, 0),
   (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
   (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
   (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
   (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
   (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)),
  'temp_2': frozenset({(2, (-1, 0)),
             (2, (-1, 2)),
             (2, (1, 0)),
             (2, (1, 2)),
             (8, (-1, 4)),
             (8, (-1, 6)),
             (8, (1, 4)),
             (8, (1, 6))})})

In [32]:
from tqdm import tqdm

In [33]:
# map_to_solutions = {}
# for func_name in tqdm(map, total=len(map)):
#     exp = map[func_name]
#     solutions = generate_solutions(expr)
#     map_to_solutions[func_name] = solutions

In [54]:
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from functools import partial

def process_function(func_name, map_dict):
    """Process a single function and generate solutions."""
    exp = map_dict[func_name]
    return func_name, generate_solutions(exp)

def parallel_generate_solutions(map_dict, num_processes=None):
    """
    Generate solutions in parallel using multiple processes.
    
    Args:
        map_dict: Dictionary mapping function names to expressions
        num_processes: Number of processes to use (defaults to CPU count)
    
    Returns:
        Dictionary mapping function names to their solutions
    """
    if num_processes is None:
        num_processes = cpu_count()

    # Create a partial function with the map_dict argument fixed
    process_func = partial(process_function, map_dict=map_dict)
    
    # Create a process pool
    with Pool(processes=num_processes) as pool:
        # Use imap_unordered for better performance when order doesn't matter
        # Wrap with tqdm for progress tracking
        results = list(tqdm(
            pool.imap_unordered(process_func, map_dict.keys()),
            total=len(map_dict),
            desc="Generating solutions"
        ))
    
    # Convert results list back to dictionary
    return dict(results)

# Usage example:
# if __name__ == '__main__':
map_to_solutions = parallel_generate_solutions(map)

Generating solutions: 100%|██████████| 400/400 [14:05<00:00,  2.11s/it] 


In [34]:
exp = map['solve_8d510a79']
solutions = generate_solutions(exp)

In [35]:
solutions[-6].steps

[Step(expression='compose(temp_12, compose(lbind(greater, uppermost(temp_2)), first))', result_var='temp_0', dependencies=set()),
 Step(expression='fork(shoot, identity, temp_0)', result_var='temp_1', dependencies=set()),
 Step(expression='ofcolor(I, FIVE)', result_var='temp_2', dependencies=set()),
 Step(expression='uppermost(temp_2)', result_var='temp_3', dependencies=set()),
 Step(expression='lbind(greater, temp_3)', result_var='temp_4', dependencies=set()),
 Step(expression='compose(temp_4, first)', result_var='temp_5', dependencies=set()),
 Step(expression='lbind(matcher, temp_5)', result_var='temp_6', dependencies=set()),
 Step(expression='compose(temp_6, temp_5)', result_var='temp_7', dependencies=set()),
 Step(expression='fork(sfilter, temp_1, temp_7)', result_var='temp_8', dependencies=set()),
 Step(expression='ofcolor(I, TWO)', result_var='temp_9', dependencies=set()),
 Step(expression='mapply(temp_8, temp_9)', result_var='temp_10', dependencies=set()),
 Step(expression='unde

In [97]:
# nme, solns = process_function("solve_6aa20dc0", map)

In [98]:
# for soln in solns:
#     evaluate_

SyntaxError: incomplete input (150792189.py, line 1)

In [62]:
len(map_to_solutions)

400

In [130]:
def process_collection(collection):
    """Format a collection (set, frozenset, tuple) with readable output"""
    items = list(collection)
    try:
        items.sort()
    except TypeError:
        pass
    
    formatted_items = []
    
    for item in items:
        
        if isinstance(item, frozenset):
            formatted_items.append(format_collection(item))
        else:
            formatted_items.append(str(item))
            
    items_str = " ".join(formatted_items)
    return items_str

def process_inputs(inputs):
    return [process_result(inputs[k]) for k in inputs]

def process_outputs(outputs):
    return [process_result(o) for o in outputs]
    
def convert_frozenset_to_set(obj):
    """
    Recursively converts FrozenSets to Sets in a nested data structure.
    First converts everything to lists, then rebuilds with sets.
    
    Args:
        obj: Any Python object that might contain FrozenSets
        
    Returns:
        The same structure with all FrozenSets converted to Sets
    """
    # First pass: convert everything to lists
    def to_lists(obj):
        if isinstance(obj, (str, int, float, bool, type(None))):
            return obj
        elif isinstance(obj, (frozenset, set)):
            return [to_lists(x) for x in obj]
        elif isinstance(obj, (list, tuple)):
            return [to_lists(x) for x in obj]
        elif isinstance(obj, dict):
            return {to_lists(k): to_lists(v) for k, v in obj.items()}
        else:
            return obj

    # Second pass: convert lists back to sets where needed
    def to_sets(obj):
        if isinstance(obj, (str, int, float, bool, type(None))):
            return obj
        elif isinstance(obj, list):
            # Convert all elements first
            converted = [to_sets(x) for x in obj]
            # If this was originally a set/frozenset, make it a set
            if isinstance(obj, (set, frozenset)):
                return set(converted)
            return converted
        elif isinstance(obj, dict):
            return {to_sets(k): to_sets(v) for k, v in obj.items()}
        else:
            return obj

    # Run both passes
    return to_sets(to_lists(obj))
def process_result(result):
    if isinstance(result, (set, frozenset)):
        return convert_frozenset_to_set(result)
    elif callable(result):
        return "function"
    else:
        return result
        

In [131]:
# incremental['train'][0]

In [132]:
def process_solution(solution, name):
    evaluator = StepEvaluator()
    key = name.replace("solve_", "")
    orig_inputs = [ex['input'] for ex in data['train'][key]]
    # step_results = evaluator.evaluate_solution(solution, orig_inputs[0])
    # merged_steps = merge_function_steps(step_results)
    # if need:
    #     return None

    steps_all_inputs = []
    for I in orig_inputs:
        step_results = evaluator.evaluate_solution(solution, I)
        
        merged_steps = merge_function_steps(step_results)
        # print(merged_steps)
        for step in merged_steps:
            if callable(step[2]):
                print(name)
                print(solution)
                print(step)
        steps_all_inputs.append(merged_steps)
    subroutines = {}
    # return steps_all_inputs
    # print(len(steps_all_inputs[0][0]))
    for i in range(len(steps_all_inputs[0])):
        # print(merged[i], i)
        # print(i)
        # for m in steps_all_inputs:
            # print(m, len(m))
            # print(m[i])
        # print([(merged[i], len(merged[i])) for merged in steps_all_inputs], i)
        inputs = [process_inputs(merged[i][3]) for merged in steps_all_inputs]
        outputs = [process_result(merged[i][2]) for merged in steps_all_inputs]
        program = f"{steps_all_inputs[0][i][0]} = {steps_all_inputs[0][i][1]}"
        subroutines[f"subroutine_{i}"] = {"inputs": inputs,
                                          "outputs": outputs,
                                          "program": program}
    return {"original_task_key": key,
            "subroutines": subroutines}

In [84]:
# map_to_solutions.keys()

dict_keys(['solve_67a3c6ac', 'solve_68b16354', 'solve_74dd1130', 'solve_3c9b0459', 'solve_6150a2bd', 'solve_9172f3a0', 'solve_9dfd6313', 'solve_a416b8f3', 'solve_b1948b0a', 'solve_c59eb873', 'solve_c8f0f002', 'solve_d10ecb37', 'solve_d511f180', 'solve_ed36ccf7', 'solve_4c4377d9', 'solve_6d0aefbc', 'solve_6fa7a44f', 'solve_5614dbcf', 'solve_5bd6f4ac', 'solve_5582e5ca', 'solve_8be77c9e', 'solve_c9e6f938', 'solve_2dee498d', 'solve_1cf80156', 'solve_32597951', 'solve_25ff71a9', 'solve_0b148d64', 'solve_1f85a75f', 'solve_23b5c85d', 'solve_9ecd008a', 'solve_ac0a08a4', 'solve_be94b721', 'solve_c909285e', 'solve_f25ffba3', 'solve_c1d99e64', 'solve_b91ae062', 'solve_3aa6fb7a', 'solve_7b7f7511', 'solve_4258a5f9', 'solve_2dc579da', 'solve_28bf18c6', 'solve_3af2c5a8', 'solve_44f52bb0', 'solve_62c24649', 'solve_67e8384a', 'solve_7468f01a', 'solve_662c240a', 'solve_42a50994', 'solve_56ff96f3', 'solve_50cb2852', 'solve_4347f46a', 'solve_46f33fce', 'solve_a740d043', 'solve_a79310a0', 'solve_aabf363d',

In [57]:
s = process_solution(solutions[-6], name)

NameError: name 'name' is not defined

In [156]:
tmp = process_solution(map_to_solutions['solve_bbc9ae5d'][0], "solve_bbc9ae5d")

In [105]:
len(tmp['subroutines']['subroutine_0']['outputs'])

5

In [51]:
map_to_solutions['solve_a699fb00']

[<__main__.Solution at 0x15173728d550>,
 <__main__.Solution at 0x15173728da90>,
 <__main__.Solution at 0x15173728f750>,
 <__main__.Solution at 0x15173728ef50>,
 <__main__.Solution at 0x15173728e050>,
 <__main__.Solution at 0x151735f1efd0>]

In [57]:
all_solutions = []
for name in tqdm(map_to_solutions):
    if name in ['solve_a64e4611', 'solve_2dd70a9a']:
        continue
    solutions = map_to_solutions[name]
    for solution in solutions:
        processed = process_solution(solution, name)
        if processed is not None:
            all_solutions.append(processed)

100%|██████████| 400/400 [03:17<00:00,  2.03it/s] 


In [93]:
solution.steps

[Step(expression='compose(temp_12, compose(lbind(greater, uppermost(temp_2)), first))', result_var='temp_0', dependencies=set()),
 Step(expression='fork(shoot, identity, temp_0)', result_var='temp_1', dependencies=set()),
 Step(expression='ofcolor(I, FIVE)', result_var='temp_2', dependencies=set()),
 Step(expression='uppermost(temp_2)', result_var='temp_3', dependencies=set()),
 Step(expression='lbind(greater, temp_3)', result_var='temp_4', dependencies=set()),
 Step(expression='compose(temp_4, first)', result_var='temp_5', dependencies=set()),
 Step(expression='lbind(matcher, temp_5)', result_var='temp_6', dependencies=set()),
 Step(expression='compose(temp_6, temp_5)', result_var='temp_7', dependencies=set()),
 Step(expression='fork(sfilter, temp_1, temp_7)', result_var='temp_8', dependencies=set()),
 Step(expression='ofcolor(I, TWO)', result_var='temp_9', dependencies=set()),
 Step(expression='mapply(temp_8, temp_9)', result_var='temp_10', dependencies=set()),
 Step(expression='unde

In [53]:
s = process_solution(solutions[-6], "solve_8d510a79")

In [89]:
len(s[1][1])

4

In [112]:
map_to_solutions[name][0].steps

[Step(expression='vmirror(I)', result_var='temp_0', dependencies=set())]

In [117]:
for sol in map_to_solutions[name]:
    sol

In [118]:
sol

<__main__.Solution at 0x151df8a5b950>

In [77]:
all_solutions[1000]

{'original_task_key': '694f12f3',
 'subroutines': {'subroutine_0': {'inputs': [[((0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 0, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 0, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 0, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 0, 0, 0, 0, 0),
      (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
      (0, 0, 0, 4, 4, 4, 4, 4, 4, 0),
      (0, 0, 0, 4, 4, 4, 4, 4, 4, 0),
      (0, 0, 0, 4, 4, 4, 4, 4, 4, 0),
      (0, 0, 0, 4, 4, 4, 4, 4, 4, 0))],
    [((0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 4, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 4, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 4, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 4, 0, 0, 0, 0),
      (0, 4, 4, 4, 4, 4, 0, 0, 0, 0),
      (0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
      (0, 0, 0, 0, 0, 4, 4, 4, 4, 0),
      (0, 0, 0, 0, 0, 4, 4, 4, 4, 0),
      (0, 0, 0, 0, 0, 4, 4, 4, 4, 0))]],
   'outputs': [[[[4, [3, 1]],
      [4, [4, 2]],
      [4, [3, 3]],
      [4, [2, 1]],
      [4, [3, 2]],
      [4, [4, 1]],
      [4, [3, 4]],
      [4, [1, 1]

In [58]:
tests = ['solve_9d9215db', 'solve_150deff5', 'solve_b7249182']
test_keys = [t.replace("solve_", "") for t in tests]
test_examples = [soln for soln in all_solutions if soln['original_task_key'] in test_keys]
train_examples = [soln for soln in all_solutions if soln['original_task_key'] not in test_keys]

In [59]:
len(all_solutions)

13062

In [60]:
len(test_examples)

2589

In [61]:
decomposed = {"train": train_examples, 'test': test_examples}

In [62]:
def find_non_serializable(obj, path=""):
    """
    Recursively traverse an object and find all paths that contain non-JSON-serializable items.
    Specifically looks for function objects that would cause TypeError.
    
    Args:
        obj: The object to inspect
        path: Current path in the object (used recursively)
        
    Returns:
        list: List of paths (strings) where functions were found
    """
    problems = []
    
    # Handle different types
    if isinstance(obj, dict):
        for key, value in obj.items():
            new_path = f"{path}.{key}" if path else str(key)
            if callable(value):
                problems.append((new_path, type(value).__name__))
            else:
                problems.extend(find_non_serializable(value, new_path))
                
    elif isinstance(obj, (list, tuple, set)):
        for i, item in enumerate(obj):
            new_path = f"{path}[{i}]"
            if callable(item):
                problems.append((new_path, type(item).__name__))
            else:
                problems.extend(find_non_serializable(item, new_path))
                
    elif callable(obj):
        problems.append((path, type(obj).__name__))
        
    return problems

def debug_json_serialization(obj):
    """
    Wrapper function that provides a friendly output of all non-serializable paths found.
    
    Args:
        obj: The object to inspect
        
    Returns:
        None: Prints the problems found
    """
    problems = find_non_serializable(obj)
    
    if not problems:
        print("No serialization problems found!")
        return
        
    print("Found the following non-serializable items:")
    for path, type_name in problems:
        print(f"• At path '{path}': Found {type_name}")
    return problems

In [63]:
probs = debug_json_serialization(decomposed)

No serialization problems found!


In [54]:
probs

[('train[5071].subroutines.subroutine_0.outputs[0][0]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][1]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][2]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][3]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][4]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][5]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][6]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][7]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][8]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][9]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][10]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][11]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][12]', 'function'),
 ('train[5071].subroutines.subroutine_0.outputs[0][13]', 'function'),
 ('train[5071].subroutines.sub

In [65]:
decomposed['train'][5071]

{'original_task_key': '6aa20dc0',
 'subroutines': {'subroutine_0': {'inputs': [[], [], []],
   'outputs': [[<function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>],
    [<function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x)>,
     <function dsl.compose.<locals>.<lambda>(x

In [177]:
def set_default(obj):
    if isinstance(obj, (set, frozenset)):
        return list(obj)
    raise TypeError


In [None]:
# with open("arc_decomposed.json", 'w') as fp:
#     json.dump(decomposed, fp, default=list)

In [179]:
import torch


In [180]:
torch.save(decomposed, "decomposed.pth")

AttributeError: Can't pickle local object 'compose.<locals>.<lambda>'

In [133]:
with open("arc_decomposed.json", 'r') as fp:
    solns = json.load(fp)

In [150]:
def convert_to_tuples(list_of_lists):
    return tuple(tuple(t) for t in list_of_lists)


In [152]:
convert_to_tuples(solns['train'][1000]['subroutines']['subroutine_0']['outputs'][0][0])

((4, [3, 1]),
 (4, [4, 2]),
 (4, [3, 3]),
 (4, [2, 1]),
 (4, [3, 2]),
 (4, [4, 1]),
 (4, [3, 4]),
 (4, [1, 1]),
 (4, [4, 3]),
 (4, [2, 2]),
 (4, [1, 3]),
 (4, [4, 4]),
 (4, [1, 2]),
 (4, [2, 3]),
 (4, [1, 4]),
 (4, [2, 4]))

In [153]:
charset = set()
for example in solns['train']:
    for subroutine_name in example['subroutines']:
        charset.update(set(str(example['subroutines'][subroutine_name]['outputs'])))
        if "-" in str(example['subroutines'][subroutine_name]['outputs']):
            print(str(example['subroutines'][subroutine_name]['outputs']))

[[[3, 8], [1, 8], [4, 6], [1, 7], [2, 6], [4, 8], [3, 6], [1, 6], [4, 7], [2, 8]], [[-1, 8], [6, 15], [-1, 14], [3, 16], [5, 16], [0, 14], [1, 9], [2, 14], [1, 12], [-1, 7], [6, 14], [-1, 10], [-1, 13], [4, 14], [3, 15], [0, 7], [0, 10], [1, 8], [1, 14], [2, 13], [-1, 9], [6, 16], [-1, 12], [4, 16], [3, 14], [5, 14], [0, 12], [1, 7], [2, 12], [1, 10]]]
[[[3, 8], [1, 8], [4, 6], [1, 7], [2, 6], [4, 8], [3, 6], [1, 6], [4, 7], [2, 8]], [[-1, 8], [6, 15], [-1, 14], [3, 16], [5, 16], [0, 14], [1, 9], [2, 14], [1, 12], [-1, 7], [6, 14], [-1, 10], [-1, 13], [4, 14], [3, 15], [0, 7], [0, 10], [1, 8], [1, 14], [2, 13], [-1, 9], [6, 16], [-1, 12], [4, 16], [3, 14], [5, 14], [0, 12], [1, 7], [2, 12], [1, 10]]]
[[[3, 8], [1, 8], [4, 6], [1, 7], [2, 6], [4, 8], [3, 6], [1, 6], [4, 7], [2, 8]], [[-1, 8], [6, 15], [-1, 14], [3, 16], [5, 16], [0, 14], [1, 9], [2, 14], [1, 12], [-1, 7], [6, 14], [-1, 10], [-1, 13], [4, 14], [3, 15], [0, 7], [0, 10], [1, 8], [1, 14], [2, 13], [-1, 9], [6, 16], [-1, 12]

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



[[[1, [10, 21]], [2, [3, 17]], [2, [6, -9]], [1, [25, 1]], [3, [1, 22]], [1, [20, 4]], [1, [18, 0]], [1, [5, 25]], [1, [-10, -5]], [1, [23, 20]], [5, [12, 7]], [4, [24, 22]], [2, [23, 27]], [1, [-8, 25]], [2, [27, 3]], [4, [6, 8]], [2, [2, -2]], [5, [1, 4]], [4, [11, 13]], [2, [14, -6]], [1, [-6, 20]], [1, [-3, -5]], [1, [-9, 5]], [5, [2, 7]], [2, [-1, 4]], [4, [17, -6]], [3, [14, 23]], [1, [21, -10]], [5, [8, 13]], [1, [10, 12]], [1, [25, -8]], [1, [20, -5]], [4, [12, -1]], [1, [-1, 5]], [1, [-2, 5]], [3, [-2, -1]], [3, [-1, -2]], [3, [-3, -4]], [3, [-9, 22]], [1, [11, 15]], [1, [-10, 27]], [1, [5, 16]], [4, [-3, 19]], [5, [-1, 21]], [1, [12, 25]], [2, [19, 24]], [4, [19, 17]], [5, [21, 14]], [2, [18, 22]], [2, [22, -2]], [4, [1, 3]], [2, [4, -6]], [2, [-3, -7]], [1, [0, -7]], [2, [2, 13]], [1, [10, 3]], [5, [-3, 2]], [3, [17, 21]], [4, [13, -9]], [3, [19, -7]], [1, [13, 10]], [1, [-9, 20]], [1, [-10, 18]], [5, [3, 8]], [4, [-1, 12]], [1, [5, 7]], [1, [0, 1]], [1, [10, 27]], [5, [4, 1

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



[[[2, [2, 15]], [6, [30, 15]], [6, [-4, 1]], [2, [-11, 40]], [4, [34, -3]], [5, [17, 3]], [6, [14, 25]], [4, [12, 7]], [3, [-1, 7]], [5, [24, 14]], [5, [2, 24]], [4, [-1, 14]], [5, [6, 0]], [5, [-12, -4]], [3, [23, 13]], [1, [21, 7]], [2, [24, 29]], [4, [3, 34]], [2, [39, 18]], [1, [28, 18]], [2, [21, 14]], [6, [38, 1]], [4, [0, 19]], [6, [16, 11]], [2, [-12, 11]], [2, [31, 40]], [2, [3, 0]], [3, [27, -3]], [6, [36, 39]], [5, [37, 13]], [2, [25, 32]], [1, [12, 34]], [5, [30, 12]], [3, [-9, 39]], [4, [-6, 1]], [4, [-1, 38]], [5, [21, 23]], [1, [-3, -5]], [5, [6, 8]], [2, [22, 7]], [4, [1, 12]], [2, [19, -8]], [2, [-1, 4]], [1, [39, 19]], [6, [4, 23]], [2, [-3, 2]], [1, [3, 1]], [5, [35, -9]], [5, [12, -10]], [2, [23, 10]], [1, [10, 12]], [1, [1, 39]], [3, [-11, 17]], [5, [-6, 36]], [5, [-3, 11]], [3, [11, -11]], [4, [36, 1]], [5, [38, -2]], [2, [34, -5]], [4, [13, 0]], [6, [2, 1]], [5, [39, 35]], [1, [-1, 5]], [4, [11, -4]], [4, [-12, -5]], [4, [20, 11]], [5, [3, 17]], [2, [19, 16]], [4


KeyboardInterrupt



In [149]:
charset

{' ',
 ',',
 '-',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 'F',
 'T',
 '[',
 ']',
 'a',
 'e',
 'l',
 'r',
 's',
 'u'}