Skip to content

Commit

Permalink
Merge pull request #289 from axw/apmhttp-client-closebody-endspan
Browse files Browse the repository at this point in the history
module/apmhttp: end client span on body closure
  • Loading branch information
axw committed Oct 31, 2018
2 parents d17ddd2 + 6d37ee2 commit 1e3afe4
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -29,6 +29,7 @@
- Reuse memory for tags (#286)
- Return a more helpful error message when /intake/v2/events 404s, to detect old servers (#290)
- Implement test service for w3c/distributed-tracing test harness (#293)
- End HTTP client spans on response body closure (#289)

## [v0.5.2](https://github.com/elastic/apm-agent-go/releases/tag/v0.5.2)

Expand Down
49 changes: 45 additions & 4 deletions module/apmhttp/client.go
@@ -1,7 +1,10 @@
package apmhttp

import (
"io"
"net/http"
"sync/atomic"
"unsafe"

"go.elastic.co/apm"
)
Expand Down Expand Up @@ -80,15 +83,53 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
spanType := "ext.http"
span := tx.StartSpan(name, spanType, apm.SpanFromContext(ctx))
span.Context.SetHTTPRequest(req)
defer span.End()
if !span.Dropped() {
traceContext = span.TraceContext()
ctx = apm.ContextWithSpan(ctx, span)
req = RequestWithContext(ctx, req)
} else {
span.End()
span = nil
}

req.Header.Set(TraceparentHeader, FormatTraceparentHeader(traceContext))
ctx = apm.ContextWithSpan(ctx, span)
req = RequestWithContext(ctx, req)
return r.r.RoundTrip(req)
resp, err := r.r.RoundTrip(req)
if span != nil {
if err != nil {
span.End()
} else {
resp.Body = &responseBody{span: span, body: resp.Body}
}
}
return resp, err
}

type responseBody struct {
span *apm.Span
body io.ReadCloser
}

// Close closes the response body, and ends the span if it hasn't already been ended.
func (b *responseBody) Close() error {
b.endSpan()
return b.body.Close()
}

// Read reads from the response body, and ends the span when io.EOF is returend if
// the span hasn't already been ended.
func (b *responseBody) Read(p []byte) (n int, err error) {
n, err = b.body.Read(p)
if err == io.EOF {
b.endSpan()
}
return n, err
}

func (b *responseBody) endSpan() {
addr := (*unsafe.Pointer)(unsafe.Pointer(&b.span))
if old := atomic.SwapPointer(addr, nil); old != nil {
(*apm.Span)(old).End()
}
}

// ClientOption sets options for tracing client requests.
Expand Down
85 changes: 85 additions & 0 deletions module/apmhttp/client_test.go
Expand Up @@ -2,17 +2,20 @@ package apmhttp_test

import (
"context"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context/ctxhttp"

"go.elastic.co/apm"
"go.elastic.co/apm/apmtest"
"go.elastic.co/apm/model"
"go.elastic.co/apm/module/apmhttp"
"go.elastic.co/apm/transport/transporttest"
Expand Down Expand Up @@ -71,3 +74,85 @@ func TestClient(t *testing.T) {
assert.Equal(t, span.ID, model.SpanID(clientTraceContext.Span))
assert.Equal(t, transaction.ID, span.ParentID)
}

func TestClientSpanDropped(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("Elastic-Apm-Traceparent")))
}))
defer server.Close()

tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

tracer.SetMaxSpans(1)
tx := tracer.StartTransaction("name", "type")
ctx := apm.ContextWithTransaction(context.Background(), tx)

var responseBodies []string
for i := 0; i < 2; i++ {
client := apmhttp.WrapClient(http.DefaultClient)
resp, err := ctxhttp.Get(ctx, client, server.URL)
assert.NoError(t, err)
responseBody, err := ioutil.ReadAll(resp.Body)
if !assert.NoError(t, err) {
resp.Body.Close()
return
}
responseBodies = append(responseBodies, string(responseBody))
}

tx.End()
tracer.Flush(nil)
payloads := transport.Payloads()
require.Len(t, payloads.Spans, 1)
transaction := payloads.Transactions[0]
span := payloads.Spans[0] // for first request

clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(responseBodies[0]))
require.NoError(t, err)
assert.Equal(t, span.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, span.ID, model.SpanID(clientTraceContext.Span))

clientTraceContext, err = apmhttp.ParseTraceparentHeader(string(responseBodies[1]))
require.NoError(t, err)
assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span))
}

func TestClientError(t *testing.T) {
_, spans, _ := apmtest.WithTransaction(func(ctx context.Context) {
client := apmhttp.WrapClient(http.DefaultClient)
resp, err := ctxhttp.Get(ctx, client, "http://testing.invalid")
if !assert.Error(t, err) {
resp.Body.Close()
}
})
require.Len(t, spans, 1)
}

func TestClientDuration(t *testing.T) {
const delay = 500 * time.Millisecond
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("hello"))
w.(http.Flusher).Flush()
time.Sleep(delay)
w.Write([]byte("world"))
}))
defer server.Close()

_, spans, _ := apmtest.WithTransaction(func(ctx context.Context) {
client := apmhttp.WrapClient(http.DefaultClient)

resp, err := ctxhttp.Get(ctx, client, server.URL)
assert.NoError(t, err)
defer resp.Body.Close()
io.Copy(ioutil.Discard, resp.Body)
})

require.Len(t, spans, 1)
span := spans[0]

assert.Equal(t, "GET "+server.Listener.Addr().String(), span.Name)
assert.Equal(t, "ext.http", span.Type)
assert.InDelta(t, delay/time.Millisecond, span.Duration, 100)
}

0 comments on commit 1e3afe4

Please sign in to comment.