Skip to content

Commit

Permalink
feat: Retrive recent pull requests for GithubToolkit (#582)
Browse files Browse the repository at this point in the history
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
Co-authored-by: Wendong <w3ndong.fan@gmail.com>
  • Loading branch information
3 people committed Jun 16, 2024
1 parent d0e3722 commit 26100a9
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 11 deletions.
90 changes: 86 additions & 4 deletions camel/toolkits/github_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional

from camel.functions import OpenAIFunction
Expand Down Expand Up @@ -57,8 +58,8 @@ def __init__(
self.file_path = file_path
self.file_content = file_content

def summary(self) -> str:
r"""Returns a summary of the issue.
def __str__(self) -> str:
r"""Returns a string representation of the issue.
Returns:
str: A string containing the title, body, number, file path, and
Expand All @@ -73,6 +74,48 @@ def summary(self) -> str:
)


@dataclass
class GithubPullRequestDiff:
r"""Represents a single diff of a pull request on Github.
Attributes:
filename (str): The name of the file that was changed.
patch (str): The diff patch for the file.
"""

filename: str
patch: str

def __str__(self) -> str:
r"""Returns a string representation of this diff."""
return f"Filename: {self.filename}\nPatch: {self.patch}"


@dataclass
class GithubPullRequest:
r"""Represents a pull request on Github.
Attributes:
title (str): The title of the GitHub pull request.
body (str): The body/content of the GitHub pull request.
diffs (List[GithubPullRequestDiff]): A list of diffs for the pull
request.
"""

title: str
body: str
diffs: List[GithubPullRequestDiff]

def __str__(self) -> str:
r"""Returns a string representation of the pull request."""
diff_summaries = '\n'.join(str(diff) for diff in self.diffs)
return (
f"Title: {self.title}\n"
f"Body: {self.body}\n"
f"Diffs: {diff_summaries}\n"
)


class GithubToolkit(BaseToolkit):
r"""A class representing a toolkit for interacting with GitHub
repositories.
Expand Down Expand Up @@ -106,7 +149,7 @@ def __init__(
except ImportError:
raise ImportError(
"Please install `github` first. You can install it by running "
"`pip install wikipedia`."
"`pip install pygithub`."
)
self.github = Github(auth=Auth.Token(access_token))
self.repo = self.github.get_repo(repo_name)
Expand All @@ -123,6 +166,7 @@ def get_tools(self) -> List[OpenAIFunction]:
OpenAIFunction(self.retrieve_issue_list),
OpenAIFunction(self.retrieve_issue),
OpenAIFunction(self.create_pull_request),
OpenAIFunction(self.retrieve_pull_requests),
]

def get_github_access_token(self) -> str:
Expand Down Expand Up @@ -181,9 +225,47 @@ def retrieve_issue(self, issue_number: int) -> Optional[str]:
issues = self.retrieve_issue_list()
for issue in issues:
if issue.number == issue_number:
return issue.summary()
return str(issue)
return None

def retrieve_pull_requests(
self, days: int, state: str, max_prs: int
) -> List[str]:
r"""Retrieves a summary of merged pull requests from the repository.
The summary will be provided for the last specified number of days.
Args:
days (int): The number of days to retrieve merged pull requests
for.
state (str): A specific state of PRs to retrieve. Can be open or
closed.
max_prs (int): The maximum number of PRs to retrieve.
Returns:
List[str]: A list of merged pull request summaries.
"""
pull_requests = self.repo.get_pulls(state=state)
merged_prs = []
earliest_date: datetime = datetime.utcnow() - timedelta(days=days)

for pr in pull_requests[:max_prs]:
if (
pr.merged
and pr.merged_at is not None
and pr.merged_at.timestamp() > earliest_date.timestamp()
):
pr_details = GithubPullRequest(pr.title, pr.body, [])

# Get files changed in the PR
files = pr.get_files()

for file in files:
diff = GithubPullRequestDiff(file.filename, file.patch)
pr_details.diffs.append(diff)

merged_prs.append(str(pr_details))
return merged_prs

def create_pull_request(
self,
file_path: str,
Expand Down
67 changes: 65 additions & 2 deletions examples/function_call/github_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import argparse

from colorama import Fore

from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.functions import OpenAIFunction
from camel.messages import BaseMessage
from camel.toolkits import GithubToolkit
from camel.utils import print_text_animated


def write_weekly_pr_summary(repo_name, model=None):
prompt = """
You need to write a summary of the pull requests that were merged in the
last week.
You can use the provided github function retrieve_pull_requests to
retrieve the list of pull requests that were merged in the last week.
The maximum amount of PRs to analyze is 3.
You have to pass the number of days as the first parameter to
retrieve_pull_requests and state='closed' as the second parameter.
The function will return a list of pull requests with the following
properties: title, body, and diffs.
Diffs is a list of dictionaries with the following properties: filename,
diff.
You will have to look closely at each diff to understand the changes that
were made in each pull request.
Output a twitter post that describes recent changes in the project and
thanks the contributors.
Here is an example of a summary for one pull request:
馃摙 We've improved function calling in the 馃惇 CAMEL-AI framework!
This update enhances the handling of various docstring styles and supports
enum types, ensuring more accurate and reliable function calls.
Thanks to our contributor Jiahui Zhang for making this possible.
"""
print(Fore.YELLOW + f"Final prompt:\n{prompt}\n")

toolkit = GithubToolkit(repo_name=repo_name)
assistant_sys_msg = BaseMessage.make_assistant_message(
role_name="Marketing Manager",
content=f"""
You are an experienced marketing manager responsible for posting
weekly updates about the status
of an open source project {repo_name} on the project's blog.
""",
)
assistant_model_config = ChatGPTConfig(
tools=[OpenAIFunction(toolkit.retrieve_pull_requests)],
temperature=0.0,
)
agent = ChatAgent(
assistant_sys_msg,
model_type=model,
model_config=assistant_model_config,
tools=[OpenAIFunction(toolkit.retrieve_pull_requests)],
)
agent.reset()

user_msg = BaseMessage.make_user_message(role_name="User", content=prompt)
assistant_response = agent.step(user_msg)

if len(assistant_response.msgs) > 0:
print_text_animated(
Fore.GREEN + f"Agent response:\n{assistant_response.msg.content}\n"
)


def solve_issue(
repo_name,
issue_number,
Expand Down Expand Up @@ -70,8 +129,12 @@ def solve_issue(


def main(model=None) -> None:
repo_name = "camel-ai/test-github-agent"
solve_issue(repo_name=repo_name, issue_number=1, model=model)
parser = argparse.ArgumentParser(description='Enter repo name.')
parser.add_argument('repo_name', type=str, help='Name of the repository')
args = parser.parse_args()

repo_name = args.repo_name
write_weekly_pr_summary(repo_name=repo_name, model=model)


if __name__ == "__main__":
Expand Down
64 changes: 59 additions & 5 deletions test/toolkits/test_github_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from datetime import datetime
from unittest.mock import MagicMock, patch

from github import Auth, Github
from github.ContentFile import ContentFile

from camel.toolkits.github_toolkit import GithubIssue, GithubToolkit
from camel.toolkits.github_toolkit import (
GithubIssue,
GithubPullRequest,
GithubPullRequestDiff,
GithubToolkit,
)


@patch.object(Github, '__init__', lambda self, *args, **kwargs: None)
Expand Down Expand Up @@ -120,9 +126,9 @@ def test_retrieve_issue(monkeypatch):
file_path="path/to/file",
file_content="This is the content of the file",
)
assert (
issue == expected_issue.summary()
), f"Expected {expected_issue.summary()}, but got {issue}"
assert issue == str(
expected_issue
), f"Expected {expected_issue}, but got {issue}"


@patch.object(Github, 'get_repo', return_value=MagicMock())
Expand Down Expand Up @@ -165,6 +171,54 @@ def test_create_pull_request(monkeypatch):
), f"Expected {expected_response}, but got {pr}"


@patch.object(Github, 'get_repo', return_value=MagicMock())
@patch.object(Auth.Token, '__init__', lambda self, *args, **kwargs: None)
def test_retrieve_pull_requests(monkeypatch):
# Call the constructor of the GithubToolkit class
github_toolkit = GithubToolkit("repo_name", "token")

# Create a mock file
mock_file = MagicMock()
mock_file.filename = "path/to/file"
mock_file.diff = "This is the diff of the file"

# Create a mock pull request
mock_pull_request = MagicMock()
mock_pull_request.title = "Test PR"
mock_pull_request.body = "This is a test issue"
mock_pull_request.merged_at = datetime.utcnow()

# Create a mock file
mock_file = MagicMock()
mock_file.filename = "path/to/file"
mock_file.patch = "This is the diff of the file"

# Mock the get_files method of the mock_pull_request instance to return a
# list containing the mock file object
mock_pull_request.get_files.return_value = [mock_file]

# Mock the get_issues method of the mock repo instance to return a list
# containing the mock issue object
github_toolkit.repo.get_pulls.return_value = [mock_pull_request]

pull_requests = github_toolkit.retrieve_pull_requests(
days=7, state='closed', max_prs=3
)
# Assert the returned issue list
expected_pull_request = GithubPullRequest(
title="Test PR",
body="This is a test issue",
diffs=[
GithubPullRequestDiff(
filename="path/to/file", patch="This is the diff of the file"
)
],
)
assert pull_requests == [
str(expected_pull_request)
], f"Expected {expected_pull_request}, but got {pull_requests}"


def test_github_issue():
# Create a GithubIssue object
issue = GithubIssue(
Expand All @@ -183,7 +237,7 @@ def test_github_issue():
assert issue.file_content == "This is the content of the file"

# Test the summary method
summary = issue.summary()
summary = str(issue)
expected_summary = (
f"Title: {issue.title}\n"
f"Body: {issue.body}\n"
Expand Down

0 comments on commit 26100a9

Please sign in to comment.