diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index cbce0c1cc..8ab951b2d 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -10,8 +10,11 @@ from codegen.extensions.linear.linear_client import LinearClient from codegen.extensions.tools.linear_tools import ( linear_comment_on_issue_tool, + linear_create_issue_tool, linear_get_issue_comments_tool, linear_get_issue_tool, + linear_get_teams_tool, + linear_search_issues_tool, ) from ..tools import ( @@ -532,6 +535,68 @@ def _run(self, issue_id: str, body: str) -> str: return json.dumps(result, indent=2) +class LinearSearchIssuesInput(BaseModel): + """Input for searching Linear issues.""" + + query: str = Field(..., description="Search query string") + limit: int = Field(default=10, description="Maximum number of issues to return") + + +class LinearSearchIssuesTool(BaseTool): + """Tool for searching Linear issues.""" + + name: ClassVar[str] = "linear_search_issues" + description: ClassVar[str] = "Search for Linear issues using a query string" + args_schema: ClassVar[type[BaseModel]] = LinearSearchIssuesInput + client: LinearClient = Field(exclude=True) + + def __init__(self, client: LinearClient) -> None: + super().__init__(client=client) + + def _run(self, query: str, limit: int = 10) -> str: + result = linear_search_issues_tool(self.client, query, limit) + return json.dumps(result, indent=2) + + +class LinearCreateIssueInput(BaseModel): + """Input for creating a Linear issue.""" + + title: str = Field(..., description="Title of the issue") + description: str | None = Field(None, description="Optional description of the issue") + team_id: str | None = Field(None, description="Optional team ID. If not provided, uses the default team_id (recommended)") + + +class LinearCreateIssueTool(BaseTool): + """Tool for creating Linear issues.""" + + name: ClassVar[str] = "linear_create_issue" + description: ClassVar[str] = "Create a new Linear issue" + args_schema: ClassVar[type[BaseModel]] = LinearCreateIssueInput + client: LinearClient = Field(exclude=True) + + def __init__(self, client: LinearClient) -> None: + super().__init__(client=client) + + def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str: + result = linear_create_issue_tool(self.client, title, description, team_id) + return json.dumps(result, indent=2) + + +class LinearGetTeamsTool(BaseTool): + """Tool for getting Linear teams.""" + + name: ClassVar[str] = "linear_get_teams" + description: ClassVar[str] = "Get all Linear teams the authenticated user has access to" + client: LinearClient = Field(exclude=True) + + def __init__(self, client: LinearClient) -> None: + super().__init__(client=client) + + def _run(self) -> str: + result = linear_get_teams_tool(self.client) + return json.dumps(result, indent=2) + + def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: """Get all workspace tools initialized with a codebase. @@ -561,4 +626,7 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: LinearGetIssueTool(codebase), LinearGetIssueCommentsTool(codebase), LinearCommentOnIssueTool(codebase), + LinearSearchIssuesTool(codebase), + LinearCreateIssueTool(codebase), + LinearGetTeamsTool(codebase), ] diff --git a/src/codegen/extensions/linear/linear_client.py b/src/codegen/extensions/linear/linear_client.py index ef0e1b96a..e27867dca 100644 --- a/src/codegen/extensions/linear/linear_client.py +++ b/src/codegen/extensions/linear/linear_client.py @@ -17,6 +17,14 @@ class LinearUser(BaseModel): name: str +class LinearTeam(BaseModel): + """Represents a Linear team.""" + + id: str + name: str + key: str + + class LinearComment(BaseModel): id: str body: str @@ -33,13 +41,18 @@ class LinearClient: api_headers: dict api_endpoint = "https://api.linear.app/graphql" - def __init__(self, access_token: Optional[str] = None): + def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] = None): if not access_token: access_token = os.getenv("LINEAR_ACCESS_TOKEN") if not access_token: msg = "access_token is required" raise ValueError(msg) self.access_token = access_token + + if not team_id: + team_id = os.getenv("LINEAR_TEAM_ID") + self.team_id = team_id + self.api_headers = { "Content-Type": "application/json", "Authorization": self.access_token, @@ -194,3 +207,101 @@ def search_issues(self, query: str, limit: int = 10) -> list[LinearIssue]: except Exception as e: msg = f"Error searching issues\n{data}\n{e}" raise Exception(msg) + + def create_issue(self, title: str, description: str | None = None, team_id: str | None = None) -> LinearIssue: + """Create a new issue. + + Args: + title: Title of the issue + description: Optional description of the issue + team_id: Optional team ID. If not provided, uses the client's configured team_id + + Returns: + The created LinearIssue object + + Raises: + ValueError: If no team_id is provided or configured + """ + if not team_id: + team_id = self.team_id + if not team_id: + msg = "team_id must be provided either during client initialization or in the create_issue call" + raise ValueError(msg) + + mutation = """ + mutation createIssue($input: IssueCreateInput!) { + issueCreate(input: $input) { + success + issue { + id + title + description + } + } + } + """ + + variables = { + "input": { + "teamId": team_id, + "title": title, + "description": description, + } + } + + response = requests.post( + self.api_endpoint, + headers=self.api_headers, + json={"query": mutation, "variables": variables}, + ) + data = response.json() + + try: + issue_data = data["data"]["issueCreate"]["issue"] + return LinearIssue( + id=issue_data["id"], + title=issue_data["title"], + description=issue_data["description"], + ) + except Exception as e: + msg = f"Error creating issue\n{data}\n{e}" + raise Exception(msg) + + def get_teams(self) -> list[LinearTeam]: + """Get all teams the authenticated user has access to. + + Returns: + List of LinearTeam objects + """ + query = """ + query { + teams { + nodes { + id + name + key + } + } + } + """ + + response = requests.post( + self.api_endpoint, + headers=self.api_headers, + json={"query": query}, + ) + data = response.json() + + try: + teams_data = data["data"]["teams"]["nodes"] + return [ + LinearTeam( + id=team["id"], + name=team["name"], + key=team["key"], + ) + for team in teams_data + ] + except Exception as e: + msg = f"Error getting teams\n{data}\n{e}" + raise Exception(msg) diff --git a/src/codegen/extensions/tools/linear_tools.py b/src/codegen/extensions/tools/linear_tools.py index 9bc02f5fe..c15cb5d10 100644 --- a/src/codegen/extensions/tools/linear_tools.py +++ b/src/codegen/extensions/tools/linear_tools.py @@ -37,3 +37,30 @@ def linear_register_webhook_tool(client: LinearClient, webhook_url: str, team_id return {"status": "success", "response": response} except Exception as e: return {"error": f"Failed to register webhook: {e!s}"} + + +def linear_search_issues_tool(client: LinearClient, query: str, limit: int = 10) -> dict[str, Any]: + """Search for issues using a query string.""" + try: + issues = client.search_issues(query, limit) + return {"status": "success", "issues": [issue.dict() for issue in issues]} + except Exception as e: + return {"error": f"Failed to search issues: {e!s}"} + + +def linear_create_issue_tool(client: LinearClient, title: str, description: str | None = None, team_id: str | None = None) -> dict[str, Any]: + """Create a new issue.""" + try: + issue = client.create_issue(title, description, team_id) + return {"status": "success", "issue": issue.dict()} + except Exception as e: + return {"error": f"Failed to create issue: {e!s}"} + + +def linear_get_teams_tool(client: LinearClient) -> dict[str, Any]: + """Get all teams the authenticated user has access to.""" + try: + teams = client.get_teams() + return {"status": "success", "teams": [team.dict() for team in teams]} + except Exception as e: + return {"error": f"Failed to get teams: {e!s}"} diff --git a/tests/integration/extension/test_linear.py b/tests/integration/extension/test_linear.py index c6eb3606d..5a758f2cb 100644 --- a/tests/integration/extension/test_linear.py +++ b/tests/integration/extension/test_linear.py @@ -7,8 +7,11 @@ from codegen.extensions.linear.linear_client import LinearClient from codegen.extensions.tools.linear_tools import ( linear_comment_on_issue_tool, + linear_create_issue_tool, linear_get_issue_comments_tool, linear_get_issue_tool, + linear_get_teams_tool, + linear_search_issues_tool, ) @@ -18,7 +21,10 @@ def client() -> LinearClient: token = os.getenv("LINEAR_ACCESS_TOKEN") if not token: pytest.skip("LINEAR_ACCESS_TOKEN environment variable not set") - return LinearClient(token) + team_id = os.getenv("LINEAR_TEAM_ID") + if not team_id: + pytest.skip("LINEAR_TEAM_ID environment variable not set") + return LinearClient(token, team_id) def test_linear_get_issue(client: LinearClient) -> None: @@ -45,6 +51,45 @@ def test_linear_comment_on_issue(client: LinearClient) -> None: def test_search_issues(client: LinearClient) -> None: """Test searching for issues in Linear.""" - issues = client.search_issues("REVEAL_SYMBOL") + issues = linear_search_issues_tool(client, "REVEAL_SYMBOL") assert issues["status"] == "success" assert len(issues["issues"]) > 0 + + +def test_create_issue(client: LinearClient) -> None: + """Test creating an issue in Linear.""" + # Test creating an issue with explicit team_id + title = "Test Issue - Automated Testing (Explicit Team)" + description = "This is a test issue created by automated testing with explicit team_id" + + issue = client.create_issue(title, description) + assert issue.title == title + assert issue.description == description + + # Test creating an issue using default team_id from environment + title2 = "Test Issue - Automated Testing (Default Team)" + description2 = "This is a test issue created by automated testing with default team_id" + + issue2 = client.create_issue(title2, description2) + assert issue2.title == title2 + assert issue2.description == description2 + + # Test the tool wrapper with default team_id + result = linear_create_issue_tool(client, "Test Tool Issue", "Test description from tool") + assert result["status"] == "success" + assert result["issue"]["title"] == "Test Tool Issue" + assert result["issue"]["description"] == "Test description from tool" + + +def test_get_teams(client: LinearClient) -> None: + """Test getting teams from Linear.""" + result = linear_get_teams_tool(client) + assert result["status"] == "success" + assert len(result["teams"]) > 0 + + # Verify team structure + team = result["teams"][0] + print(result) + assert "id" in team + assert "name" in team + assert "key" in team