Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/cli/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var apiCmd = &cobra.Command{
req.Header.Set(strings.TrimSpace(k), strings.TrimSpace(v))
}

resp, err := http.DefaultClient.Do(req)
resp, err := resolve.HTTPClient().Do(req)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions internal/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
forges "github.com/git-pkgs/forge"
"github.com/git-pkgs/forge/internal/config"
"github.com/git-pkgs/forge/internal/output"
"github.com/git-pkgs/forge/internal/resolve"
"github.com/spf13/cobra"
)

Expand All @@ -33,6 +34,7 @@ var rootCmd = &cobra.Command{
}

func Execute() error {
resolve.SetUserAgent("forge/" + Version)
return rootCmd.Execute()
}

Expand Down
13 changes: 7 additions & 6 deletions internal/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ func newClient(domain string) *forges.Client {
}

// Register default forges first, so config-based registrations override them.
hc := HTTPClient()
defaults := map[string]forges.Forge{
"github.com": ghforge.New(TokenForDomain("github.com"), nil),
"gitlab.com": glforge.New("https://gitlab.com", TokenForDomain("gitlab.com"), nil),
"codeberg.org": gitea.New("https://codeberg.org", TokenForDomain("codeberg.org"), nil),
"bitbucket.org": bitbucket.New(TokenForDomain("bitbucket.org"), nil),
"github.com": ghforge.New(TokenForDomain("github.com"), hc),
"gitlab.com": glforge.New("https://gitlab.com", TokenForDomain("gitlab.com"), hc),
"codeberg.org": gitea.New("https://codeberg.org", TokenForDomain("codeberg.org"), hc),
"bitbucket.org": bitbucket.New(TokenForDomain("bitbucket.org"), hc),
}
for d, f := range defaults {
opts = append(opts, forges.WithForge(d, f))
Expand All @@ -99,9 +100,9 @@ func newClient(domain string) *forges.Client {
if ft := configForgeType(domain); ft != "" {
switch ft {
case "gitea", "forgejo":
opts = append(opts, forges.WithForge(domain, gitea.New("https://"+domain, token, nil)))
opts = append(opts, forges.WithForge(domain, gitea.New("https://"+domain, token, hc)))
case "gitlab":
opts = append(opts, forges.WithForge(domain, glforge.New("https://"+domain, token, nil)))
opts = append(opts, forges.WithForge(domain, glforge.New("https://"+domain, token, hc)))
}
}

Expand Down
31 changes: 31 additions & 0 deletions internal/resolve/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package resolve

import "net/http"

var userAgent = "forge/dev"

// SetUserAgent sets the User-Agent string sent on every HTTP request made
// through HTTPClient. The CLI calls this at startup with the build version.
func SetUserAgent(ua string) {
userAgent = ua
}

// HTTPClient returns an http.Client whose transport sets the User-Agent
// header on outbound requests. Used for all forge API traffic so requests
// are identifiable in server logs.
func HTTPClient() *http.Client {
return &http.Client{Transport: &userAgentTransport{base: http.DefaultTransport}}
}

type userAgentTransport struct {
base http.RoundTripper
}

func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header.Get("User-Agent") != "" {
return t.base.RoundTrip(req)
}
r := req.Clone(req.Context())
r.Header.Set("User-Agent", userAgent)
return t.base.RoundTrip(r)
}
94 changes: 94 additions & 0 deletions internal/resolve/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package resolve

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestHTTPClientSetsUserAgent(t *testing.T) {
var got string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = r.Header.Get("User-Agent")
}))
defer srv.Close()

c := HTTPClient()
resp, err := c.Get(srv.URL)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()

if got != userAgent {
t.Errorf("expected User-Agent %q, got %q", userAgent, got)
}
if got == "" {
t.Error("User-Agent was empty")
}
}

func TestSetUserAgent(t *testing.T) {
old := userAgent
defer func() { userAgent = old }()

SetUserAgent("forge/1.2.3")

var got string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = r.Header.Get("User-Agent")
}))
defer srv.Close()

c := HTTPClient()
resp, err := c.Get(srv.URL)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()

if got != "forge/1.2.3" {
t.Errorf("expected User-Agent forge/1.2.3, got %q", got)
}
}

func TestUserAgentTransportPreservesExisting(t *testing.T) {
// If a caller has already set User-Agent (e.g. forge api -H "User-Agent: x"),
// the transport must not stomp on it.
var got string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = r.Header.Get("User-Agent")
}))
defer srv.Close()

c := HTTPClient()
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
req.Header.Set("User-Agent", "custom/1.0")
resp, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()

if got != "custom/1.0" {
t.Errorf("expected explicit User-Agent to be preserved, got %q", got)
}
}

func TestUserAgentTransportDoesNotMutateRequest(t *testing.T) {
// RoundTrippers must not modify the original request.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer srv.Close()

c := HTTPClient()
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
resp, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
_ = resp.Body.Close()

if req.Header.Get("User-Agent") != "" {
t.Errorf("transport mutated the original request: User-Agent=%q", req.Header.Get("User-Agent"))
}
}