diff --git a/camel/toolkits/github_toolkit.py b/camel/toolkits/github_toolkit.py index f1eb0d30f..caa8252fd 100644 --- a/camel/toolkits/github_toolkit.py +++ b/camel/toolkits/github_toolkit.py @@ -14,6 +14,7 @@ import os from dataclasses import dataclass +from datetime import datetime, timedelta from typing import List, Optional from camel.functions import OpenAIFunction @@ -57,8 +58,8 @@ def __init__( self.file_path = file_path self.file_content = file_content - def summary(self) -> str: - r"""Returns a summary of the issue. + def __str__(self) -> str: + r"""Returns a string representation of the issue. Returns: str: A string containing the title, body, number, file path, and @@ -73,6 +74,48 @@ def summary(self) -> str: ) +@dataclass +class GithubPullRequestDiff: + r"""Represents a single diff of a pull request on Github. + + Attributes: + filename (str): The name of the file that was changed. + patch (str): The diff patch for the file. + """ + + filename: str + patch: str + + def __str__(self) -> str: + r"""Returns a string representation of this diff.""" + return f"Filename: {self.filename}\nPatch: {self.patch}" + + +@dataclass +class GithubPullRequest: + r"""Represents a pull request on Github. + + Attributes: + title (str): The title of the GitHub pull request. + body (str): The body/content of the GitHub pull request. + diffs (List[GithubPullRequestDiff]): A list of diffs for the pull + request. + """ + + title: str + body: str + diffs: List[GithubPullRequestDiff] + + def __str__(self) -> str: + r"""Returns a string representation of the pull request.""" + diff_summaries = '\n'.join(str(diff) for diff in self.diffs) + return ( + f"Title: {self.title}\n" + f"Body: {self.body}\n" + f"Diffs: {diff_summaries}\n" + ) + + class GithubToolkit(BaseToolkit): r"""A class representing a toolkit for interacting with GitHub repositories. @@ -106,7 +149,7 @@ def __init__( except ImportError: raise ImportError( "Please install `github` first. You can install it by running " - "`pip install wikipedia`." + "`pip install pygithub`." ) self.github = Github(auth=Auth.Token(access_token)) self.repo = self.github.get_repo(repo_name) @@ -123,6 +166,7 @@ def get_tools(self) -> List[OpenAIFunction]: OpenAIFunction(self.retrieve_issue_list), OpenAIFunction(self.retrieve_issue), OpenAIFunction(self.create_pull_request), + OpenAIFunction(self.retrieve_pull_requests), ] def get_github_access_token(self) -> str: @@ -181,9 +225,47 @@ def retrieve_issue(self, issue_number: int) -> Optional[str]: issues = self.retrieve_issue_list() for issue in issues: if issue.number == issue_number: - return issue.summary() + return str(issue) return None + def retrieve_pull_requests( + self, days: int, state: str, max_prs: int + ) -> List[str]: + r"""Retrieves a summary of merged pull requests from the repository. + The summary will be provided for the last specified number of days. + + Args: + days (int): The number of days to retrieve merged pull requests + for. + state (str): A specific state of PRs to retrieve. Can be open or + closed. + max_prs (int): The maximum number of PRs to retrieve. + + Returns: + List[str]: A list of merged pull request summaries. + """ + pull_requests = self.repo.get_pulls(state=state) + merged_prs = [] + earliest_date: datetime = datetime.utcnow() - timedelta(days=days) + + for pr in pull_requests[:max_prs]: + if ( + pr.merged + and pr.merged_at is not None + and pr.merged_at.timestamp() > earliest_date.timestamp() + ): + pr_details = GithubPullRequest(pr.title, pr.body, []) + + # Get files changed in the PR + files = pr.get_files() + + for file in files: + diff = GithubPullRequestDiff(file.filename, file.patch) + pr_details.diffs.append(diff) + + merged_prs.append(str(pr_details)) + return merged_prs + def create_pull_request( self, file_path: str, diff --git a/examples/function_call/github_examples.py b/examples/function_call/github_examples.py index d87dfba1a..cfd9ea657 100644 --- a/examples/function_call/github_examples.py +++ b/examples/function_call/github_examples.py @@ -11,15 +11,74 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import argparse + from colorama import Fore from camel.agents import ChatAgent from camel.configs import ChatGPTConfig +from camel.functions import OpenAIFunction from camel.messages import BaseMessage from camel.toolkits import GithubToolkit from camel.utils import print_text_animated +def write_weekly_pr_summary(repo_name, model=None): + prompt = """ + You need to write a summary of the pull requests that were merged in the + last week. + You can use the provided github function retrieve_pull_requests to + retrieve the list of pull requests that were merged in the last week. + The maximum amount of PRs to analyze is 3. + You have to pass the number of days as the first parameter to + retrieve_pull_requests and state='closed' as the second parameter. + The function will return a list of pull requests with the following + properties: title, body, and diffs. + Diffs is a list of dictionaries with the following properties: filename, + diff. + You will have to look closely at each diff to understand the changes that + were made in each pull request. + Output a twitter post that describes recent changes in the project and + thanks the contributors. + + Here is an example of a summary for one pull request: + 📢 We've improved function calling in the 🐪 CAMEL-AI framework! + This update enhances the handling of various docstring styles and supports + enum types, ensuring more accurate and reliable function calls. + Thanks to our contributor Jiahui Zhang for making this possible. + """ + print(Fore.YELLOW + f"Final prompt:\n{prompt}\n") + + toolkit = GithubToolkit(repo_name=repo_name) + assistant_sys_msg = BaseMessage.make_assistant_message( + role_name="Marketing Manager", + content=f""" + You are an experienced marketing manager responsible for posting + weekly updates about the status + of an open source project {repo_name} on the project's blog. + """, + ) + assistant_model_config = ChatGPTConfig( + tools=[OpenAIFunction(toolkit.retrieve_pull_requests)], + temperature=0.0, + ) + agent = ChatAgent( + assistant_sys_msg, + model_type=model, + model_config=assistant_model_config, + tools=[OpenAIFunction(toolkit.retrieve_pull_requests)], + ) + agent.reset() + + user_msg = BaseMessage.make_user_message(role_name="User", content=prompt) + assistant_response = agent.step(user_msg) + + if len(assistant_response.msgs) > 0: + print_text_animated( + Fore.GREEN + f"Agent response:\n{assistant_response.msg.content}\n" + ) + + def solve_issue( repo_name, issue_number, @@ -70,8 +129,12 @@ def solve_issue( def main(model=None) -> None: - repo_name = "camel-ai/test-github-agent" - solve_issue(repo_name=repo_name, issue_number=1, model=model) + parser = argparse.ArgumentParser(description='Enter repo name.') + parser.add_argument('repo_name', type=str, help='Name of the repository') + args = parser.parse_args() + + repo_name = args.repo_name + write_weekly_pr_summary(repo_name=repo_name, model=model) if __name__ == "__main__": diff --git a/test/toolkits/test_github_toolkit.py b/test/toolkits/test_github_toolkit.py index 075590d2f..714697aff 100644 --- a/test/toolkits/test_github_toolkit.py +++ b/test/toolkits/test_github_toolkit.py @@ -12,12 +12,18 @@ # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from datetime import datetime from unittest.mock import MagicMock, patch from github import Auth, Github from github.ContentFile import ContentFile -from camel.toolkits.github_toolkit import GithubIssue, GithubToolkit +from camel.toolkits.github_toolkit import ( + GithubIssue, + GithubPullRequest, + GithubPullRequestDiff, + GithubToolkit, +) @patch.object(Github, '__init__', lambda self, *args, **kwargs: None) @@ -120,9 +126,9 @@ def test_retrieve_issue(monkeypatch): file_path="path/to/file", file_content="This is the content of the file", ) - assert ( - issue == expected_issue.summary() - ), f"Expected {expected_issue.summary()}, but got {issue}" + assert issue == str( + expected_issue + ), f"Expected {expected_issue}, but got {issue}" @patch.object(Github, 'get_repo', return_value=MagicMock()) @@ -165,6 +171,54 @@ def test_create_pull_request(monkeypatch): ), f"Expected {expected_response}, but got {pr}" +@patch.object(Github, 'get_repo', return_value=MagicMock()) +@patch.object(Auth.Token, '__init__', lambda self, *args, **kwargs: None) +def test_retrieve_pull_requests(monkeypatch): + # Call the constructor of the GithubToolkit class + github_toolkit = GithubToolkit("repo_name", "token") + + # Create a mock file + mock_file = MagicMock() + mock_file.filename = "path/to/file" + mock_file.diff = "This is the diff of the file" + + # Create a mock pull request + mock_pull_request = MagicMock() + mock_pull_request.title = "Test PR" + mock_pull_request.body = "This is a test issue" + mock_pull_request.merged_at = datetime.utcnow() + + # Create a mock file + mock_file = MagicMock() + mock_file.filename = "path/to/file" + mock_file.patch = "This is the diff of the file" + + # Mock the get_files method of the mock_pull_request instance to return a + # list containing the mock file object + mock_pull_request.get_files.return_value = [mock_file] + + # Mock the get_issues method of the mock repo instance to return a list + # containing the mock issue object + github_toolkit.repo.get_pulls.return_value = [mock_pull_request] + + pull_requests = github_toolkit.retrieve_pull_requests( + days=7, state='closed', max_prs=3 + ) + # Assert the returned issue list + expected_pull_request = GithubPullRequest( + title="Test PR", + body="This is a test issue", + diffs=[ + GithubPullRequestDiff( + filename="path/to/file", patch="This is the diff of the file" + ) + ], + ) + assert pull_requests == [ + str(expected_pull_request) + ], f"Expected {expected_pull_request}, but got {pull_requests}" + + def test_github_issue(): # Create a GithubIssue object issue = GithubIssue( @@ -183,7 +237,7 @@ def test_github_issue(): assert issue.file_content == "This is the content of the file" # Test the summary method - summary = issue.summary() + summary = str(issue) expected_summary = ( f"Title: {issue.title}\n" f"Body: {issue.body}\n"