Skip to content
This repository has been archived by the owner on Feb 28, 2019. It is now read-only.

Commit

Permalink
Address review comments. Add authType functionality to AuthHandler.
Browse files Browse the repository at this point in the history
  • Loading branch information
m-sandusky committed Jan 10, 2018
1 parent abc04c3 commit be98154
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 63 deletions.
16 changes: 15 additions & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,34 @@ import (

type keyType int

// AuthorizationType designates a type of authorization.
type AuthorizationType int

type errorResponseHandler func(w http.ResponseWriter, code int, msg string) error

const (
// UserIDField is a key
UserIDField keyType = iota
)

const (
// AuthorizationTypeNone is the no authorizationType case.
AuthorizationTypeNone AuthorizationType = iota
// AuthorizationTypeReadOnly is the read only authorizationType case.
AuthorizationTypeReadOnly
// AuthorizationTypeWriteOnly is the write only authorizationType case.
AuthorizationTypeWriteOnly
// AuthorizationTypeReadWrite is the read and write authorizationType case.
AuthorizationTypeReadWrite
)

// HTTPAuthService defines how to handle requests for various http authentication and authorization methods.
type HTTPAuthService interface {
// NewAuthHandler should return a handler that performs some check on the request coming into the given handler
// and then runs the handler if it is. If the request passes authentication/authorization successfully, it should call SetUser
// to make the callers id available to the service in a global context. errHandler should be passed in to properly format the
// the error and respond the the request in the event of bad auth.
NewAuthHandler(next http.Handler, errHandler errorResponseHandler) http.Handler
NewAuthHandler(authType AuthorizationType, next http.Handler, errHandler errorResponseHandler) http.Handler

// SetUser sets a userID that identifies the api caller in the global context.
SetUser(parent context.Context, userID string) context.Context
Expand Down
2 changes: 1 addition & 1 deletion auth/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewNoopAuth() HTTPAuthService {
return noopAuth{}
}

func (a noopAuth) NewAuthHandler(next http.Handler, errHandler errorResponseHandler) http.Handler {
func (a noopAuth) NewAuthHandler(authType AuthorizationType, next http.Handler, errHandler errorResponseHandler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
Expand Down
22 changes: 15 additions & 7 deletions auth/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,22 @@ type simpleAuthorization struct {
writeWhitelistedUserIDs []string
}

func (a simpleAuthorization) authorize(httpMethod, userID string) error {
switch httpMethod {
case http.MethodGet:
func (a simpleAuthorization) authorize(authType AuthorizationType, userID string) error {
switch authType {
case AuthorizationTypeNone:
return nil
case AuthorizationTypeReadOnly:
return a.authorizeUser(a.readWhitelistEnabled, a.readWhitelistedUserIDs, userID)
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
case AuthorizationTypeWriteOnly:
return a.authorizeUser(a.writeWhitelistEnabled, a.writeWhitelistedUserIDs, userID)
case AuthorizationTypeReadWrite:
err := a.authorizeUser(a.readWhitelistEnabled, a.readWhitelistedUserIDs, userID)
if err != nil {
return err
}
return a.authorizeUser(a.writeWhitelistEnabled, a.writeWhitelistedUserIDs, userID)
default:
return fmt.Errorf("unsupported request method: %s", httpMethod)
return fmt.Errorf("unsupported authorizationType %v passed to handler", authType)
}
}

Expand All @@ -118,7 +126,7 @@ func (a simpleAuthorization) authorizeUser(useWhitelist bool, whitelistedUsers [

// Authenticate looks for a header defining a user name. If it finds it, runs the actual http handler passed as a parameter.
// Otherwise, it returns an Unauthorized http response.
func (a simpleAuth) NewAuthHandler(next http.Handler, errHandler errorResponseHandler) http.Handler {
func (a simpleAuth) NewAuthHandler(authType AuthorizationType, next http.Handler, errHandler errorResponseHandler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var (
userID = r.Header.Get(a.authentication.userIDHeader)
Expand All @@ -133,7 +141,7 @@ func (a simpleAuth) NewAuthHandler(next http.Handler, errHandler errorResponseHa
return
}

err = a.authorization.authorize(r.Method, userID)
err = a.authorization.authorize(authType, userID)
if err != nil {
errHandler(w, http.StatusForbidden, err.Error())
return
Expand Down
21 changes: 12 additions & 9 deletions auth/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,16 @@ func TestSimpleAuthorizationAuthorize(t *testing.T) {
readWhitelistEnabled: true,
writeWhitelistEnabled: false,
readWhitelistedUserIDs: []string{"foo", "bar"},
writeWhitelistedUserIDs: []string{"foo", "bar"},
writeWhitelistedUserIDs: []string{"foo", "bar", "baz"},
}

require.Nil(t, authorization.authorize("GET", "foo"))
require.Nil(t, authorization.authorize("POST", "foo"))
require.EqualError(t, authorization.authorize("OPTIONS", "foo"), "unsupported request method: OPTIONS")
require.EqualError(t, authorization.authorize("GET", "baz"), "supplied userID: [baz] is not authorized")
require.Nil(t, authorization.authorize(AuthorizationTypeReadOnly, "foo"))
require.Nil(t, authorization.authorize(AuthorizationTypeWriteOnly, "foo"))
require.Nil(t, authorization.authorize(AuthorizationTypeNone, "foo"))
require.Nil(t, authorization.authorize(AuthorizationTypeWriteOnly, "baz"))
require.EqualError(t, authorization.authorize(AuthorizationTypeReadOnly, "baz"), "supplied userID: [baz] is not authorized")
require.EqualError(t, authorization.authorize(AuthorizationTypeReadWrite, "baz"), "supplied userID: [baz] is not authorized")
require.EqualError(t, authorization.authorize(AuthorizationType(100), "baz"), "unsupported authorizationType 100 passed to handler")
}

func TestHealthCheck(t *testing.T) {
Expand All @@ -147,7 +150,7 @@ func TestHealthCheck(t *testing.T) {
require.Equal(t, "testHeader", v)
})

wrappedCall := a.NewAuthHandler(f, writeAPIResponse)
wrappedCall := a.NewAuthHandler(AuthorizationTypeNone, f, writeAPIResponse)
wrappedCall.ServeHTTP(httptest.NewRecorder(), &http.Request{})
}

Expand All @@ -160,7 +163,7 @@ func TestAuthenticateFailure(t *testing.T) {
})
recorder := httptest.NewRecorder()

wrappedCall := a.NewAuthHandler(f, writeAPIResponse)
wrappedCall := a.NewAuthHandler(AuthorizationTypeNone, f, writeAPIResponse)
wrappedCall.ServeHTTP(recorder, &http.Request{})
require.Equal(t, http.StatusUnauthorized, recorder.Code)
require.Equal(t, "application/json", recorder.HeaderMap["Content-Type"][0])
Expand All @@ -180,7 +183,7 @@ func TestAuthenticateWithOriginatorID(t *testing.T) {
writeAPIResponse(w, http.StatusOK, "success!")
})
recorder := httptest.NewRecorder()
wrappedCall := a.NewAuthHandler(f, writeAPIResponse)
wrappedCall := a.NewAuthHandler(AuthorizationTypeNone, f, writeAPIResponse)
wrappedCall.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, "application/json", recorder.HeaderMap["Content-Type"][0])
Expand All @@ -198,7 +201,7 @@ func TestAuthorizeFailure(t *testing.T) {
require.NoError(t, err)
req.Header.Add("testHeader", "validUserID")

wrappedCall := a.NewAuthHandler(f, writeAPIResponse)
wrappedCall := a.NewAuthHandler(AuthorizationTypeReadOnly, f, writeAPIResponse)
wrappedCall.ServeHTTP(recorder, req)
require.Equal(t, http.StatusForbidden, recorder.Code)
require.Equal(t, "application/json", recorder.HeaderMap["Content-Type"][0])
Expand Down
11 changes: 6 additions & 5 deletions service/r2/kv/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
package kv

import (
"errors"
"fmt"

"github.com/m3db/m3ctl/service/r2"
"github.com/m3db/m3metrics/errors"
merrors "github.com/m3db/m3metrics/errors"
"github.com/m3db/m3metrics/rules"
"github.com/m3db/m3x/clock"
xerrors "github.com/m3db/m3x/errors"
Expand Down Expand Up @@ -75,9 +76,9 @@ func (s *store) FetchNamespaces() (*rules.NamespacesView, error) {

func (s *store) ValidateRuleSet(rs *rules.RuleSetSnapshot) error {
validator := s.opts.Validator()
// No validator is set so by default the ruleSetSnapshot is valid
// If no validator is set, then the validation functionality is not applicable
if validator == nil {
return nil
return errors.New("no validator set on StoreOptions so validation is not applicable")
}
return validator.ValidateSnapshot(rs)
}
Expand Down Expand Up @@ -411,9 +412,9 @@ func (s *store) handleUpstreamError(err error) error {
}

switch err.(type) {
case errors.RuleConflictError:
case merrors.RuleConflictError:
return r2.NewConflictError(err.Error())
case errors.ValidationError:
case merrors.ValidationError:
return r2.NewBadInputError(err.Error())
default:
return r2.NewInternalError(err.Error())
Expand Down
5 changes: 4 additions & 1 deletion service/r2/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ func validateRuleSet(s *service, r *http.Request) (data interface{}, err error)
)
}

rss := rsj.ruleSetSnapshot()
rss, err := rsj.ruleSetSnapshot(genIDTrue)
if err != nil {
return nil, err
}
if err := s.store.ValidateRuleSet(rss); err != nil {
return nil, err
}
Expand Down
119 changes: 85 additions & 34 deletions service/r2/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/m3db/m3metrics/rules"
"github.com/m3db/m3x/clock"
"github.com/m3db/m3x/instrument"

"github.com/pborman/uuid"

"github.com/gorilla/mux"
Expand All @@ -54,8 +55,8 @@ const (
)

var (
namespacePrefix = fmt.Sprintf("%s/{%s}", namespacePath, namespaceIDVar)
namespaceValidatePath = fmt.Sprintf("%s/{%s}/validate", namespacePath, namespaceIDVar)
namespacePrefix = fmt.Sprintf("%s/{%s}", namespacePath, namespaceIDVar)
validateRuleSetPath = fmt.Sprintf("%s/{%s}/ruleset/validate", namespacePath, namespaceIDVar)

mappingRuleRoot = fmt.Sprintf("%s/%s", namespacePrefix, mappingRulePrefix)
mappingRuleWithIDPath = fmt.Sprintf("%s/{%s}", mappingRuleRoot, ruleIDVar)
Expand All @@ -68,6 +69,15 @@ var (
errNilRequest = errors.New("Nil request")
)

type endpoint struct {
path string
method string
}

var authorizationRegistry = map[endpoint]auth.AuthorizationType{
{path: validateRuleSetPath, method: http.MethodPost}: auth.AuthorizationTypeReadOnly,
}

func sendResponse(w http.ResponseWriter, data []byte, status int) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
Expand Down Expand Up @@ -121,13 +131,13 @@ type r2Handler struct {
auth auth.HTTPAuthService
}

func (h r2Handler) wrap(fn r2HandlerFunc) http.Handler {
func (h r2Handler) wrap(authType auth.AuthorizationType, fn r2HandlerFunc) http.Handler {
f := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
h.handleError(w, err)
}
})
return h.auth.NewAuthHandler(f, writeAPIResponse)
return h.auth.NewAuthHandler(authType, f, writeAPIResponse)
}

func (h r2Handler) handleError(w http.ResponseWriter, opError error) {
Expand Down Expand Up @@ -157,6 +167,32 @@ func (h r2Handler) handleError(w http.ResponseWriter, opError error) {
}
}

func getDefaultAuthorizationTypeForHTTPMethod(method string) (auth.AuthorizationType, error) {
switch method {
case http.MethodGet:
return auth.AuthorizationTypeReadOnly, nil
case http.MethodPost, http.MethodPut, http.MethodDelete:
return auth.AuthorizationTypeWriteOnly, nil
default:
return auth.AuthorizationTypeNone, fmt.Errorf("unsupported http method %s for getting authorization type", method)
}
}

func registerRoute(router *mux.Router, path, method string, h r2Handler, hf r2HandlerFunc) {
authType, exists := authorizationRegistry[endpoint{path: path, method: method}]
if !exists {
var err error
authType, err = getDefaultAuthorizationTypeForHTTPMethod(method)
if err != nil {
// Panic if cannot find an authorization type for the route and/or method. This indicates an
// unrecognized route and service should panic on startup when registering routes.
panic(err)
}
}
fn := h.wrap(authType, hf)
router.Handle(path, fn).Methods(method)
}

// service handles all of the endpoints for r2.
type service struct {
rootPrefix string
Expand Down Expand Up @@ -189,36 +225,35 @@ func (s *service) URLPrefix() string { return s.rootPrefix }

func (s *service) RegisterHandlers(router *mux.Router) {
log := s.iOpts.Logger()
// Namespaces action
h := r2Handler{s.iOpts, s.authService}

router.Handle(namespacePath, h.wrap(s.fetchNamespaces)).Methods(http.MethodGet)
router.Handle(namespacePath, h.wrap(s.createNamespace)).Methods(http.MethodPost)
router.Handle(namespaceValidatePath, h.wrap(s.validateNamespace)).Methods(http.MethodPost)
// Namespaces actions
registerRoute(router, namespacePath, http.MethodGet, h, s.fetchNamespaces)
registerRoute(router, namespacePath, http.MethodPost, h, s.createNamespace)

// Ruleset actions
router.Handle(namespacePrefix, h.wrap(s.fetchNamespace)).Methods(http.MethodGet)
router.Handle(namespacePrefix, h.wrap(s.deleteNamespace)).Methods(http.MethodDelete)
registerRoute(router, namespacePrefix, http.MethodGet, h, s.fetchNamespace)
registerRoute(router, namespacePrefix, http.MethodDelete, h, s.deleteNamespace)
registerRoute(router, validateRuleSetPath, http.MethodPost, h, s.validateNamespace)

// Mapping Rule actions
router.Handle(mappingRuleRoot, h.wrap(s.createMappingRule)).Methods(http.MethodPost)
registerRoute(router, mappingRuleRoot, http.MethodPost, h, s.createMappingRule)

router.Handle(mappingRuleWithIDPath, h.wrap(s.fetchMappingRule)).Methods(http.MethodGet)
router.Handle(mappingRuleWithIDPath, h.wrap(s.updateMappingRule)).Methods(http.MethodPut, http.MethodPatch)
router.Handle(mappingRuleWithIDPath, h.wrap(s.deleteMappingRule)).Methods(http.MethodDelete)
registerRoute(router, mappingRuleWithIDPath, http.MethodGet, h, s.fetchMappingRule)
registerRoute(router, mappingRuleWithIDPath, http.MethodPut, h, s.updateMappingRule)
registerRoute(router, mappingRuleWithIDPath, http.MethodDelete, h, s.deleteMappingRule)

// Mapping Rule history
router.Handle(mappingRuleHistoryPath, h.wrap(s.fetchMappingRuleHistory)).Methods(http.MethodGet)
registerRoute(router, mappingRuleHistoryPath, http.MethodGet, h, s.fetchMappingRuleHistory)

// Rollup Rule actions
router.Handle(rollupRuleRoot, h.wrap(s.createRollupRule)).Methods(http.MethodPost)
registerRoute(router, rollupRuleRoot, http.MethodPost, h, s.createRollupRule)

router.Handle(rollupRuleWithIDPath, h.wrap(s.fetchRollupRule)).Methods(http.MethodGet)
router.Handle(rollupRuleWithIDPath, h.wrap(s.updateRollupRule)).Methods(http.MethodPut, http.MethodPatch)
router.Handle(rollupRuleWithIDPath, h.wrap(s.deleteRollupRule)).Methods(http.MethodDelete)
registerRoute(router, rollupRuleWithIDPath, http.MethodGet, h, s.fetchRollupRule)
registerRoute(router, rollupRuleWithIDPath, http.MethodPut, h, s.updateRollupRule)
registerRoute(router, rollupRuleWithIDPath, http.MethodDelete, h, s.deleteRollupRule)

// Rollup Rule history
router.Handle(rollupRuleHistoryPath, h.wrap(s.fetchRollupRuleHistory)).Methods(http.MethodGet)
registerRoute(router, rollupRuleHistoryPath, http.MethodGet, h, s.fetchRollupRuleHistory)

log.Infof("Registered rules endpoints")
}
Expand Down Expand Up @@ -566,32 +601,48 @@ type ruleSetJSON struct {
RollupRules []rollupRuleJSON `json:"rollupRules"`
}

// Creates a new RuleSetSnapshot from a rulesetJSON. If the ruleSetJSON has no IDs for any of its
// mapping rules or rollup rules, it generates missing IDs and sets as a string UUID string so they
// can be stored in a mapping (id -> rule).
func (r ruleSetJSON) ruleSetSnapshot() *rules.RuleSetSnapshot {
rss := rules.RuleSetSnapshot{
Namespace: r.Namespace,
Version: r.Version,
MappingRules: map[string]*rules.MappingRuleView{},
RollupRules: map[string]*rules.RollupRuleView{},
}
type genID int

const (
genIDFalse genID = iota
genIDTrue
)

// ruleSetSnapshot create a RuleSetSnapshot from a rulesetJSON. If the ruleSetJSON has no IDs
// for any of its mapping rules or rollup rules, it generates missing IDs and sets as a string UUID
// string so they can be stored in a mapping (id -> rule).
func (r ruleSetJSON) ruleSetSnapshot(genID genID) (*rules.RuleSetSnapshot, error) {

mappingRules := make(map[string]*rules.MappingRuleView, len(r.MappingRules))
for _, mr := range r.MappingRules {
id := mr.ID
if id == "" {
if genID == genIDFalse {
return nil, fmt.Errorf("can't convert rulesetJSON to ruleSetSnapshot, no mapping rule id for %v", mr)
}
id = uuid.New()
mr.ID = id
}
rss.MappingRules[id] = mr.mappingRuleView()
mappingRules[id] = mr.mappingRuleView()
}

rollupRules := make(map[string]*rules.RollupRuleView, len(r.RollupRules))
for _, rr := range r.RollupRules {
id := rr.ID
if id == "" {
if genID == genIDFalse {
return nil, fmt.Errorf("can't convert rulesetJSON to ruleSetSnapshot, no rollup rule id for %v", rr)
}
id = uuid.New()
rr.ID = id
}
rss.RollupRules[id] = rr.rollupRuleView()
rollupRules[id] = rr.rollupRuleView()
}
return &rss

return &rules.RuleSetSnapshot{
Namespace: r.Namespace,
Version: r.Version,
MappingRules: mappingRules,
RollupRules: rollupRules,
}, nil
}

0 comments on commit be98154

Please sign in to comment.