From 5d85b0a5a125193708f9508d5598c25da44a626b Mon Sep 17 00:00:00 2001 From: tonytrg Date: Mon, 8 Dec 2025 11:05:06 +0100 Subject: [PATCH] adding review comments grouped as threads --- README.md | 2 +- .../__toolsnaps__/pull_request_read.snap | 2 +- pkg/github/pullrequests.go | 152 ++++++-- pkg/github/pullrequests_test.go | 344 ++++++++++++------ pkg/github/tools.go | 2 +- 5 files changed, 358 insertions(+), 144 deletions(-) diff --git a/README.md b/README.md index c7243033b..54dff7a03 100644 --- a/README.md +++ b/README.md @@ -991,7 +991,7 @@ Possible options: 2. get_diff - Get the diff of a pull request. 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned. - 5. get_review_comments - Get the review comments on a pull request. They are comments made on a portion of the unified diff during a pull request review. Use with pagination parameters to control the number of results returned. + 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Use cursor-based pagination (perPage, after) to control results. 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned. (string, required) diff --git a/pkg/github/__toolsnaps__/pull_request_read.snap b/pkg/github/__toolsnaps__/pull_request_read.snap index 434fba348..69b1bd901 100644 --- a/pkg/github/__toolsnaps__/pull_request_read.snap +++ b/pkg/github/__toolsnaps__/pull_request_read.snap @@ -15,7 +15,7 @@ "properties": { "method": { "type": "string", - "description": "Action to specify what pull request data needs to be retrieved from GitHub. \nPossible options: \n 1. get - Get details of a specific pull request.\n 2. get_diff - Get the diff of a pull request.\n 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks.\n 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned.\n 5. get_review_comments - Get the review comments on a pull request. They are comments made on a portion of the unified diff during a pull request review. Use with pagination parameters to control the number of results returned.\n 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method.\n 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned.\n", + "description": "Action to specify what pull request data needs to be retrieved from GitHub. \nPossible options: \n 1. get - Get details of a specific pull request.\n 2. get_diff - Get the diff of a pull request.\n 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks.\n 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned.\n 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Use cursor-based pagination (perPage, after) to control results.\n 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method.\n 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned.\n", "enum": [ "get", "get_diff", diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 661384529..b9ecdc1ee 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -21,7 +21,7 @@ import ( ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -33,7 +33,7 @@ Possible options: 2. get_diff - Get the diff of a pull request. 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned. - 5. get_review_comments - Get the review comments on a pull request. They are comments made on a portion of the unified diff during a pull request review. Use with pagination parameters to control the number of results returned. + 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Use cursor-based pagination (perPage, after) to control results. 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned. `, @@ -107,7 +107,11 @@ Possible options: result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) return result, nil, err case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) + gqlClient, err := getGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } + result, err := GetPullRequestReviewComments(ctx, gqlClient, cache, owner, repo, pullNumber, pagination, flags) return result, nil, err case "get_reviews": result, err := GetPullRequestReviews(ctx, client, cache, owner, repo, pullNumber, flags) @@ -282,54 +286,130 @@ func GetPullRequestFiles(ctx context.Context, client *github.Client, owner, repo return utils.NewToolResultText(string(r)), nil } -func GetPullRequestReviewComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, pagination PaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { - opts := &github.PullRequestListCommentsOptions{ - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, +// GraphQL types for review threads query +type reviewThreadsQuery struct { + Repository struct { + PullRequest struct { + ReviewThreads struct { + Nodes []reviewThreadNode + PageInfo pageInfoFragment + TotalCount githubv4.Int + } `graphql:"reviewThreads(first: $first, after: $after)"` + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` +} + +type reviewThreadNode struct { + ID githubv4.ID + IsResolved githubv4.Boolean + IsOutdated githubv4.Boolean + IsCollapsed githubv4.Boolean + Comments struct { + Nodes []reviewCommentNode + TotalCount githubv4.Int + } `graphql:"comments(first: $commentsPerThread)"` +} + +type reviewCommentNode struct { + ID githubv4.ID + Body githubv4.String + Path githubv4.String + Line *githubv4.Int + Author struct { + Login githubv4.String } + CreatedAt githubv4.DateTime + UpdatedAt githubv4.DateTime + URL githubv4.URI +} + +type pageInfoFragment struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String +} - comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) +func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, pagination PaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { + // Convert pagination parameters to GraphQL format + gqlParams, err := pagination.ToGraphQLParams() if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get pull request review comments", - resp, - err, - ), nil + return utils.NewToolResultError(fmt.Sprintf("invalid pagination parameters: %v", err)), nil } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request review comments: %s", string(body))), nil + // Default to 100 threads if not specified, max is 100 for GraphQL + perPage := int32(100) + if gqlParams.First != nil && *gqlParams.First > 0 { + perPage = *gqlParams.First } + // Build variables for GraphQL query + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(int32(pullNumber)), //nolint:gosec // pullNumber is controlled by user input validation + "first": githubv4.Int(perPage), + "commentsPerThread": githubv4.Int(50), // Max 50 comments per thread + } + + // Add cursor if provided + if gqlParams.After != nil && *gqlParams.After != "" { + vars["after"] = githubv4.String(*gqlParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } + + // Execute GraphQL query + var query reviewThreadsQuery + if err := gqlClient.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get pull request review threads", + err, + ), nil + } + + // Lockdown mode filtering if ff.LockdownMode { if cache == nil { return nil, fmt.Errorf("lockdown cache is not configured") } - filteredComments := make([]*github.PullRequestComment, 0, len(comments)) - for _, comment := range comments { - user := comment.GetUser() - if user == nil { - continue - } - isSafeContent, err := cache.IsSafeContent(ctx, user.GetLogin(), owner, repo) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil - } - if isSafeContent { - filteredComments = append(filteredComments, comment) + + // Iterate through threads and filter comments + for i := range query.Repository.PullRequest.ReviewThreads.Nodes { + thread := &query.Repository.PullRequest.ReviewThreads.Nodes[i] + filteredComments := make([]reviewCommentNode, 0, len(thread.Comments.Nodes)) + + for _, comment := range thread.Comments.Nodes { + login := string(comment.Author.Login) + if login != "" { + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check lockdown mode: %w", err) + } + if isSafeContent { + filteredComments = append(filteredComments, comment) + } + } } + + thread.Comments.Nodes = filteredComments + thread.Comments.TotalCount = githubv4.Int(int32(len(filteredComments))) //nolint:gosec // comment count is bounded by API limits } - comments = filteredComments } - r, err := json.Marshal(comments) + // Build response with review threads and pagination info + response := map[string]interface{}{ + "reviewThreads": query.Repository.PullRequest.ReviewThreads.Nodes, + "pageInfo": map[string]interface{}{ + "hasNextPage": query.Repository.PullRequest.ReviewThreads.PageInfo.HasNextPage, + "hasPreviousPage": query.Repository.PullRequest.ReviewThreads.PageInfo.HasPreviousPage, + "startCursor": string(query.Repository.PullRequest.ReviewThreads.PageInfo.StartCursor), + "endCursor": string(query.Repository.PullRequest.ReviewThreads.PageInfo.EndCursor), + }, + "totalCount": int(query.Repository.PullRequest.ReviewThreads.TotalCount), + } + + r, err := json.Marshal(response) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 94313d4e3..80f859297 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/internal/githubv4mock" "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" @@ -22,7 +23,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -104,7 +105,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1142,7 +1143,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1246,7 +1247,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1287,7 +1288,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1415,7 +1416,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1578,7 +1579,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1590,52 +1591,80 @@ func Test_GetPullRequestComments(t *testing.T) { assert.Contains(t, schema.Properties, "pullNumber") assert.ElementsMatch(t, schema.Required, []string{"method", "owner", "repo", "pullNumber"}) - // Setup mock PR comments for success case - mockComments := []*github.PullRequestComment{ - { - ID: github.Ptr(int64(101)), - Body: github.Ptr("This looks good"), - HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r101"), - User: &github.User{ - Login: github.Ptr("reviewer1"), - }, - Path: github.Ptr("file1.go"), - Position: github.Ptr(5), - CommitID: github.Ptr("abcdef123456"), - CreatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, - UpdatedAt: &github.Timestamp{Time: time.Now().Add(-24 * time.Hour)}, - }, - { - ID: github.Ptr(int64(102)), - Body: github.Ptr("Please fix this"), - HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#discussion_r102"), - User: &github.User{ - Login: github.Ptr("reviewer2"), - }, - Path: github.Ptr("file2.go"), - Position: github.Ptr(10), - CommitID: github.Ptr("abcdef123456"), - CreatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, - UpdatedAt: &github.Timestamp{Time: time.Now().Add(-12 * time.Hour)}, - }, - } - tests := []struct { - name string - mockedClient *http.Client - gqlHTTPClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedComments []*github.PullRequestComment - expectedErrMsg string - lockdownEnabled bool + name string + gqlHTTPClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + lockdownEnabled bool + validateResult func(t *testing.T, textContent string) }{ { - name: "successful comments fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, - mockComments, + name: "successful review threads fetch", + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + reviewThreadsQuery{}, + map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + "first": githubv4.Int(30), + "commentsPerThread": githubv4.Int(50), + "after": (*githubv4.String)(nil), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "reviewThreads": map[string]any{ + "nodes": []map[string]any{ + { + "id": "RT_kwDOA0xdyM4AX1Yz", + "isResolved": false, + "isOutdated": false, + "isCollapsed": false, + "comments": map[string]any{ + "totalCount": 2, + "nodes": []map[string]any{ + { + "id": "PRRC_kwDOA0xdyM4AX1Y0", + "body": "This looks good", + "path": "file1.go", + "line": 5, + "author": map[string]any{ + "login": "reviewer1", + }, + "createdAt": "2024-01-01T12:00:00Z", + "updatedAt": "2024-01-01T12:00:00Z", + "url": "https://github.com/owner/repo/pull/42#discussion_r101", + }, + { + "id": "PRRC_kwDOA0xdyM4AX1Y1", + "body": "Please fix this", + "path": "file1.go", + "line": 10, + "author": map[string]any{ + "login": "reviewer2", + }, + "createdAt": "2024-01-01T13:00:00Z", + "updatedAt": "2024-01-01T13:00:00Z", + "url": "https://github.com/owner/repo/pull/42#discussion_r102", + }, + }, + }, + }, + }, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "cursor1", + "endCursor": "cursor2", + }, + "totalCount": 1, + }, + }, + }, + }), ), ), requestArgs: map[string]interface{}{ @@ -1644,18 +1673,63 @@ func Test_GetPullRequestComments(t *testing.T) { "repo": "repo", "pullNumber": float64(42), }, - expectError: false, - expectedComments: mockComments, + expectError: false, + validateResult: func(t *testing.T, textContent string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(textContent), &result) + require.NoError(t, err) + + // Validate response structure + assert.Contains(t, result, "reviewThreads") + assert.Contains(t, result, "pageInfo") + assert.Contains(t, result, "totalCount") + + // Validate review threads + threads := result["reviewThreads"].([]interface{}) + assert.Len(t, threads, 1) + + thread := threads[0].(map[string]interface{}) + assert.Equal(t, "RT_kwDOA0xdyM4AX1Yz", thread["ID"]) + assert.Equal(t, false, thread["IsResolved"]) + assert.Equal(t, false, thread["IsOutdated"]) + assert.Equal(t, false, thread["IsCollapsed"]) + + // Validate comments within thread + comments := thread["Comments"].(map[string]interface{}) + commentNodes := comments["Nodes"].([]interface{}) + assert.Len(t, commentNodes, 2) + + // Validate first comment + comment1 := commentNodes[0].(map[string]interface{}) + assert.Equal(t, "PRRC_kwDOA0xdyM4AX1Y0", comment1["ID"]) + assert.Equal(t, "This looks good", comment1["Body"]) + assert.Equal(t, "file1.go", comment1["Path"]) + + // Validate pagination info + pageInfo := result["pageInfo"].(map[string]interface{}) + assert.Equal(t, false, pageInfo["hasNextPage"]) + assert.Equal(t, false, pageInfo["hasPreviousPage"]) + assert.Equal(t, "cursor1", pageInfo["startCursor"]) + assert.Equal(t, "cursor2", pageInfo["endCursor"]) + + // Validate total count + assert.Equal(t, float64(1), result["totalCount"]) + }, }, { - name: "comments fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), + name: "review threads fetch fails", + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + reviewThreadsQuery{}, + map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(999), + "first": githubv4.Int(30), + "commentsPerThread": githubv4.Int(50), + "after": (*githubv4.String)(nil), + }, + githubv4mock.ErrorResponse("Could not resolve to a PullRequest with the number of 999."), ), ), requestArgs: map[string]interface{}{ @@ -1665,59 +1739,129 @@ func Test_GetPullRequestComments(t *testing.T) { "pullNumber": float64(999), }, expectError: true, - expectedErrMsg: "failed to get pull request review comments", + expectedErrMsg: "failed to get pull request review threads", }, { name: "lockdown enabled filters review comments without push access", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, - []*github.PullRequestComment{ - { - ID: github.Ptr(int64(2010)), - Body: github.Ptr("Maintainer review comment"), - User: &github.User{Login: github.Ptr("maintainer")}, - }, - { - ID: github.Ptr(int64(2011)), - Body: github.Ptr("External review comment"), - User: &github.User{Login: github.Ptr("testuser")}, - }, + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + reviewThreadsQuery{}, + map[string]interface{}{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + "first": githubv4.Int(30), + "commentsPerThread": githubv4.Int(50), + "after": (*githubv4.String)(nil), }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "reviewThreads": map[string]any{ + "nodes": []map[string]any{ + { + "id": "RT_kwDOA0xdyM4AX1Yz", + "isResolved": false, + "isOutdated": false, + "isCollapsed": false, + "comments": map[string]any{ + "totalCount": 2, + "nodes": []map[string]any{ + { + "id": "PRRC_kwDOA0xdyM4AX1Y0", + "body": "Maintainer review comment", + "path": "file1.go", + "line": 5, + "author": map[string]any{ + "login": "maintainer", + }, + "createdAt": "2024-01-01T12:00:00Z", + "updatedAt": "2024-01-01T12:00:00Z", + "url": "https://github.com/owner/repo/pull/42#discussion_r2010", + }, + { + "id": "PRRC_kwDOA0xdyM4AX1Y1", + "body": "External review comment", + "path": "file1.go", + "line": 10, + "author": map[string]any{ + "login": "testuser", + }, + "createdAt": "2024-01-01T13:00:00Z", + "updatedAt": "2024-01-01T13:00:00Z", + "url": "https://github.com/owner/repo/pull/42#discussion_r2011", + }, + }, + }, + }, + }, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "cursor1", + "endCursor": "cursor2", + }, + "totalCount": 1, + }, + }, + }, + }), ), ), - gqlHTTPClient: newRepoAccessHTTPClient(), requestArgs: map[string]interface{}{ "method": "get_review_comments", "owner": "owner", "repo": "repo", "pullNumber": float64(42), }, - expectError: false, - expectedComments: []*github.PullRequestComment{ - { - ID: github.Ptr(int64(2010)), - Body: github.Ptr("Maintainer review comment"), - User: &github.User{Login: github.Ptr("maintainer")}, - }, - }, + expectError: false, lockdownEnabled: true, + validateResult: func(t *testing.T, textContent string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(textContent), &result) + require.NoError(t, err) + + // Validate that only maintainer comment is returned + threads := result["reviewThreads"].([]interface{}) + assert.Len(t, threads, 1) + + thread := threads[0].(map[string]interface{}) + comments := thread["Comments"].(map[string]interface{}) + + // Should only have 1 comment (maintainer) after filtering + assert.Equal(t, float64(1), comments["TotalCount"]) + + commentNodes := comments["Nodes"].([]interface{}) + assert.Len(t, commentNodes, 1) + + comment := commentNodes[0].(map[string]interface{}) + author := comment["Author"].(map[string]interface{}) + assert.Equal(t, "maintainer", author["Login"]) + assert.Equal(t, "Maintainer review comment", comment["Body"]) + }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Setup client with mock - client := github.NewClient(tc.mockedClient) + // Setup GraphQL client with mock var gqlClient *githubv4.Client if tc.gqlHTTPClient != nil { gqlClient = githubv4.NewClient(tc.gqlHTTPClient) } else { gqlClient = githubv4.NewClient(nil) } - cache := stubRepoAccessCache(gqlClient, 5*time.Minute) + + // Setup cache for lockdown mode + var cache *lockdown.RepoAccessCache + if tc.lockdownEnabled { + cache = stubRepoAccessCache(githubv4.NewClient(newRepoAccessHTTPClient()), 5*time.Minute) + } else { + cache = stubRepoAccessCache(gqlClient, 5*time.Minute) + } + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + _, handler := PullRequestRead(stubGetClientFn(github.NewClient(nil)), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1740,19 +1884,9 @@ func Test_GetPullRequestComments(t *testing.T) { // Parse the result and get the text content if no error textContent := getTextResult(t, result) - // Unmarshal and verify the result - var returnedComments []*github.PullRequestComment - err = json.Unmarshal([]byte(textContent.Text), &returnedComments) - require.NoError(t, err) - assert.Len(t, returnedComments, len(tc.expectedComments)) - for i, comment := range returnedComments { - require.NotNil(t, tc.expectedComments[i].User) - require.NotNil(t, comment.User) - assert.Equal(t, tc.expectedComments[i].GetID(), comment.GetID()) - assert.Equal(t, tc.expectedComments[i].GetBody(), comment.GetBody()) - assert.Equal(t, tc.expectedComments[i].GetUser().GetLogin(), comment.GetUser().GetLogin()) - assert.Equal(t, tc.expectedComments[i].GetPath(), comment.GetPath()) - assert.Equal(t, tc.expectedComments[i].GetHTMLURL(), comment.GetHTMLURL()) + // Use custom validation if provided + if tc.validateResult != nil { + tc.validateResult(t, textContent.Text) } }) } @@ -1761,7 +1895,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1899,7 +2033,7 @@ func Test_GetPullRequestReviews(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), cache, translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2974,7 +3108,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -3033,7 +3167,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index d37af98b8..57529fb04 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -225,7 +225,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, cache, t, flags)), + toolsets.NewServerTool(PullRequestRead(getClient, getGQLClient, cache, t, flags)), toolsets.NewServerTool(ListPullRequests(getClient, t)), toolsets.NewServerTool(SearchPullRequests(getClient, t)), ).