diff --git a/cmd/utils_test.go b/cmd/utils_test.go index 2077c5d..0d0005f 100644 --- a/cmd/utils_test.go +++ b/cmd/utils_test.go @@ -249,6 +249,8 @@ func TestParsePRURL(t *testing.T) { {"standard URL", "https://github.com/owner/repo/pull/42", 42, true}, {"with trailing slash", "https://github.com/owner/repo/pull/42/", 42, true}, {"with files tab", "https://github.com/owner/repo/pull/42/files", 42, true}, + {"GHES URL", "https://ghes.example.com/owner/repo/pull/99", 99, true}, + {"GHES URL with trailing slash", "https://ghes.example.com/owner/repo/pull/7/", 7, true}, {"not a PR URL", "https://github.com/owner/repo/issues/42", 0, false}, {"plain number", "42", 0, false}, {"branch name", "feat-1", 0, false}, diff --git a/cmd/view.go b/cmd/view.go index 1aa22b3..11116ac 100644 --- a/cmd/view.go +++ b/cmd/view.go @@ -12,6 +12,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/github/gh-stack/internal/config" "github.com/github/gh-stack/internal/git" + ghapi "github.com/github/gh-stack/internal/github" "github.com/github/gh-stack/internal/stack" "github.com/github/gh-stack/internal/tui/stackview" "github.com/spf13/cobra" @@ -65,8 +66,9 @@ func runView(cfg *config.Config, opts *viewOptions) error { } func viewShort(cfg *config.Config, s *stack.Stack, currentBranch string) error { - var repoOwner, repoName string + var repoHost, repoOwner, repoName string if repo, err := cfg.Repo(); err == nil { + repoHost = repo.Host repoOwner = repo.Owner repoName = repo.Name } @@ -81,7 +83,7 @@ func viewShort(cfg *config.Config, s *stack.Stack, currentBranch string) error { } indicator := branchStatusIndicator(cfg, s, b) - prSuffix := shortPRSuffix(cfg, b, repoOwner, repoName) + prSuffix := shortPRSuffix(cfg, b, repoHost, repoOwner, repoName) if b.Branch == currentBranch { cfg.Outf("ยป %s%s%s %s\n", cfg.ColorBold(b.Branch), indicator, prSuffix, cfg.ColorCyan("(current)")) } else if merged { @@ -187,13 +189,13 @@ func viewJSON(cfg *config.Config, s *stack.Stack, currentBranch string) error { return err } -func shortPRSuffix(cfg *config.Config, b stack.BranchRef, owner, repo string) string { +func shortPRSuffix(cfg *config.Config, b stack.BranchRef, host, owner, repo string) string { if b.PullRequest == nil || b.PullRequest.Number == 0 { return "" } url := b.PullRequest.URL if url == "" && owner != "" && repo != "" { - url = fmt.Sprintf("https://github.com/%s/%s/pull/%d", owner, repo, b.PullRequest.Number) + url = ghapi.PRURL(host, owner, repo, b.PullRequest.Number) } prNum := cfg.PRLink(b.PullRequest.Number, url) colorFn := cfg.ColorSuccess // green for open @@ -251,9 +253,10 @@ func viewFullTUI(cfg *config.Config, s *stack.Stack, currentBranch string) error func viewFullStatic(cfg *config.Config, s *stack.Stack, currentBranch string) error { client, clientErr := cfg.GitHubClient() - var repoOwner, repoName string + var repoHost, repoOwner, repoName string repo, repoErr := cfg.Repo() if repoErr == nil { + repoHost = repo.Host repoOwner = repo.Owner repoName = repo.Name } @@ -285,7 +288,7 @@ func viewFullStatic(cfg *config.Config, s *stack.Stack, currentBranch string) er } else if clientErr == nil && repoErr == nil { pr, err := client.FindPRForBranch(b.Branch) if err == nil && pr != nil { - prInfo = fmt.Sprintf(" https://github.com/%s/%s/pull/%d", repoOwner, repoName, pr.Number) + prInfo = " " + ghapi.PRURL(repoHost, repoOwner, repoName, pr.Number) } } diff --git a/internal/config/config.go b/internal/config/config.go index 27a0f4a..f4ea5c7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -121,5 +121,5 @@ func (c *Config) GitHubClient() (ghapi.ClientOps, error) { if err != nil { return nil, fmt.Errorf("determining repository: %w", err) } - return ghapi.NewClient(repo.Owner, repo.Name) + return ghapi.NewClient(repo.Host, repo.Owner, repo.Name) } diff --git a/internal/github/github.go b/internal/github/github.go index 6cd8861..73b6126 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -24,30 +24,47 @@ type PullRequest struct { type Client struct { gql *api.GraphQLClient rest *api.RESTClient + host string owner string repo string slug string } // NewClient creates a new GitHub API client for the given repository. -func NewClient(owner, repo string) (*Client, error) { - gql, err := api.DefaultGraphQLClient() +// The host parameter specifies the GitHub hostname (e.g. "github.com" or a +// GHES hostname like "github.mycompany.com"). If empty, it defaults to +// "github.com". +func NewClient(host, owner, repo string) (*Client, error) { + if host == "" { + host = "github.com" + } + opts := api.ClientOptions{Host: host} + gql, err := api.NewGraphQLClient(opts) if err != nil { return nil, fmt.Errorf("creating GraphQL client: %w", err) } - rest, err := api.DefaultRESTClient() + rest, err := api.NewRESTClient(opts) if err != nil { return nil, fmt.Errorf("creating REST client: %w", err) } return &Client{ gql: gql, rest: rest, + host: host, owner: owner, repo: repo, slug: owner + "/" + repo, }, nil } +// PRURL constructs the web URL for a pull request on the given host. +func PRURL(host, owner, repo string, number int) string { + if host == "" { + host = "github.com" + } + return fmt.Sprintf("https://%s/%s/%s/pull/%d", host, owner, repo, number) +} + // FindPRForBranch finds an open PR by head branch name. func (c *Client) FindPRForBranch(branch string) (*PullRequest, error) { var query struct { diff --git a/internal/github/github_test.go b/internal/github/github_test.go new file mode 100644 index 0000000..8ba93b8 --- /dev/null +++ b/internal/github/github_test.go @@ -0,0 +1,28 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPRURL(t *testing.T) { + tests := []struct { + name string + host string + owner string + repo string + number int + want string + }{ + {"github.com", "github.com", "owner", "repo", 42, "https://github.com/owner/repo/pull/42"}, + {"GHES host", "ghes.example.com", "myorg", "myrepo", 99, "https://ghes.example.com/myorg/myrepo/pull/99"}, + {"empty host defaults to github.com", "", "owner", "repo", 1, "https://github.com/owner/repo/pull/1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := PRURL(tt.host, tt.owner, tt.repo, tt.number) + assert.Equal(t, tt.want, got) + }) + } +}