Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
class ViewFileTool(BaseTool):
"""Tool for viewing file contents and metadata."""

name: ClassVar[str] = "view_file"

Check failure on line 54 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "View the contents and metadata of a file in the codebase"

Check failure on line 55 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = ViewFileInput
codebase: Codebase = Field(exclude=True)

Expand All @@ -61,7 +61,7 @@

def _run(self, filepath: str) -> str:
result = view_file(self.codebase, filepath)
return json.dumps(result, indent=2)
return result.render()


class ListDirectoryInput(BaseModel):
Expand All @@ -84,7 +84,7 @@

def _run(self, dirpath: str = "./", depth: int = 1) -> str:
result = list_directory(self.codebase, dirpath, depth)
return json.dumps(result, indent=2)
return result.render()


class SearchInput(BaseModel):
Expand All @@ -107,7 +107,7 @@

def _run(self, query: str, target_directories: Optional[list[str]] = None) -> str:
result = search(self.codebase, query, target_directories)
return json.dumps(result, indent=2)
return result.render()


class EditFileInput(BaseModel):
Expand All @@ -130,7 +130,7 @@

def _run(self, filepath: str, content: str) -> str:
result = edit_file(self.codebase, filepath, content)
return json.dumps(result, indent=2)
return result.render()


class CreateFileInput(BaseModel):
Expand All @@ -153,7 +153,7 @@

def _run(self, filepath: str, content: str = "") -> str:
result = create_file(self.codebase, filepath, content)
return json.dumps(result, indent=2)
return result.render()


class DeleteFileInput(BaseModel):
Expand All @@ -175,7 +175,7 @@

def _run(self, filepath: str) -> str:
result = delete_file(self.codebase, filepath)
return json.dumps(result, indent=2)
return result.render()


class CommitTool(BaseTool):
Expand All @@ -190,7 +190,7 @@

def _run(self) -> str:
result = commit(self.codebase)
return json.dumps(result, indent=2)
return result.render()


class RevealSymbolInput(BaseModel):
Expand Down Expand Up @@ -233,7 +233,7 @@
collect_dependencies=collect_dependencies,
collect_usages=collect_usages,
)
return json.dumps(result, indent=2)
return result.render()


_SEMANTIC_EDIT_BRIEF = """Tool for file editing via an LLM delegate. Describe the changes you want to make and an expert will apply them to the file.
Expand Down Expand Up @@ -278,7 +278,7 @@
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
# Create the the draft editor mini llm
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
return json.dumps(result, indent=2)
return result.render()


class RenameFileInput(BaseModel):
Expand All @@ -301,7 +301,7 @@

def _run(self, filepath: str, new_filepath: str) -> str:
result = rename_file(self.codebase, filepath, new_filepath)
return json.dumps(result, indent=2)
return result.render()


class MoveSymbolInput(BaseModel):
Expand Down Expand Up @@ -344,7 +344,7 @@
strategy=strategy,
include_dependencies=include_dependencies,
)
return json.dumps(result, indent=2)
return result.render()


class SemanticSearchInput(BaseModel):
Expand All @@ -368,7 +368,7 @@

def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str:
result = semantic_search(self.codebase, query, k=k, preview_length=preview_length)
return json.dumps(result, indent=2)
return result.render()


########################################################################################################################
Expand All @@ -392,7 +392,7 @@

def _run(self, command: str, is_background: bool = False) -> str:
result = run_bash_command(command, is_background)
return json.dumps(result, indent=2)
return result.render()


########################################################################################################################
Expand Down Expand Up @@ -420,7 +420,7 @@

def _run(self, title: str, body: str) -> str:
result = create_pr(self.codebase, title, body)
return json.dumps(result, indent=2)
return result.render()


class GithubViewPRInput(BaseModel):
Expand All @@ -442,6 +442,7 @@

def _run(self, pr_id: int) -> str:
result = view_pr(self.codebase, pr_id)
return result.render()
return json.dumps(result, indent=2)


Expand All @@ -465,7 +466,7 @@

def _run(self, pr_number: int, body: str) -> str:
result = create_pr_comment(self.codebase, pr_number, body)
return json.dumps(result, indent=2)
return result.render()


class GithubCreatePRReviewCommentInput(BaseModel):
Expand Down Expand Up @@ -511,7 +512,7 @@
side=side,
start_line=start_line,
)
return json.dumps(result, indent=2)
return result.render()


########################################################################################################################
Expand All @@ -538,7 +539,7 @@

def _run(self, issue_id: str) -> str:
result = linear_get_issue_tool(self.client, issue_id)
return json.dumps(result, indent=2)
return result.render()


class LinearGetIssueCommentsInput(BaseModel):
Expand All @@ -560,7 +561,7 @@

def _run(self, issue_id: str) -> str:
result = linear_get_issue_comments_tool(self.client, issue_id)
return json.dumps(result, indent=2)
return result.render()


class LinearCommentOnIssueInput(BaseModel):
Expand All @@ -583,7 +584,7 @@

def _run(self, issue_id: str, body: str) -> str:
result = linear_comment_on_issue_tool(self.client, issue_id, body)
return json.dumps(result, indent=2)
return result.render()


class LinearSearchIssuesInput(BaseModel):
Expand All @@ -606,7 +607,7 @@

def _run(self, query: str, limit: int = 10) -> str:
result = linear_search_issues_tool(self.client, query, limit)
return json.dumps(result, indent=2)
return result.render()


class LinearCreateIssueInput(BaseModel):
Expand All @@ -630,7 +631,7 @@

def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str:
result = linear_create_issue_tool(self.client, title, description, team_id)
return json.dumps(result, indent=2)
return result.render()


class LinearGetTeamsTool(BaseTool):
Expand All @@ -645,7 +646,7 @@

def _run(self) -> str:
result = linear_get_teams_tool(self.client)
return json.dumps(result, indent=2)
return result.render()


########################################################################################################################
Expand Down Expand Up @@ -678,6 +679,7 @@
self.codebase = codebase

def _run(self, content: str) -> str:
# TODO - pull this out into a separate function
print("> Adding links to message")
content_formatted = add_links_to_message(content, self.codebase)
print("> Sending message to Slack")
Expand Down
84 changes: 58 additions & 26 deletions src/codegen/extensions/tools/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import re
import shlex
import subprocess
from typing import Any
from typing import ClassVar, Optional

from pydantic import Field

from .observation import Observation

# Whitelist of allowed commands and their flags
ALLOWED_COMMANDS = {
Expand All @@ -22,6 +26,28 @@
}


class RunBashCommandObservation(Observation):
"""Response from running a bash command."""

stdout: Optional[str] = Field(
default=None,
description="Standard output from the command",
)
stderr: Optional[str] = Field(
default=None,
description="Standard error from the command",
)
command: str = Field(
description="The command that was executed",
)
pid: Optional[int] = Field(
default=None,
description="Process ID for background commands",
)

str_template: ClassVar[str] = "Command '{command}' completed"


def validate_command(command: str) -> tuple[bool, str]:
"""Validate if a command is safe to execute.

