Skip to content

Commit

Permalink
cmd/telemeter-server,pkg/cache: cache auth resps
Browse files Browse the repository at this point in the history
This commit adds new functionality to Telemeter server to enable it to
cache the responses to authentication requests made while handling
remote-write requests. The only currently implemented backend is
Memcached, because it is so simple, but any shared k-v would serve our
purposes.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
  • Loading branch information
squat committed Dec 10, 2019
1 parent 18dfe23 commit a5af65a
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 10 deletions.
34 changes: 25 additions & 9 deletions cmd/telemeter-server/main.go
Expand Up @@ -25,6 +25,7 @@ import (

oidc "github.com/coreos/go-oidc"
"github.com/oklog/run"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
Expand All @@ -36,6 +37,8 @@ import (
"github.com/openshift/telemeter/pkg/authorize/jwt"
"github.com/openshift/telemeter/pkg/authorize/stub"
"github.com/openshift/telemeter/pkg/authorize/tollbooth"
"github.com/openshift/telemeter/pkg/cache"
"github.com/openshift/telemeter/pkg/cache/memcached"
"github.com/openshift/telemeter/pkg/cluster"
telemeter_http "github.com/openshift/telemeter/pkg/http"
httpserver "github.com/openshift/telemeter/pkg/http/server"
Expand Down Expand Up @@ -86,6 +89,7 @@ func main() {
PartitionKey: "_id",
Ratelimit: 4*time.Minute + 30*time.Second,
TTL: 10 * time.Minute,
MemcachedExpire: 24 * 60 * 60,
}
cmd := &cobra.Command{
Short: "Aggregate federated metrics pushes",
Expand Down Expand Up @@ -122,6 +126,8 @@ func main() {
cmd.Flags().StringVar(&opt.ClientSecret, "client-secret", opt.ClientSecret, "The OIDC client secret, see https://tools.ietf.org/html/rfc6749#section-2.3.")
cmd.Flags().StringVar(&opt.ClientID, "client-id", opt.ClientID, "The OIDC client ID, see https://tools.ietf.org/html/rfc6749#section-2.3.")
cmd.Flags().StringVar(&opt.TenantKey, "tenant-key", opt.TenantKey, "The JSON key in the bearer token whose value to use as the tenant ID.")
cmd.Flags().StringSliceVar(&opt.Memcacheds, "memcached", opt.Memcacheds, "One or more Memcached server addresses.")
cmd.Flags().Int32Var(&opt.MemcachedExpire, "memcached-expire", opt.MemcachedExpire, "Time after which keys stored in Memcached should expire, given in seconds.")

cmd.Flags().DurationVar(&opt.Ratelimit, "ratelimit", opt.Ratelimit, "The rate limit of metric uploads per cluster ID. Uploads happening more often than this limit will be rejected.")
cmd.Flags().DurationVar(&opt.TTL, "ttl", opt.TTL, "The TTL for metrics to be held in memory.")
Expand Down Expand Up @@ -174,10 +180,12 @@ type Options struct {

AuthorizeEndpoint string

OIDCIssuer string
ClientID string
ClientSecret string
TenantKey string
OIDCIssuer string
ClientID string
ClientSecret string
TenantKey string
Memcacheds []string
MemcachedExpire int32

PartitionKey string
LabelFlag []string
Expand Down Expand Up @@ -236,7 +244,7 @@ func (o *Options) Run() error {

// set up the upstream authorization
var authorizeURL *url.URL
var authorizeClient *http.Client
var authorizeClient http.Client
ctx := context.Background()
if len(o.AuthorizeEndpoint) > 0 {
u, err := url.Parse(o.AuthorizeEndpoint)
Expand All @@ -255,7 +263,7 @@ func (o *Options) Run() error {
transport = telemeter_http.NewDebugRoundTripper(o.Logger, transport)
}

authorizeClient = &http.Client{
authorizeClient = http.Client{
Timeout: 20 * time.Second,
Transport: telemeter_http.NewInstrumentedRoundTripper("authorize", transport),
}
Expand Down Expand Up @@ -389,7 +397,7 @@ func (o *Options) Run() error {
// configure the authenticator and incoming data validator
var clusterAuth authorize.ClusterAuthorizer = authorize.ClusterAuthorizerFunc(stub.Authorize)
if authorizeURL != nil {
clusterAuth = tollbooth.NewAuthorizer(o.Logger, authorizeClient, authorizeURL)
clusterAuth = tollbooth.NewAuthorizer(o.Logger, &authorizeClient, authorizeURL)
}

auth := jwt.NewAuthorizeClusterHandler(o.Logger, o.PartitionKey, o.TokenExpireSeconds, signer, o.RequiredLabels, clusterAuth)
Expand Down Expand Up @@ -501,10 +509,18 @@ func (o *Options) Run() error {
),
)

// v1 routes
// v2 routes
v2AuthorizeClient := authorizeClient

if len(o.Memcacheds) > 0 {
mc := memcached.New(o.MemcachedExpire, o.Memcacheds...)
l := log.With(o.Logger, "component", "cache")
v2AuthorizeClient.Transport = cache.NewRoundTripper(mc, tollbooth.ExtractToken, v2AuthorizeClient.Transport, l, prometheus.DefaultRegisterer)
}

external.Handle("/metrics/v1/receive",
telemeter_http.NewInstrumentedHandler("receive",
authorize.NewHandler(o.Logger, authorizeClient, authorizeURL, o.TenantKey,
authorize.NewHandler(o.Logger, &v2AuthorizeClient, authorizeURL, o.TenantKey,
http.HandlerFunc(receiver.Receive),
),
),
Expand Down
6 changes: 5 additions & 1 deletion pkg/authorize/handler.go
Expand Up @@ -117,6 +117,9 @@ func AgainstEndpoint(logger log.Logger, client *http.Client, endpoint *url.URL,
}
}

// NewHandler returns an http.HandlerFunc that is able to authorize requests against Tollbooth.
// The handler function expects a bearer token in the Authorization header consisting of a
// base64-encoded JSON object containing "authorization_token" and "cluster_id" fields.
func NewHandler(logger log.Logger, client *http.Client, endpoint *url.URL, tenantKey string, next http.Handler) http.HandlerFunc {
logger = log.With(logger, "component", "authorize")
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -129,7 +132,8 @@ func NewHandler(logger log.Logger, client *http.Client, endpoint *url.URL, tenan

token, err := base64.StdEncoding.DecodeString(authParts[1])
if err != nil {
http.Error(w, "bad authorization header", http.StatusBadRequest)
level.Warn(logger).Log("msg", "failed to extract token", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
var tenant string
Expand Down
11 changes: 11 additions & 0 deletions pkg/authorize/tollbooth/tollbooth.go
@@ -1,8 +1,10 @@
package tollbooth

import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"mime"
"net/http"
"net/url"
Expand Down Expand Up @@ -75,3 +77,12 @@ func (a *authorizer) AuthorizeCluster(token, cluster string) (string, error) {

return response.AccountID, nil
}

// ExtractToken extracts the token from an auth request.
// In the case of a request to Tollbooth, the token
// is the entire contents of the request body.
func ExtractToken(r *http.Request) (string, error) {
body, err := ioutil.ReadAll(r.Body)
r.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return string(body), err
}
104 changes: 104 additions & 0 deletions pkg/cache/cache.go
@@ -0,0 +1,104 @@
package cache

import (
"bufio"
"bytes"
"net/http"
"net/http/httputil"

"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
)

// Cacher is able to get and set key value pairs.
type Cacher interface {
Get(string) ([]byte, bool, error)
Set(string, []byte) error
}

// KeyFunc generates a cache key from a http.Request.
type KeyFunc func(*http.Request) (string, error)

// RoundTripper is a http.RoundTripper than can get and set responses from a cache.
type RoundTripper struct {
c Cacher
key KeyFunc
next http.RoundTripper

l log.Logger

// Metrics.
cacheReadsTotal *prometheus.CounterVec
cacheWritesTotal *prometheus.CounterVec
}

// RoundTrip implements the RoundTripper interface.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
key, err := r.key(req)
if err != nil {
return nil, errors.Wrap(err, "failed to generate key from request")
}

raw, ok, err := r.c.Get(key)
if err != nil {
r.cacheReadsTotal.WithLabelValues("error").Inc()
return nil, errors.Wrap(err, "failed to retrieve value from cache")
}

if !ok {
r.cacheReadsTotal.WithLabelValues("miss").Inc()
resp, err := r.next.RoundTrip(req)
if err == nil && resp.StatusCode/200 == 1 {
// Try to cache the response but don't block.
defer func() {
raw, err := httputil.DumpResponse(resp, true)
if err != nil {
level.Error(r.l).Log("msg", "failed to dump response", "err", err)
return
}
if err := r.c.Set(key, raw); err != nil {
r.cacheWritesTotal.WithLabelValues("error").Inc()
level.Error(r.l).Log("msg", "failed to set value in cache", "err", err)
return
}
r.cacheWritesTotal.WithLabelValues("success").Inc()
}()
}
return resp, err
}

r.cacheReadsTotal.WithLabelValues("hit").Inc()
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(raw)), req)
return resp, errors.Wrap(err, "failed to read response")
}

// NewRoundTripper creates a new http.RoundTripper that returns http.Responses
// from a cache.
func NewRoundTripper(c Cacher, key KeyFunc, next http.RoundTripper, l log.Logger, reg prometheus.Registerer) http.RoundTripper {
rt := &RoundTripper{
c: c,
key: key,
next: next,
l: l,
cacheReadsTotal: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_reads_total",
Help: "The number of read requests made to the cache.",
}, []string{"result"},
),
cacheWritesTotal: prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_writes_total",
Help: "The number of write requests made to the cache.",
}, []string{"result"},
),
}

if reg != nil {
reg.MustRegister(rt.cacheReadsTotal, rt.cacheWritesTotal)
}

return rt
}
67 changes: 67 additions & 0 deletions pkg/cache/memcached/memcached.go
@@ -0,0 +1,67 @@
package memcached

import (
"crypto/sha256"
"fmt"

"github.com/bradfitz/gomemcache/memcache"
"github.com/pkg/errors"

tcache "github.com/openshift/telemeter/pkg/cache"
)

// cache is a Cacher implemented on top of Memcached.
type cache struct {
*memcache.Client
expiration int32
}

// New creates a new Cache from a list of Memcached servers
// and key expiration time given in seconds.
func New(expiration int32, servers ...string) tcache.Cacher {
return &cache{
memcache.New(servers...),
expiration,
}
}

// Get returns a value from Memcached.
func (c *cache) Get(key string) ([]byte, bool, error) {
key, err := hash(key)
if err != nil {
return nil, false, err
}
i, err := c.Client.Get(key)
if err != nil {
if err == memcache.ErrCacheMiss {
return nil, false, nil
}
return nil, false, err
}

return i.Value, true, nil
}

// Set sets a value in Memcached.
func (c *cache) Set(key string, value []byte) error {
key, err := hash(key)
if err != nil {
return err
}
i := memcache.Item{
Key: key,
Value: value,
Expiration: c.expiration,
}
return c.Client.Set(&i)
}

// hashKey hashes the given key to ensure that it is less than 250 bytes,
// as Memcached cannot handler longer keys.
func hash(key string) (string, error) {
h := sha256.New()
if _, err := h.Write([]byte(key)); err != nil {
return "", errors.Wrap(err, "failed to hash key")
}
return fmt.Sprintf("%x", (h.Sum(nil))), nil
}

0 comments on commit a5af65a

Please sign in to comment.