Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions pkg/github/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,15 @@ func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock {
type partialMock struct {
t *testing.T

expectedPath string
expectedQueryParams map[string]string
expectedRequestBody any
expectedPath string
expectedQueryParams map[string]string
expectedRequestBody any
expectedHeaderContains map[string]string
}

func (p *partialMock) withHeaders(headers map[string]string) *partialMock {
p.expectedHeaderContains = headers
return p
}

func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc {
Expand All @@ -247,6 +253,12 @@ func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc
require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody)
}

if p.expectedHeaderContains != nil {
for k, v := range p.expectedHeaderContains {
require.Contains(p.t, r.Header.Get(k), v, "expected header %q to contain %q", k, v)
}
}

responseHandler(w, r)
}
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/github/minimal_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ type MinimalSearchRepositoriesResult struct {
Items []MinimalRepository `json:"items"`
}

// MinimalCodeSearchResult is the trimmed output type for code search results.
type MinimalCodeSearchResult struct {
TotalCount int `json:"total_count"`
IncompleteResults bool `json:"incomplete_results"`
Items []MinimalCodeResult `json:"items"`
}

// MinimalCodeResult is the trimmed output type for a single code search hit.
type MinimalCodeResult struct {
Name string `json:"name"`
Path string `json:"path"`
SHA string `json:"sha"`
Repository string `json:"repository"`
TextMatches []*github.TextMatch `json:"text_matches,omitempty"`
}

// MinimalCommitAuthor represents commit author information.
type MinimalCommitAuthor struct {
Name string `json:"name,omitempty"`
Expand Down
27 changes: 24 additions & 3 deletions pkg/github/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool {
}

opts := &github.SearchOptions{
Sort: sort,
Order: order,
Sort: sort,
Order: order,
TextMatch: true,
Comment thread
SamMorrowDrums marked this conversation as resolved.
ListOptions: github.ListOptions{
PerPage: pagination.PerPage,
Page: pagination.Page,
Expand Down Expand Up @@ -301,7 +302,27 @@ func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool {
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search code", resp, body), nil, nil
}

r, err := json.Marshal(result)
minimalItems := make([]MinimalCodeResult, 0, len(result.CodeResults))
for _, code := range result.CodeResults {
item := MinimalCodeResult{
Name: code.GetName(),
Path: code.GetPath(),
SHA: code.GetSHA(),
TextMatches: code.TextMatches,
}
if code.Repository != nil {
item.Repository = code.Repository.GetFullName()
}
minimalItems = append(minimalItems, item)
}

minimalResult := &MinimalCodeSearchResult{
TotalCount: result.GetTotal(),
IncompleteResults: result.GetIncompleteResults(),
Items: minimalItems,
}

r, err := json.Marshal(minimalResult)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil
}
Expand Down
67 changes: 43 additions & 24 deletions pkg/github/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,22 +430,35 @@ func Test_SearchCode(t *testing.T) {
IncompleteResults: github.Ptr(false),
CodeResults: []*github.CodeResult{
{
Name: github.Ptr("file1.go"),
Path: github.Ptr("path/to/file1.go"),
SHA: github.Ptr("abc123def456"),
HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file1.go"),
Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")},
Name: github.Ptr("file1.go"),
Path: github.Ptr("path/to/file1.go"),
SHA: github.Ptr("abc123def456"),
Repository: &github.Repository{
Name: github.Ptr("repo"),
FullName: github.Ptr("owner/repo"),
},
TextMatches: []*github.TextMatch{
{
Fragment: github.Ptr("func main() { fmt.Println(\"hello\") }"),
},
},
},
{
Name: github.Ptr("file2.go"),
Path: github.Ptr("path/to/file2.go"),
SHA: github.Ptr("def456abc123"),
HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file2.go"),
Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")},
Name: github.Ptr("file2.go"),
Path: github.Ptr("path/to/file2.go"),
SHA: github.Ptr("def456abc123"),
Repository: &github.Repository{
Name: github.Ptr("repo"),
FullName: github.Ptr("owner/repo"),
},
},
},
}

textMatchAcceptHeader := map[string]string{
"Accept": "text-match",
}

tests := []struct {
name string
mockedClient *http.Client
Expand All @@ -463,7 +476,7 @@ func Test_SearchCode(t *testing.T) {
"order": "desc",
"page": "1",
"per_page": "30",
}).andThen(
}).withHeaders(textMatchAcceptHeader).andThen(
mockResponse(t, http.StatusOK, mockSearchResult),
),
}),
Expand All @@ -484,7 +497,7 @@ func Test_SearchCode(t *testing.T) {
"q": "fmt.Println language:go",
"page": "1",
"per_page": "30",
}).andThen(
}).withHeaders(textMatchAcceptHeader).andThen(
mockResponse(t, http.StatusOK, mockSearchResult),
),
}),
Expand Down Expand Up @@ -537,22 +550,28 @@ func Test_SearchCode(t *testing.T) {
require.NoError(t, err)
require.False(t, result.IsError)

// Parse the result and get the text content if no error
textContent := getTextResult(t, result)

// Unmarshal and verify the result
var returnedResult github.CodeSearchResult
var returnedResult MinimalCodeSearchResult
err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
require.NoError(t, err)
assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total)
assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults)
assert.Len(t, returnedResult.CodeResults, len(tc.expectedResult.CodeResults))
for i, code := range returnedResult.CodeResults {
assert.Equal(t, *tc.expectedResult.CodeResults[i].Name, *code.Name)
assert.Equal(t, *tc.expectedResult.CodeResults[i].Path, *code.Path)
assert.Equal(t, *tc.expectedResult.CodeResults[i].SHA, *code.SHA)
assert.Equal(t, *tc.expectedResult.CodeResults[i].HTMLURL, *code.HTMLURL)
assert.Equal(t, *tc.expectedResult.CodeResults[i].Repository.FullName, *code.Repository.FullName)
assert.Equal(t, *tc.expectedResult.Total, returnedResult.TotalCount)
assert.Equal(t, *tc.expectedResult.IncompleteResults, returnedResult.IncompleteResults)
assert.Len(t, returnedResult.Items, len(tc.expectedResult.CodeResults))
for i, code := range returnedResult.Items {
assert.Equal(t, tc.expectedResult.CodeResults[i].GetName(), code.Name)
assert.Equal(t, tc.expectedResult.CodeResults[i].GetPath(), code.Path)
assert.Equal(t, tc.expectedResult.CodeResults[i].GetSHA(), code.SHA)
assert.Equal(t, tc.expectedResult.CodeResults[i].Repository.GetFullName(), code.Repository)
}

// Verify text matches are included when present
if len(tc.expectedResult.CodeResults[0].TextMatches) > 0 {
require.NotEmpty(t, returnedResult.Items[0].TextMatches)
assert.Equal(t,
tc.expectedResult.CodeResults[0].TextMatches[0].GetFragment(),
returnedResult.Items[0].TextMatches[0].GetFragment(),
)
}
})
}
Expand Down
Loading