Skip to content

Commit

Permalink
Merge pull request #328 from metalmatze/middleware-logger
Browse files Browse the repository at this point in the history
Pass logger into all middlewares to log warnings on 5xx
  • Loading branch information
brancz committed Apr 28, 2020
2 parents 958dc61 + ccd04fd commit f2c98f2
Show file tree
Hide file tree
Showing 29 changed files with 1,894 additions and 27 deletions.
8 changes: 5 additions & 3 deletions cmd/telemeter-server/main.go
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/coreos/go-oidc"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/oklog/run"
Expand Down Expand Up @@ -323,6 +324,7 @@ func (o *Options) Run() error {
}
{
external := chi.NewRouter()
external.Use(middleware.RequestID)

// TODO: Refactor HealthRoutes to not take *http.Mux
mux := http.NewServeMux()
Expand Down Expand Up @@ -426,10 +428,10 @@ func (o *Options) Run() error {
external.Post("/upload",
server.InstrumentedHandler("upload",
authorize.NewAuthorizeClientHandler(jwtAuthorizer,
server.ClusterID(o.clusterIDKey,
server.Ratelimit(o.Ratelimit, time.Now,
server.ClusterID(o.Logger, o.clusterIDKey,
server.Ratelimit(o.Logger, o.Ratelimit, time.Now,
server.Snappy(
server.Validate(transforms, 24*time.Hour, o.LimitBytes, time.Now,
server.Validate(o.Logger, transforms, 24*time.Hour, o.LimitBytes, time.Now,
server.ForwardHandler(o.Logger, forwardURL),
),
),
Expand Down
31 changes: 24 additions & 7 deletions pkg/server/forward.go
Expand Up @@ -10,6 +10,7 @@ import (
"net/url"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/gogo/protobuf/proto"
Expand Down Expand Up @@ -59,9 +60,13 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
client := http.Client{}

return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

clusterID, ok := ClusterIDFromContext(r.Context())
if !ok {
http.Error(w, "failed to retrieve clusterID", http.StatusInternalServerError)
msg := "failed to retrieve clusterID"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -75,7 +80,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
if err == io.EOF {
break
}
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := err.Error()
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -85,7 +92,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

timeseries, err := convertToTimeseries(&PartitionedMetrics{ClusterID: clusterID, Families: families}, time.Now())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to convert timeseries"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -98,15 +107,19 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

data, err := proto.Marshal(wreq)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to marshal proto"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

compressed := snappy.Encode(nil, data)

req, err := http.NewRequest(http.MethodPost, forwardURL.String(), bytes.NewBuffer(compressed))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to create forwarding request"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
req.Header.Add("THANOS-TENANT", clusterID)
Expand All @@ -119,7 +132,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
begin := time.Now()
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
msg := "failed to forward request"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusBadGateway)
return
}

Expand All @@ -134,7 +149,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

if resp.StatusCode/100 != 2 {
// surfacing upstreams error to our users too
http.Error(w, fmt.Errorf("response status code is %s", resp.Status).Error(), resp.StatusCode)
msg := fmt.Sprintf("response status code is %s", resp.Status)
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, resp.StatusCode)
return
}

Expand Down
12 changes: 10 additions & 2 deletions pkg/server/ratelimited.go
Expand Up @@ -6,6 +6,9 @@ import (
"sync"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"golang.org/x/time/rate"
)

Expand All @@ -17,20 +20,25 @@ func (e ErrWriteLimitReached) Error() string {
}

// Ratelimit is a middleware that rate limits requests based on a cluster ID.
func Ratelimit(limit time.Duration, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
func Ratelimit(logger log.Logger, limit time.Duration, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
s := ratelimitStore{
limits: make(map[string]*rate.Limiter),
mu: sync.Mutex{},
}

return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

clusterID, ok := ClusterIDFromContext(r.Context())
if !ok {
http.Error(w, "failed to get cluster ID from request", http.StatusInternalServerError)
msg := "failed to get cluster ID from request"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

if err := s.limit(limit, now(), clusterID); err != nil {
level.Debug(logger).Log("msg", "rate limited", "err", err)
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/ratelimited_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-kit/kit/log"
"golang.org/x/time/rate"

"github.com/openshift/telemeter/pkg/authorize"
Expand All @@ -27,8 +28,8 @@ func TestRatelimit(t *testing.T) {
}
server := httptest.NewServer(
fakeAuthorizeHandler(
ClusterID("_id",
Ratelimit(time.Minute, time.Now,
ClusterID(log.NewNopLogger(), "_id",
Ratelimit(log.NewNopLogger(), time.Minute, time.Now,
func(w http.ResponseWriter, r *http.Request) {},
),
),
Expand Down
42 changes: 33 additions & 9 deletions pkg/server/validator.go
Expand Up @@ -10,6 +10,9 @@ import (
"net/http"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
clientmodel "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"

Expand All @@ -36,15 +39,21 @@ func ClusterIDFromContext(ctx context.Context) (string, bool) {
}

// ClusterID is a HTTP middleware that extracts the cluster's ID and passes it on via context.
func ClusterID(key string, next http.HandlerFunc) http.HandlerFunc {
func ClusterID(logger log.Logger, key string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

client, ok := authorize.FromContext(r.Context())
if !ok {
http.Error(w, "unable to find user info", http.StatusInternalServerError)
msg := "unable to find user info"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}
if len(client.Labels[key]) == 0 {
http.Error(w, fmt.Sprintf("user data must contain a '%s' label", key), http.StatusInternalServerError)
msg := fmt.Sprintf("user data must contain a '%s' label", key)
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -55,11 +64,15 @@ func ClusterID(key string, next http.HandlerFunc) http.HandlerFunc {
}

// Validate the payload of a request against given and required rules.
func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, limitBytes int64, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
func Validate(logger log.Logger, baseTransforms metricfamily.Transformer, maxAge time.Duration, limitBytes int64, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

client, ok := authorize.FromContext(r.Context())
if !ok {
http.Error(w, "unable to find user info", http.StatusInternalServerError)
msg := "unable to find user info"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -69,7 +82,9 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim

body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to read request body"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
defer r.Body.Close()
Expand All @@ -96,7 +111,9 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim
if err == io.EOF {
break
}
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to decode metrics"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
families = append(families, family)
Expand All @@ -106,22 +123,28 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim

if err := metricfamily.Filter(families, transforms); err != nil {
if errors.Is(err, metricfamily.ErrNoTimestamp) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrUnsorted) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrTimestampTooOld) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrRequiredLabelMissing) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

msg := "unexpected error during metrics transforming"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -130,9 +153,10 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim
encoder := expfmt.NewEncoder(buf, expfmt.ResponseFormat(r.Header))
for _, f := range families {
if err := encoder.Encode(f); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to encode transformed metrics again"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return

}
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/server/validator_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-kit/kit/log"
clientmodel "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"

Expand Down Expand Up @@ -40,7 +41,7 @@ func TestValidate(t *testing.T) {

s := httptest.NewServer(
fakeAuthorizeHandler(
Validate(metricfamily.MultiTransformer{}, time.Hour, 512*1024, now,
Validate(log.NewNopLogger(), metricfamily.MultiTransformer{}, time.Hour, 512*1024, now,
func(w http.ResponseWriter, r *http.Request) {
// TODO: Make the check proper to changing timestamps?
body, err := ioutil.ReadAll(r.Body)
Expand Down
6 changes: 3 additions & 3 deletions test/e2e/forward_test.go
Expand Up @@ -77,10 +77,10 @@ func TestForward(t *testing.T) {

telemeterServer = httptest.NewServer(
fakeAuthorizeHandler(
server.ClusterID("cluster",
server.Ratelimit(4*time.Minute+30*time.Second, time.Now,
server.ClusterID(log.NewNopLogger(), "cluster",
server.Ratelimit(log.NewNopLogger(), 4*time.Minute+30*time.Second, time.Now,
server.Snappy(
server.Validate(metricfamily.MultiTransformer{}, 10*365*24*time.Hour, 500*1024, time.Now,
server.Validate(log.NewNopLogger(), metricfamily.MultiTransformer{}, 10*365*24*time.Hour, 500*1024, time.Now,
server.ForwardHandler(logger, receiveURL),
),
),
Expand Down
32 changes: 32 additions & 0 deletions vendor/github.com/go-chi/chi/middleware/basic_auth.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f2c98f2

Please sign in to comment.