Skip to content

Commit

Permalink
Add otelhttp.Middleware for better compatiblity with third-party HTTP…
Browse files Browse the repository at this point in the history
… routers and middleware chainers

The key difference between the existing otelhttp.Handler and the new
otelhttp.Middleware is that `Middleware` takes the `next` handler as an
argument after construction, wheres the existing `Handler` works by
wrapping one specific http.Handler at construction time.
  • Loading branch information
alnr committed Nov 2, 2022
1 parent 0893cda commit e3d5d21
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 57 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,9 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

## [Unreleased]

### Added
- An HTTP middleware in addition to the existing HTTP handler. The middleware is constructed once and afterwards can take different handlers as arguments, whereas the handler wraps a single http.Handler during construction. This allows better compatiblity with packages like github.com/urfave/negroni and github.com/go-chi/chi.

## [1.11.1/0.36.4/0.5.2]

### Added
Expand Down
131 changes: 74 additions & 57 deletions instrumentation/net/http/otelhttp/handler.go
Expand Up @@ -33,14 +33,36 @@ import (

var _ http.Handler = &Handler{}

// Handler is http middleware that corresponds to the http.Handler interface and
// is designed to wrap a http.Mux (or equivalent), while individual routes on
// the mux are wrapped with WithRouteTag. A Handler will add various attributes
// to the span using the attribute.Keys defined in this package.
// Handler is an http.Handler which wraps a http.Mux (or equivalent), while
// individual routes on the mux are wrapped with WithRouteTag. A Handler will
// add various attributes to the span using the attribute.Keys defined in this
// package.
type Handler struct {
operation string
handler http.Handler
next http.Handler
middleware *Middleware
}

func defaultHandlerFormatter(operation string, _ *http.Request) string {
return operation
}

// NewHandler wraps the passed handler in a span named after the operation and
// with any provided Options.
func NewHandler(handler http.Handler, operation string, opts ...Option) http.Handler {
return &Handler{
next: handler,
middleware: NewMiddleware(operation, opts...),
}
}

// ServeHTTP serves HTTP requests (implements http.Handler).
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.middleware.ServeHTTP(w, r, h.next)
}

