diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index e5116630f..3da422560 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -71,9 +71,250 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False jitter=True, ) + # Custom error handler for tool validation errors + def handle_tool_errors(exception): + error_msg = str(exception) + + # Extract tool name and input from the exception if possible + tool_name = "unknown" + tool_input = {} + + # Helper function to get field descriptions from any tool + def get_field_descriptions(tool_obj): + field_descriptions = {} + if not tool_obj or not hasattr(tool_obj, "args_schema"): + return field_descriptions + + try: + schema_cls = tool_obj.args_schema + + # Handle Pydantic v2 + if hasattr(schema_cls, "model_fields"): + for field_name, field in schema_cls.model_fields.items(): + field_descriptions[field_name] = field.description or f"Required parameter for {tool_obj.name}" + + # Handle Pydantic v1 with warning suppression + elif hasattr(schema_cls, "__fields__"): + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + for field_name, field in schema_cls.__fields__.items(): + field_descriptions[field_name] = field.field_info.description or f"Required parameter for {tool_obj.name}" + except Exception: + pass + + return field_descriptions + + # Try to extract tool name and input from the exception + import re + + tool_match = re.search(r"for (\w+)Input", error_msg) + if tool_match: + # Get the extracted name but preserve original case by finding the matching tool + extracted_name = tool_match.group(1).lower() + for t in self.tools: + if t.name.lower() == extracted_name: + tool_name = t.name # Use the original case from the tool + break + + # Try to extract the input values + input_match = re.search(r"input_value=(\{.*?\})", error_msg) + if input_match: + input_str = input_match.group(1) + # Simple parsing of the dict-like string + try: + # Clean up the string to make it more parseable + input_str = input_str.replace("'", '"') + import json + + tool_input = json.loads(input_str) + except: + pass + + # Handle validation errors with more helpful messages + if "validation error" in error_msg.lower(): + # Find the tool in our tools list to get its schema + tool = next((t for t in self.tools if t.name == tool_name), None) + + # If we couldn't find the tool by extracted name, try to find it by looking at all tools + if tool is None: + # Try to extract tool name from the error message + for t in self.tools: + if t.name.lower() in error_msg.lower(): + tool = t + tool_name = t.name + break + + # If still not found, check if any tool's schema name matches + if tool is None: + for t in self.tools: + if hasattr(t, "args_schema") and t.args_schema.__name__.lower() in error_msg.lower(): + tool = t + tool_name = t.name + break + + # Check for type errors + type_errors = [] + if "type_error" in error_msg.lower(): + import re + + # Try to extract type error information + type_error_matches = re.findall(r"'(\w+)'.*?type_error\.(.*?)(?:;|$)", error_msg, re.IGNORECASE) + for field_name, error_type in type_error_matches: + if "json" in error_type: + type_errors.append(f"'{field_name}' must be a string, not a JSON object or dictionary") + elif "str_type" in error_type: + type_errors.append(f"'{field_name}' must be a string") + elif "int_type" in error_type: + type_errors.append(f"'{field_name}' must be an integer") + elif "bool_type" in error_type: + type_errors.append(f"'{field_name}' must be a boolean") + elif "list_type" in error_type: + type_errors.append(f"'{field_name}' must be a list") + else: + type_errors.append(f"'{field_name}' has an incorrect type") + + if type_errors: + errors_str = "\n- ".join(type_errors) + return f"Error using {tool_name} tool: Parameter type errors:\n- {errors_str}\n\nYou provided: {tool_input}\n\nPlease try again with the correct parameter types." + + # Get missing fields by comparing tool input with required fields + missing_fields = [] + if tool and hasattr(tool, "args_schema"): + try: + # Get the schema class + schema_cls = tool.args_schema + + # Handle Pydantic v2 (preferred) or v1 with warning suppression + if hasattr(schema_cls, "model_fields"): # Pydantic v2 + for field_name, field in schema_cls.model_fields.items(): + # Check if field is required and missing from input + if field.is_required() and field_name not in tool_input: + missing_fields.append(field_name) + else: # Pydantic v1 with warning suppression + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + for field_name, field in schema_cls.__fields__.items(): + # Check if field is required and missing from input + if field.required and field_name not in tool_input: + missing_fields.append(field_name) + except Exception as e: + # If we can't extract schema info, we'll fall back to regex + pass + + # If we couldn't get missing fields from schema, try to extract from error message + if not missing_fields: + # Extract the missing field name if possible using regex + import re + + field_matches = re.findall(r"'(\w+)'(?:\s+|.*?)field required", error_msg, re.IGNORECASE) + if field_matches: + missing_fields = field_matches + else: + # Try another pattern + field_match = re.search(r"(\w+)\s+Field required", error_msg) + if field_match: + missing_fields = [field_match.group(1)] + + # If we have identified missing fields, create a helpful error message + if missing_fields: + fields_str = ", ".join([f"'{f}'" for f in missing_fields]) + + # Get tool documentation if available + tool_docs = "" + if tool: + if hasattr(tool, "description") and tool.description: + tool_docs = f"\nTool description: {tool.description}\n" + + # Try to get parameter descriptions from the schema + param_docs = [] + try: + # Get all field descriptions from the tool + field_descriptions = get_field_descriptions(tool) + + # Add descriptions for missing fields + for field_name in missing_fields: + if field_name in field_descriptions: + param_docs.append(f"- {field_name}: {field_descriptions[field_name]}") + else: + param_docs.append(f"- {field_name}: Required parameter") + + if param_docs: + tool_docs += "\nParameter descriptions:\n" + "\n".join(param_docs) + except Exception: + # Fallback to simple parameter list + param_docs = [f"- {field}: Required parameter" for field in missing_fields] + if param_docs: + tool_docs += "\nMissing parameters:\n" + "\n".join(param_docs) + + # Add usage examples for common tools + example = "" + if tool_name == "create_file": + example = "\nExample: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + elif tool_name == "replace_edit": + example = "\nExample: replace_edit(filepath='path/to/file.py', old_text='def old_function()', new_text='def new_function()')" + elif tool_name == "view_file": + example = "\nExample: view_file(filepath='path/to/file.py')" + elif tool_name == "search": + example = "\nExample: search(query='function_name', file_extensions=['.py'])" + + return ( + f"Error using {tool_name} tool: Missing required parameter(s): {fields_str}\n\nYou provided: {tool_input}\n{tool_docs}{example}\nPlease try again with all required parameters." + ) + + # Common error patterns for specific tools (as fallback) + if tool_name == "create_file": + if "content" not in tool_input: + return ( + "Error: When using the create_file tool, you must provide both 'filepath' and 'content' parameters.\n" + "The 'content' parameter is missing. Please try again with both parameters.\n\n" + "Example: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + ) + elif "filepath" not in tool_input: + return ( + "Error: When using the create_file tool, you must provide both 'filepath' and 'content' parameters.\n" + "The 'filepath' parameter is missing. Please try again with both parameters.\n\n" + "Example: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + ) + + elif tool_name == "replace_edit": + if "filepath" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'filepath' parameter is missing. Please try again with all required parameters." + ) + elif "old_text" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'old_text' parameter is missing. Please try again with all required parameters." + ) + elif "new_text" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'new_text' parameter is missing. Please try again with all required parameters." + ) + + # Generic validation error with better formatting + if tool: + return ( + f"Error using {tool_name} tool: {error_msg}\n\n" + f"You provided these parameters: {tool_input}\n\n" + f"Please check the tool's required parameters and try again with all required fields." + ) + else: + # If we couldn't identify the tool, list all available tools + available_tools = "\n".join([f"- {t.name}" for t in self.tools]) + return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters." + + # For other types of errors + return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters." + # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) - builder.add_node("tools", ToolNode(self.tools), retry=retry_policy) + builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) # Add edges builder.add_edge(START, "reasoner") diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 0ce6b97a1..877b59f05 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -136,7 +136,7 @@ class SearchTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, query: str, target_directories: Optional[list[str]] = None, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: + def _run(self, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex) return result.render() @@ -171,7 +171,6 @@ class EditFileTool(BaseTool): 1. Simple text: "function calculateTotal" (matches exactly, case-insensitive) 2. Regex: "def.*calculate.*\(.*\)" (with use_regex=True) 3. File-specific: "TODO" with file_extensions=[".py", ".ts"] - 4. Directory-specific: "api" with target_directories=["src/backend"] """ args_schema: ClassVar[type[BaseModel]] = EditFileInput codebase: Codebase = Field(exclude=True) @@ -188,21 +187,45 @@ class CreateFileInput(BaseModel): """Input for creating a file.""" filepath: str = Field(..., description="Path where to create the file") - content: str = Field(default="", description="Initial file content") + content: str = Field( + ..., + description=""" +Content for the new file (REQUIRED). + +⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type. +Example: content="print('Hello world')" +NOT: content={"code": "print('Hello world')"} + """, + ) class CreateFileTool(BaseTool): """Tool for creating files.""" name: ClassVar[str] = "create_file" - description: ClassVar[str] = "Create a new file in the codebase" + description: ClassVar[str] = """ +Create a new file in the codebase. Always provide content for the new file, even if minimal. + +⚠️ CRITICAL WARNING ⚠️ +Both parameters MUST be provided as STRINGS: +The content for the new file always needs to be provided. + +1. filepath: The path where to create the file (as a string) +2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object) + +✅ CORRECT usage: +create_file(filepath="path/to/file.py", content="print('Hello world')") + +The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about +missing content, you are likely trying to pass a dictionary instead of a string. +""" args_schema: ClassVar[type[BaseModel]] = CreateFileInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, filepath: str, content: str = "") -> str: + def _run(self, filepath: str, content: str) -> str: result = create_file(self.codebase, filepath, content) return result.render() diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index b10d01f52..3a54303ff 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -23,13 +23,13 @@ class CreateFileObservation(Observation): str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str = "") -> CreateFileObservation: +def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation: """Create a new file. Args: codebase: The codebase to operate on filepath: Path where to create the file - content: Initial file content + content: Content for the new file (required) Returns: CreateFileObservation containing new file state, or error if file exists