Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass logger into all middlewares to log warnings on 5xx #328

Merged
merged 2 commits into from Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions cmd/telemeter-server/main.go
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/coreos/go-oidc"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/oklog/run"
Expand Down Expand Up @@ -323,6 +324,7 @@ func (o *Options) Run() error {
}
{
external := chi.NewRouter()
external.Use(middleware.RequestID)

// TODO: Refactor HealthRoutes to not take *http.Mux
mux := http.NewServeMux()
Expand Down Expand Up @@ -426,10 +428,10 @@ func (o *Options) Run() error {
external.Post("/upload",
server.InstrumentedHandler("upload",
authorize.NewAuthorizeClientHandler(jwtAuthorizer,
server.ClusterID(o.clusterIDKey,
server.Ratelimit(o.Ratelimit, time.Now,
server.ClusterID(o.Logger, o.clusterIDKey,
server.Ratelimit(o.Logger, o.Ratelimit, time.Now,
server.Snappy(
server.Validate(transforms, 24*time.Hour, o.LimitBytes, time.Now,
server.Validate(o.Logger, transforms, 24*time.Hour, o.LimitBytes, time.Now,
server.ForwardHandler(o.Logger, forwardURL),
),
),
Expand Down
31 changes: 24 additions & 7 deletions pkg/server/forward.go
Expand Up @@ -10,6 +10,7 @@ import (
"net/url"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/gogo/protobuf/proto"
Expand Down Expand Up @@ -59,9 +60,13 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
client := http.Client{}

return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

clusterID, ok := ClusterIDFromContext(r.Context())
if !ok {
http.Error(w, "failed to retrieve clusterID", http.StatusInternalServerError)
msg := "failed to retrieve clusterID"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -75,7 +80,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
if err == io.EOF {
break
}
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := err.Error()
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -85,7 +92,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

timeseries, err := convertToTimeseries(&PartitionedMetrics{ClusterID: clusterID, Families: families}, time.Now())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to convert timeseries"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -98,15 +107,19 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

data, err := proto.Marshal(wreq)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to marshal proto"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}

compressed := snappy.Encode(nil, data)

req, err := http.NewRequest(http.MethodPost, forwardURL.String(), bytes.NewBuffer(compressed))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to create forwarding request"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
req.Header.Add("THANOS-TENANT", clusterID)
Expand All @@ -119,7 +132,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {
begin := time.Now()
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
msg := "failed to forward request"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusBadGateway)
return
}

Expand All @@ -134,7 +149,9 @@ func ForwardHandler(logger log.Logger, forwardURL *url.URL) http.HandlerFunc {

if resp.StatusCode/100 != 2 {
// surfacing upstreams error to our users too
http.Error(w, fmt.Errorf("response status code is %s", resp.Status).Error(), resp.StatusCode)
msg := fmt.Sprintf("response status code is %s", resp.Status)
level.Warn(logger).Log("msg", msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also log 4xx as well right? Is this what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I guess for now it's fine. If it becomes too spammy, we might want to add another if in there :)

http.Error(w, msg, resp.StatusCode)
return
}

Expand Down
12 changes: 10 additions & 2 deletions pkg/server/ratelimited.go
Expand Up @@ -6,6 +6,9 @@ import (
"sync"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"golang.org/x/time/rate"
)

Expand All @@ -17,20 +20,25 @@ func (e ErrWriteLimitReached) Error() string {
}

// Ratelimit is a middleware that rate limits requests based on a cluster ID.
func Ratelimit(limit time.Duration, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
func Ratelimit(logger log.Logger, limit time.Duration, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
s := ratelimitStore{
limits: make(map[string]*rate.Limiter),
mu: sync.Mutex{},
}

return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

clusterID, ok := ClusterIDFromContext(r.Context())
if !ok {
http.Error(w, "failed to get cluster ID from request", http.StatusInternalServerError)
msg := "failed to get cluster ID from request"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

if err := s.limit(limit, now(), clusterID); err != nil {
level.Debug(logger).Log("msg", "rate limited", "err", err)
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/ratelimited_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-kit/kit/log"
"golang.org/x/time/rate"

"github.com/openshift/telemeter/pkg/authorize"
Expand All @@ -27,8 +28,8 @@ func TestRatelimit(t *testing.T) {
}
server := httptest.NewServer(
fakeAuthorizeHandler(
ClusterID("_id",
Ratelimit(time.Minute, time.Now,
ClusterID(log.NewNopLogger(), "_id",
Ratelimit(log.NewNopLogger(), time.Minute, time.Now,
func(w http.ResponseWriter, r *http.Request) {},
),
),
Expand Down
42 changes: 33 additions & 9 deletions pkg/server/validator.go
Expand Up @@ -10,6 +10,9 @@ import (
"net/http"
"time"

"github.com/go-chi/chi/middleware"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
clientmodel "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"

Expand All @@ -36,15 +39,21 @@ func ClusterIDFromContext(ctx context.Context) (string, bool) {
}

// ClusterID is a HTTP middleware that extracts the cluster's ID and passes it on via context.
func ClusterID(key string, next http.HandlerFunc) http.HandlerFunc {
func ClusterID(logger log.Logger, key string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

client, ok := authorize.FromContext(r.Context())
if !ok {
http.Error(w, "unable to find user info", http.StatusInternalServerError)
msg := "unable to find user info"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}
if len(client.Labels[key]) == 0 {
http.Error(w, fmt.Sprintf("user data must contain a '%s' label", key), http.StatusInternalServerError)
msg := fmt.Sprintf("user data must contain a '%s' label", key)
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -55,11 +64,15 @@ func ClusterID(key string, next http.HandlerFunc) http.HandlerFunc {
}

// Validate the payload of a request against given and required rules.
func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, limitBytes int64, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
func Validate(logger log.Logger, baseTransforms metricfamily.Transformer, maxAge time.Duration, limitBytes int64, now func() time.Time, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger = log.With(logger, "request", middleware.GetReqID(r.Context()))

client, ok := authorize.FromContext(r.Context())
if !ok {
http.Error(w, "unable to find user info", http.StatusInternalServerError)
msg := "unable to find user info"
level.Warn(logger).Log("msg", msg)
http.Error(w, msg, http.StatusInternalServerError)
return
}

Expand All @@ -69,7 +82,9 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim

body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to read request body"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
defer r.Body.Close()
Expand All @@ -96,7 +111,9 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim
if err == io.EOF {
break
}
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to decode metrics"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
families = append(families, family)
Expand All @@ -106,22 +123,28 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim

if err := metricfamily.Filter(families, transforms); err != nil {
if errors.Is(err, metricfamily.ErrNoTimestamp) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrUnsorted) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrTimestampTooOld) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if errors.Is(err, metricfamily.ErrRequiredLabelMissing) {
level.Debug(logger).Log("msg", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

msg := "unexpected error during metrics transforming"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -130,9 +153,10 @@ func Validate(baseTransforms metricfamily.Transformer, maxAge time.Duration, lim
encoder := expfmt.NewEncoder(buf, expfmt.ResponseFormat(r.Header))
for _, f := range families {
if err := encoder.Encode(f); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
msg := "failed to encode transformed metrics again"
level.Warn(logger).Log("msg", msg, "err", err)
http.Error(w, msg, http.StatusInternalServerError)
return

}
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/server/validator_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-kit/kit/log"
clientmodel "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"

Expand Down Expand Up @@ -40,7 +41,7 @@ func TestValidate(t *testing.T) {

s := httptest.NewServer(
fakeAuthorizeHandler(
Validate(metricfamily.MultiTransformer{}, time.Hour, 512*1024, now,
Validate(log.NewNopLogger(), metricfamily.MultiTransformer{}, time.Hour, 512*1024, now,
func(w http.ResponseWriter, r *http.Request) {
// TODO: Make the check proper to changing timestamps?
body, err := ioutil.ReadAll(r.Body)
Expand Down
6 changes: 3 additions & 3 deletions test/e2e/forward_test.go
Expand Up @@ -77,10 +77,10 @@ func TestForward(t *testing.T) {

telemeterServer = httptest.NewServer(
fakeAuthorizeHandler(
server.ClusterID("cluster",
server.Ratelimit(4*time.Minute+30*time.Second, time.Now,
server.ClusterID(log.NewNopLogger(), "cluster",
server.Ratelimit(log.NewNopLogger(), 4*time.Minute+30*time.Second, time.Now,
server.Snappy(
server.Validate(metricfamily.MultiTransformer{}, 10*365*24*time.Hour, 500*1024, time.Now,
server.Validate(log.NewNopLogger(), metricfamily.MultiTransformer{}, 10*365*24*time.Hour, 500*1024, time.Now,
server.ForwardHandler(logger, receiveURL),
),
),
Expand Down
32 changes: 32 additions & 0 deletions vendor/github.com/go-chi/chi/middleware/basic_auth.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.