diff --git a/client_test.go b/client_test.go index f541f19745..f1d4c0fe61 100644 --- a/client_test.go +++ b/client_test.go @@ -4,11 +4,13 @@ package cloudflare_test import ( "context" + "fmt" "net/http" "testing" "time" "github.com/cloudflare/cloudflare-go/v2" + "github.com/cloudflare/cloudflare-go/v2/internal" "github.com/cloudflare/cloudflare-go/v2/option" "github.com/cloudflare/cloudflare-go/v2/zones" ) @@ -21,6 +23,32 @@ func (t *closureTransport) RoundTrip(req *http.Request) (*http.Response, error) return t.fn(req) } +func TestUserAgentHeader(t *testing.T) { + var userAgent string + client := cloudflare.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + userAgent = req.Header.Get("User-Agent") + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + }), + ) + client.Zones.New(context.Background(), zones.ZoneNewParams{ + Account: cloudflare.F(zones.ZoneNewParamsAccount{ + ID: cloudflare.F("023e105f4ecef8ad9ca31a8372d0c353"), + }), + Name: cloudflare.F("example.com"), + Type: cloudflare.F(zones.ZoneNewParamsTypeFull), + }) + if userAgent != fmt.Sprintf("Cloudflare/Go %s", internal.PackageVersion) { + t.Errorf("Expected User-Agent to be correct, but got: %#v", userAgent) + } +} + func TestRetryAfter(t *testing.T) { attempts := 0 client := cloudflare.NewClient( diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index 7f5fc8730d..f946609179 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -23,6 +23,12 @@ import ( "github.com/cloudflare/cloudflare-go/v2/internal/apiquery" ) +func getDefaultHeaders() map[string]string { + return map[string]string{ + "User-Agent": fmt.Sprintf("Cloudflare/Go %s", internal.PackageVersion), + } +} + func getNormalizedOS() string { switch runtime.GOOS { case "ios": @@ -118,6 +124,9 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa } req.Header.Set("Accept", "application/json") + for k, v := range getDefaultHeaders() { + req.Header.Add(k, v) + } for k, v := range getPlatformProperties() { req.Header.Add(k, v)