Expand Down Expand Up @@ -90,23 +116,24 @@ def validate_command(command: str) -> tuple[bool, str]:
return False, f"Failed to validate command: {e!s}"


def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any]:
def run_bash_command(command: str, is_background: bool = False) -> RunBashCommandObservation:
"""Run a bash command and return its output.

Args:
command: The command to run
is_background: Whether to run the command in the background

Returns:
Dictionary containing the command output or error
RunBashCommandObservation containing the command output or error
"""
# First validate the command
is_valid, error_message = validate_command(command)
if not is_valid:
return {
"status": "error",
"error": f"Invalid command: {error_message}",
}
return RunBashCommandObservation(
status="error",
error=f"Invalid command: {error_message}",
command=command,
)

try:
if is_background:
Expand All @@ -118,10 +145,11 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any
stderr=subprocess.PIPE,
text=True,
)
return {
"status": "success",
"message": f"Command '{command}' started in background with PID {process.pid}",
}
return RunBashCommandObservation(
status="success",
command=command,
pid=process.pid,
)

# For foreground processes, we wait for completion
result = subprocess.run(
Expand All @@ -132,20 +160,24 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any
check=True, # This will raise CalledProcessError if command fails
)

return {
"status": "success",
"stdout": result.stdout,
"stderr": result.stderr,
}
return RunBashCommandObservation(
status="success",
command=command,
stdout=result.stdout,
stderr=result.stderr,
)

except subprocess.CalledProcessError as e:
return {
"status": "error",
"error": f"Command failed with exit code {e.returncode}",
"stdout": e.stdout,
"stderr": e.stderr,
}
return RunBashCommandObservation(
status="error",
error=f"Command failed with exit code {e.returncode}",
command=command,
stdout=e.stdout,
stderr=e.stderr,
)
except Exception as e:
return {
"status": "error",
"error": f"Failed to run command: {e!s}",
}
return RunBashCommandObservation(
status="error",
error=f"Failed to run command: {e!s}",
command=command,
)
34 changes: 29 additions & 5 deletions src/codegen/extensions/tools/commit.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
"""Tool for committing changes to disk."""

from typing import Any
from typing import ClassVar

from pydantic import Field

from codegen import Codebase

from .observation import Observation


class CommitObservation(Observation):
"""Response from committing changes to disk."""

message: str = Field(
description="Message describing the commit result",
)

str_template: ClassVar[str] = "{message}"


def commit(codebase: Codebase) -> dict[str, Any]:
def commit(codebase: Codebase) -> CommitObservation:
"""Commit any pending changes to disk.

Args:
codebase: The codebase to operate on

Returns:
Dict containing commit status
CommitObservation containing commit status
"""
codebase.commit()
return {"status": "success", "message": "Changes committed to disk"}
try:
codebase.commit()
return CommitObservation(
status="success",
message="Changes committed to disk",
)
except Exception as e:
return CommitObservation(
status="error",
error=f"Failed to commit changes: {e!s}",
message="Failed to commit changes",
)
Loading