Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 0 additions & 280 deletions api/queries_pr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,12 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/set"
"github.com/shurcooL/githubv4"
"golang.org/x/sync/errgroup"
)

type PullRequestsPayload struct {
ViewerCreated PullRequestAndTotalCount
ReviewRequested PullRequestAndTotalCount
CurrentPR *PullRequest
DefaultBranch string
}

type PullRequestAndTotalCount struct {
TotalCount int
PullRequests []PullRequest
Expand Down Expand Up @@ -269,275 +258,6 @@ func (pr *PullRequest) DisplayableReviews() PullRequestReviews {
return PullRequestReviews{Nodes: published, TotalCount: len(published)}
}

type pullRequestFeature struct {
HasReviewDecision bool
HasStatusCheckRollup bool
HasBranchProtectionRule bool
}

func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prFeatures pullRequestFeature, err error) {
if !ghinstance.IsEnterprise(hostname) {
prFeatures.HasReviewDecision = true
prFeatures.HasStatusCheckRollup = true
prFeatures.HasBranchProtectionRule = true
return
}

var featureDetection struct {
PullRequest struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"PullRequest: __type(name: \"PullRequest\")"`
Commit struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Commit: __type(name: \"Commit\")"`
}

// needs to be a separate query because the backend only supports 2 `__type` expressions in one query
var featureDetection2 struct {
Ref struct {
Fields []struct {
Name string
} `graphql:"fields(includeDeprecated: true)"`
} `graphql:"Ref: __type(name: \"Ref\")"`
}

v4 := graphQLClient(httpClient, hostname)

g := new(errgroup.Group)
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
})
g.Go(func() error {
return v4.QueryNamed(context.Background(), "PullRequest_fields2", &featureDetection2, nil)
})

err = g.Wait()
if err != nil {
return
}

for _, field := range featureDetection.PullRequest.Fields {
switch field.Name {
case "reviewDecision":
prFeatures.HasReviewDecision = true
}
}
for _, field := range featureDetection.Commit.Fields {
switch field.Name {
case "statusCheckRollup":
prFeatures.HasStatusCheckRollup = true
}
}
for _, field := range featureDetection2.Ref.Fields {
switch field.Name {
case "branchProtectionRule":
prFeatures.HasBranchProtectionRule = true
}
}
return
}

type StatusOptions struct {
CurrentPR int
HeadRef string
Username string
Fields []string
}

func PullRequestStatus(client *Client, repo ghrepo.Interface, options StatusOptions) (*PullRequestsPayload, error) {
type edges struct {
TotalCount int
Edges []struct {
Node PullRequest
}
}

type response struct {
Repository struct {
DefaultBranchRef struct {
Name string
}
PullRequests edges
PullRequest *PullRequest
}
ViewerCreated edges
ReviewRequested edges
}

var fragments string
if len(options.Fields) > 0 {
fields := set.NewStringSet()
fields.AddValues(options.Fields)
// these are always necessary to find the PR for the current branch
fields.AddValues([]string{"isCrossRepository", "headRepositoryOwner", "headRefName"})
gr := PullRequestGraphQL(fields.ToSlice())
fragments = fmt.Sprintf("fragment pr on PullRequest{%s}fragment prWithReviews on PullRequest{...pr}", gr)
} else {
var err error
fragments, err = pullRequestFragment(client.http, repo.RepoHost())
if err != nil {
return nil, err
}
}

queryPrefix := `
query PullRequestStatus($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequests(headRefName: $headRefName, first: $per_page, orderBy: { field: CREATED_AT, direction: DESC }) {
totalCount
edges {
node {
...prWithReviews
}
}
}
}
`
if options.CurrentPR > 0 {
queryPrefix = `
query PullRequestStatus($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
repository(owner: $owner, name: $repo) {
defaultBranchRef {
name
}
pullRequest(number: $number) {
...prWithReviews
baseRef {
branchProtectionRule {
requiredApprovingReviewCount
}
}
}
}
`
}

query := fragments + queryPrefix + `
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...prWithReviews
}
}
}
reviewRequested: search(query: $reviewerQuery, type: ISSUE, first: $per_page) {
totalCount: issueCount
edges {
node {
...pr
}
}
}
}
`

currentUsername := options.Username
if currentUsername == "@me" && ghinstance.IsEnterprise(repo.RepoHost()) {
var err error
currentUsername, err = CurrentLoginName(client, repo.RepoHost())
if err != nil {
return nil, err
}
}

viewerQuery := fmt.Sprintf("repo:%s state:open is:pr author:%s", ghrepo.FullName(repo), currentUsername)
reviewerQuery := fmt.Sprintf("repo:%s state:open review-requested:%s", ghrepo.FullName(repo), currentUsername)

currentPRHeadRef := options.HeadRef
branchWithoutOwner := currentPRHeadRef
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
branchWithoutOwner = currentPRHeadRef[idx+1:]
}

variables := map[string]interface{}{
"viewerQuery": viewerQuery,
"reviewerQuery": reviewerQuery,
"owner": repo.RepoOwner(),
"repo": repo.RepoName(),
"headRefName": branchWithoutOwner,
"number": options.CurrentPR,
}

var resp response
err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
if err != nil {
return nil, err
}

var viewerCreated []PullRequest
for _, edge := range resp.ViewerCreated.Edges {
viewerCreated = append(viewerCreated, edge.Node)
}

var reviewRequested []PullRequest
for _, edge := range resp.ReviewRequested.Edges {
reviewRequested = append(reviewRequested, edge.Node)
}

var currentPR = resp.Repository.PullRequest
if currentPR == nil {
for _, edge := range resp.Repository.PullRequests.Edges {
if edge.Node.HeadLabel() == currentPRHeadRef {
currentPR = &edge.Node
break // Take the most recent PR for the current branch
}
}
}

payload := PullRequestsPayload{
ViewerCreated: PullRequestAndTotalCount{
PullRequests: viewerCreated,
TotalCount: resp.ViewerCreated.TotalCount,
},
ReviewRequested: PullRequestAndTotalCount{
PullRequests: reviewRequested,
TotalCount: resp.ReviewRequested.TotalCount,
},
CurrentPR: currentPR,
DefaultBranch: resp.Repository.DefaultBranchRef.Name,
}

return &payload, nil
}

func pullRequestFragment(httpClient *http.Client, hostname string) (string, error) {
cachedClient := NewCachedClient(httpClient, time.Hour*24)
prFeatures, err := determinePullRequestFeatures(cachedClient, hostname)
if err != nil {
return "", err
}

fields := []string{
"number", "title", "state", "url", "isDraft", "isCrossRepository",
"headRefName", "headRepositoryOwner", "mergeStateStatus",
}
if prFeatures.HasStatusCheckRollup {
fields = append(fields, "statusCheckRollup")
}
if prFeatures.HasBranchProtectionRule {
fields = append(fields, "requiresStrictStatusChecks")
}

var reviewFields []string
if prFeatures.HasReviewDecision {
reviewFields = append(reviewFields, "reviewDecision", "latestReviews")
}

fragments := fmt.Sprintf(`
fragment pr on PullRequest {%s}
fragment prWithReviews on PullRequest {...pr,%s}
`, PullRequestGraphQL(fields), PullRequestGraphQL(reviewFields))
return fragments, nil
}

// CreatePullRequest creates a pull request in a GitHub repository
func CreatePullRequest(client *Client, repo *Repository, params map[string]interface{}) (*PullRequest, error) {
query := `
Expand Down
Loading