diff --git a/module/apmhttp/client.go b/module/apmhttp/client.go index e4c3eee79..f2fc6a6c2 100644 --- a/module/apmhttp/client.go +++ b/module/apmhttp/client.go @@ -176,3 +176,15 @@ func (b *responseBody) endSpan() { // ClientOption sets options for tracing client requests. type ClientOption func(*roundTripper) + +// WithClientRequestName returns a ClientOption which sets r as the function +// to use to obtain the span name for the given http request. +func WithClientRequestName(r RequestNameFunc) ClientOption { + if r == nil { + panic("r == nil") + } + + return ClientOption(func(rt *roundTripper) { + rt.requestName = r + }) +} diff --git a/module/apmhttp/client_test.go b/module/apmhttp/client_test.go index bf0ab92ec..a2cfae5c1 100644 --- a/module/apmhttp/client_test.go +++ b/module/apmhttp/client_test.go @@ -217,8 +217,27 @@ func TestClientCancelRequest(t *testing.T) { } } -func mustGET(ctx context.Context, url string) (statusCode int, responseBody string) { - client := apmhttp.WrapClient(http.DefaultClient) +func TestWithClientRequestName(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + option := apmhttp.WithClientRequestName(func(_ *http.Request) string { + return "http://test" + }) + + _, spans, _ := apmtest.WithTransaction(func(ctx context.Context) { + mustGET(ctx, server.URL, option) + }) + + require.Len(t, spans, 1) + span := spans[0] + assert.Equal(t, "http://test", span.Name) +} + +func mustGET(ctx context.Context, url string, o ...apmhttp.ClientOption) (statusCode int, responseBody string) { + client := apmhttp.WrapClient(http.DefaultClient, o...) resp, err := ctxhttp.Get(ctx, client, url) if err != nil { panic(err)