// Middleware is an http middleware which wraps the next handler in a span.
type Middleware struct {
operation string
tracer trace.Tracer
meter metric.Meter
propagators propagation.TextMapPropagator
Expand All @@ -55,41 +77,35 @@ type Handler struct {
publicEndpointFn func(*http.Request) bool
}

func defaultHandlerFormatter(operation string, _ *http.Request) string {
return operation
}

// NewHandler wraps the passed handler, functioning like middleware, in a span
// named after the operation and with any provided Options.
func NewHandler(handler http.Handler, operation string, opts ...Option) http.Handler {
h := Handler{
handler: handler,
// NewMiddleware returns a tracing middleware from the given operation name and
// options.
func NewMiddleware(operation string, opts ...Option) *Middleware {
m := Middleware{
operation: operation,
}

defaultOpts := []Option{
WithSpanOptions(trace.WithSpanKind(trace.SpanKindServer)),
WithSpanNameFormatter(defaultHandlerFormatter),
}

c := newConfig(append(defaultOpts, opts...)...)
h.configure(c)
h.createMeasures()
m.configure(c)
m.createMeasures()

return &h
return &m
}

func (h *Handler) configure(c *config) {
h.tracer = c.Tracer
h.meter = c.Meter
h.propagators = c.Propagators
h.spanStartOptions = c.SpanStartOptions
h.readEvent = c.ReadEvent
h.writeEvent = c.WriteEvent
h.filters = c.Filters
h.spanNameFormatter = c.SpanNameFormatter
h.publicEndpoint = c.PublicEndpoint
h.publicEndpointFn = c.PublicEndpointFn
func (m *Middleware) configure(c *config) {
m.tracer = c.Tracer
m.meter = c.Meter
m.propagators = c.Propagators
m.spanStartOptions = c.SpanStartOptions
m.readEvent = c.ReadEvent
m.writeEvent = c.WriteEvent
m.filters = c.Filters
m.spanNameFormatter = c.SpanNameFormatter
m.publicEndpoint = c.PublicEndpoint
m.publicEndpointFn = c.PublicEndpointFn
}

func handleErr(err error) {
Expand All @@ -98,38 +114,39 @@ func handleErr(err error) {
}
}

func (h *Handler) createMeasures() {
h.counters = make(map[string]syncint64.Counter)
h.valueRecorders = make(map[string]syncfloat64.Histogram)
func (m *Middleware) createMeasures() {
m.counters = make(map[string]syncint64.Counter)
m.valueRecorders = make(map[string]syncfloat64.Histogram)

requestBytesCounter, err := h.meter.SyncInt64().Counter(RequestContentLength)
requestBytesCounter, err := m.meter.SyncInt64().Counter(RequestContentLength)
handleErr(err)

responseBytesCounter, err := h.meter.SyncInt64().Counter(ResponseContentLength)
responseBytesCounter, err := m.meter.SyncInt64().Counter(ResponseContentLength)
handleErr(err)

serverLatencyMeasure, err := h.meter.SyncFloat64().Histogram(ServerLatency)
serverLatencyMeasure, err := m.meter.SyncFloat64().Histogram(ServerLatency)
handleErr(err)

h.counters[RequestContentLength] = requestBytesCounter
h.counters[ResponseContentLength] = responseBytesCounter
h.valueRecorders[ServerLatency] = serverLatencyMeasure
m.counters[RequestContentLength] = requestBytesCounter
m.counters[ResponseContentLength] = responseBytesCounter
m.valueRecorders[ServerLatency] = serverLatencyMeasure
}

// ServeHTTP serves HTTP requests (http.Handler).
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// ServeHTTP sets up tracing and calls the given next http.Handler with the span
// context injected into the request context.
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.Handler) {
requestStartTime := time.Now()
for _, f := range h.filters {
for _, f := range m.filters {
if !f(r) {
// Simply pass through to the handler if a filter rejects the request
h.handler.ServeHTTP(w, r)
next.ServeHTTP(w, r)
return
}
}

ctx := h.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
opts := h.spanStartOptions
if h.publicEndpoint || (h.publicEndpointFn != nil && h.publicEndpointFn(r.WithContext(ctx))) {
ctx := m.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
opts := m.spanStartOptions
if m.publicEndpoint || (m.publicEndpointFn != nil && m.publicEndpointFn(r.WithContext(ctx))) {
opts = append(opts, trace.WithNewRoot())
// Linking incoming span context if any for public endpoint.
if s := trace.SpanContextFromContext(ctx); s.IsValid() && s.IsRemote() {
Expand All @@ -140,10 +157,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
opts = append([]trace.SpanStartOption{
trace.WithAttributes(semconv.NetAttributesFromHTTPRequest("tcp", r)...),
trace.WithAttributes(semconv.EndUserAttributesFromHTTPRequest(r)...),
trace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(h.operation, "", r)...),
trace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(m.operation, "", r)...),
}, opts...) // start with the configured options

tracer := h.tracer
tracer := m.tracer

if tracer == nil {
if span := trace.SpanFromContext(r.Context()); span.SpanContext().IsValid() {
Expand All @@ -153,11 +170,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

ctx, span := tracer.Start(ctx, h.spanNameFormatter(h.operation, r), opts...)
ctx, span := tracer.Start(ctx, m.spanNameFormatter(m.operation, r), opts...)
defer span.End()

readRecordFunc := func(int64) {}
if h.readEvent {
if m.readEvent {
readRecordFunc = func(n int64) {
span.AddEvent("read", trace.WithAttributes(ReadBytesKey.Int64(n)))
}
Expand All @@ -174,13 +191,13 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

writeRecordFunc := func(int64) {}
if h.writeEvent {
if m.writeEvent {
writeRecordFunc = func(n int64) {
span.AddEvent("write", trace.WithAttributes(WroteBytesKey.Int64(n)))
}
}

rww := &respWriterWrapper{ResponseWriter: w, record: writeRecordFunc, ctx: ctx, props: h.propagators}
rww := &respWriterWrapper{ResponseWriter: w, record: writeRecordFunc, ctx: ctx, props: m.propagators}

// Wrap w to use our ResponseWriter methods while also exposing
// other interfaces that w may implement (http.CloseNotifier,
Expand All @@ -201,19 +218,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
labeler := &Labeler{}
ctx = injectLabeler(ctx, labeler)

h.handler.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, r.WithContext(ctx))

setAfterServeAttributes(span, bw.read, rww.written, rww.statusCode, bw.err, rww.err)

// Add metrics
attributes := append(labeler.Get(), semconv.HTTPServerMetricAttributesFromHTTPRequest(h.operation, r)...)
h.counters[RequestContentLength].Add(ctx, bw.read, attributes...)
h.counters[ResponseContentLength].Add(ctx, rww.written, attributes...)
attributes := append(labeler.Get(), semconv.HTTPServerMetricAttributesFromHTTPRequest(m.operation, r)...)
m.counters[RequestContentLength].Add(ctx, bw.read, attributes...)
m.counters[ResponseContentLength].Add(ctx, rww.written, attributes...)

// Use floating point division here for higher precision (instead of Millisecond method).
elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond)

h.valueRecorders[ServerLatency].Record(ctx, elapsedTime, attributes...)
m.valueRecorders[ServerLatency].Record(ctx, elapsedTime, attributes...)
}

func setAfterServeAttributes(span trace.Span, read, wrote int64, statusCode int, rerr, werr error) {
Expand Down

0 comments on commit e3d5d21

Please sign in to comment.