Skip to content

Commit

Permalink
etcdserver: partial watches
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheph committed May 17, 2019
1 parent fca8add commit 4551c14
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 83 deletions.
33 changes: 25 additions & 8 deletions auth/range_perm_cache.go
Expand Up @@ -70,21 +70,38 @@ func getMergedPerms(tx backend.BatchTx, userName string) *unifiedRangePermission
}
}

func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd []byte, permtyp authpb.Permission_Type, partial bool) (*adt.IntervalTree, bool) {
if len(rangeEnd) == 1 && rangeEnd[0] == 0 {
rangeEnd = nil
}

ivl := adt.NewBytesAffineInterval(key, rangeEnd)

switch permtyp {
case authpb.READ:
return cachedPerms.readPerms.Contains(ivl)
if partial {
ranges := &adt.IntervalTree{}
ranges.Union(*cachedPerms.readPerms, ivl, true)
if ranges.Len() > 0 {
return ranges, true
}
} else {
return nil, cachedPerms.readPerms.Contains(ivl)
}
case authpb.WRITE:
return cachedPerms.writePerms.Contains(ivl)
if partial {
ranges := &adt.IntervalTree{}
ranges.Union(*cachedPerms.writePerms, ivl, true)
if ranges.Len() > 0 {
return ranges, true
}
} else {
return nil, cachedPerms.writePerms.Contains(ivl)
}
default:
plog.Panicf("unknown auth type: %v", permtyp)
}
return false
return nil, false
}

func checkKeyPoint(cachedPerms *unifiedRangePermissions, key []byte, permtyp authpb.Permission_Type) bool {
Expand All @@ -100,23 +117,23 @@ func checkKeyPoint(cachedPerms *unifiedRangePermissions, key []byte, permtyp aut
return false
}

func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool {
func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type, partial bool) (*adt.IntervalTree, bool) {
// assumption: tx is Lock()ed
_, ok := as.rangePermCache[userName]
if !ok {
perms := getMergedPerms(tx, userName)
if perms == nil {
plog.Errorf("failed to create a unified permission of user %s", userName)
return false
return nil, false
}
as.rangePermCache[userName] = perms
}

if len(rangeEnd) == 0 {
return checkKeyPoint(as.rangePermCache[userName], key, permtyp)
return nil, checkKeyPoint(as.rangePermCache[userName], key, permtyp)
}

return checkKeyInterval(as.rangePermCache[userName], key, rangeEnd, permtyp)
return checkKeyInterval(as.rangePermCache[userName], key, rangeEnd, permtyp, partial)
}

func (as *authStore) clearCachedPerm() {
Expand Down
2 changes: 1 addition & 1 deletion auth/range_perm_cache_test.go
Expand Up @@ -51,7 +51,7 @@ func TestRangePermission(t *testing.T) {
readPerms.Insert(p, struct{}{})
}

result := checkKeyInterval(&unifiedRangePermissions{readPerms: readPerms}, tt.begin, tt.end, authpb.READ)
_, result := checkKeyInterval(&unifiedRangePermissions{readPerms: readPerms}, tt.begin, tt.end, authpb.READ, false)
if result != tt.want {
t.Errorf("#%d: result=%t, want=%t", i, result, tt.want)
}
Expand Down
43 changes: 31 additions & 12 deletions auth/store.go
Expand Up @@ -26,7 +26,9 @@ import (

"github.com/coreos/etcd/auth/authpb"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/mvcc"
"github.com/coreos/etcd/mvcc/backend"
"github.com/coreos/etcd/pkg/adt"

"github.com/coreos/pkg/capnslog"
"golang.org/x/crypto/bcrypt"
Expand Down Expand Up @@ -148,6 +150,9 @@ type AuthStore interface {
// IsRangePermitted checks range permission of the user
IsRangePermitted(authInfo *AuthInfo, key, rangeEnd []byte) error

// Same as IsRangePermitted, but returns success if at least some keys within [key, rangeEnd) are permitted
IsRangePartiallyPermitted(authInfo *AuthInfo, key, rangeEnd []byte) (*adt.IntervalTree, error)

// IsDeleteRangePermitted checks delete-range permission of the user
IsDeleteRangePermitted(authInfo *AuthInfo, key, rangeEnd []byte) error

Expand Down Expand Up @@ -733,19 +738,26 @@ func (as *authStore) RoleGrantPermission(r *pb.AuthRoleGrantPermissionRequest) (
return &pb.AuthRoleGrantPermissionResponse{}, nil
}

func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeEnd []byte, permTyp authpb.Permission_Type) error {
func makePermittedRange(key, rangeEnd []byte, partial bool) *adt.IntervalTree {
if !partial || (len(rangeEnd) == 0) {
return nil
}
return mvcc.MakeWatchRange(key, rangeEnd)
}

func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeEnd []byte, permTyp authpb.Permission_Type, partial bool) (*adt.IntervalTree, error) {
// TODO(mitake): this function would be costly so we need a caching mechanism
if !as.isAuthEnabled() {
return nil
return makePermittedRange(key, rangeEnd, partial), nil
}

// only gets rev == 0 when passed AuthInfo{}; no user given
if revision == 0 {
return ErrUserEmpty
return nil, ErrUserEmpty
}

if revision < as.Revision() {
return ErrAuthOldRevision
return nil, ErrAuthOldRevision
}

tx := as.be.BatchTx()
Expand All @@ -755,31 +767,38 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE
user := getUser(tx, userName)
if user == nil {
plog.Errorf("invalid user name %s for permission checking", userName)
return ErrPermissionDenied
return nil, ErrPermissionDenied
}

// root role should have permission on all ranges
if hasRootRole(user) {
return nil
return makePermittedRange(key, rangeEnd, partial), nil
}

if as.isRangeOpPermitted(tx, userName, key, rangeEnd, permTyp) {
return nil
if ranges, ok := as.isRangeOpPermitted(tx, userName, key, rangeEnd, permTyp, partial); ok {
return ranges, nil
}

return ErrPermissionDenied
return nil, ErrPermissionDenied
}

func (as *authStore) IsPutPermitted(authInfo *AuthInfo, key []byte) error {
return as.isOpPermitted(authInfo.Username, authInfo.Revision, key, nil, authpb.WRITE)
_, err := as.isOpPermitted(authInfo.Username, authInfo.Revision, key, nil, authpb.WRITE, false)
return err
}

func (as *authStore) IsRangePermitted(authInfo *AuthInfo, key, rangeEnd []byte) error {
return as.isOpPermitted(authInfo.Username, authInfo.Revision, key, rangeEnd, authpb.READ)
_, err := as.isOpPermitted(authInfo.Username, authInfo.Revision, key, rangeEnd, authpb.READ, false)
return err
}

func (as *authStore) IsRangePartiallyPermitted(authInfo *AuthInfo, key, rangeEnd []byte) (*adt.IntervalTree, error) {
return as.isOpPermitted(authInfo.Username, authInfo.Revision, key, rangeEnd, authpb.READ, true)
}

func (as *authStore) IsDeleteRangePermitted(authInfo *AuthInfo, key, rangeEnd []byte) error {
return as.isOpPermitted(authInfo.Username, authInfo.Revision, key, rangeEnd, authpb.WRITE)
_, err := as.isOpPermitted(authInfo.Username, authInfo.Revision, key, rangeEnd, authpb.WRITE, false)
return err
}

func (as *authStore) IsAdminPermitted(authInfo *AuthInfo) error {
Expand Down
4 changes: 2 additions & 2 deletions etcdserver/api/v3rpc/key.go
Expand Up @@ -237,8 +237,8 @@ func checkIntervals(reqs []*pb.RequestOp) (map[string]struct{}, adt.IntervalTree
}
puts[k] = struct{}{}
}
dels.Union(delsThen, adt.NewStringAffineInterval("\x00", ""))
dels.Union(delsElse, adt.NewStringAffineInterval("\x00", ""))
dels.Union(delsThen, adt.NewStringAffineInterval("\x00", ""), false)
dels.Union(delsElse, adt.NewStringAffineInterval("\x00", ""), false)
}

// collect and check this level's puts
Expand Down
20 changes: 10 additions & 10 deletions etcdserver/api/v3rpc/watch.go
Expand Up @@ -26,6 +26,7 @@ import (
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/mvcc"
"github.com/coreos/etcd/mvcc/mvccpb"
"github.com/coreos/etcd/pkg/adt"
)

type watchServer struct {
Expand Down Expand Up @@ -162,17 +163,18 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
return err
}

func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) (*adt.IntervalTree, bool) {
authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
if err != nil {
return false
return nil, false
}
if authInfo == nil {
// if auth is enabled, IsRangePermitted() can cause an error
// if auth is enabled, IsRangePartiallyPermitted() can cause an error
authInfo = &auth.AuthInfo{}
}

return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
ranges, err := sws.ag.AuthStore().IsRangePartiallyPermitted(authInfo, wcr.Key, wcr.RangeEnd)
return ranges, err == nil
}

func (sws *serverWatchStream) recvLoop() error {
Expand Down Expand Up @@ -201,12 +203,9 @@ func (sws *serverWatchStream) recvLoop() error {
// between nil and []byte{} for single key / >=
creq.RangeEnd = nil
}
if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
// support >= key queries
creq.RangeEnd = []byte{}
}
ranges, ok := sws.isWatchPermitted(creq)

if !sws.isWatchPermitted(creq) {
if !ok {
wr := &pb.WatchResponse{
Header: sws.newResponseHeader(sws.watchStream.Rev()),
WatchId: -1,
Expand All @@ -229,7 +228,8 @@ func (sws *serverWatchStream) recvLoop() error {
if rev == 0 {
rev = wsrev + 1
}
id := sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev, filters...)

id := sws.watchStream.Watch(creq.Key, ranges, rev, filters...)
if id != -1 {
sws.mu.Lock()
if creq.ProgressNotify {
Expand Down
4 changes: 2 additions & 2 deletions mvcc/kv_test.go
Expand Up @@ -716,7 +716,7 @@ func TestWatchableKVWatch(t *testing.T) {
w := s.NewWatchStream()
defer w.Close()

wid := w.Watch([]byte("foo"), []byte("fop"), 0)
wid := w.Watch(nil, MakeWatchRange([]byte("foo"), []byte("fop")), 0)

wev := []mvccpb.Event{
{Type: mvccpb.PUT,
Expand Down Expand Up @@ -783,7 +783,7 @@ func TestWatchableKVWatch(t *testing.T) {
}

w = s.NewWatchStream()
wid = w.Watch([]byte("foo1"), []byte("foo2"), 3)
wid = w.Watch(nil, MakeWatchRange([]byte("foo1"), []byte("foo2")), 3)

select {
case resp := <-w.Chan():
Expand Down
13 changes: 7 additions & 6 deletions mvcc/watchable_store.go
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/coreos/etcd/lease"
"github.com/coreos/etcd/mvcc/backend"
"github.com/coreos/etcd/mvcc/mvccpb"
"github.com/coreos/etcd/pkg/adt"
)

// non-const so modifiable by tests
Expand All @@ -36,7 +37,7 @@ var (
)

type watchable interface {
watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc)
watch(key []byte, ranges *adt.IntervalTree, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc)
progress(w *watcher)
rev() int64
}
Expand Down Expand Up @@ -107,10 +108,10 @@ func (s *watchableStore) NewWatchStream() WatchStream {
}
}

func (s *watchableStore) watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc) {
func (s *watchableStore) watch(key []byte, ranges *adt.IntervalTree, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc) {
wa := &watcher{
key: key,
end: end,
ranges: ranges,
minRev: startRev,
id: id,
ch: ch,
Expand Down Expand Up @@ -473,9 +474,9 @@ func (s *watchableStore) progress(w *watcher) {
type watcher struct {
// the watcher key
key []byte
// end indicates the end of the range to watch.
// If end is set, the watcher is on a range.
end []byte
// If ranges is set, the watcher is on a range. ranges contain subranges
// of the desired range, carved out by read permissions.
ranges *adt.IntervalTree

// victim is set when ch is blocked and undergoing victim processing
victim bool
Expand Down
35 changes: 27 additions & 8 deletions mvcc/watcher.go
Expand Up @@ -15,11 +15,11 @@
package mvcc

import (
"bytes"
"errors"
"sync"

"github.com/coreos/etcd/mvcc/mvccpb"
"github.com/coreos/etcd/pkg/adt"
)

var (
Expand All @@ -33,15 +33,15 @@ type FilterFunc func(e mvccpb.Event) bool

type WatchStream interface {
// Watch creates a watcher. The watcher watches the events happening or
// happened on the given key or range [key, end) from the given startRev.
// happened on the given key or ranges from the given startRev.
//
// The whole event history can be watched unless compacted.
// If `startRev` <=0, watch observes events after currentRev.
//
// The returned `id` is the ID of this watcher. It appears as WatchID
// in events that are sent to the created watcher through stream channel.
//
Watch(key, end []byte, startRev int64, fcs ...FilterFunc) WatchID
Watch(key []byte, ranges *adt.IntervalTree, startRev int64, fcs ...FilterFunc) WatchID

// Chan returns a chan. All watch response will be sent to the returned chan.
Chan() <-chan WatchResponse
Expand Down Expand Up @@ -97,13 +97,32 @@ type watchStream struct {
watchers map[WatchID]*watcher
}

func MakeWatchRange(key, rangeEnd []byte) *adt.IntervalTree {
if len(rangeEnd) == 1 && rangeEnd[0] == 0 {
rangeEnd = nil
}
ranges := &adt.IntervalTree{}
ranges.Insert(adt.NewBytesAffineInterval(key, rangeEnd), struct{}{})
return ranges
}

// Watch creates a new watcher in the stream and returns its WatchID.
// TODO: return error if ws is closed?
func (ws *watchStream) Watch(key, end []byte, startRev int64, fcs ...FilterFunc) WatchID {
// prevent wrong range where key >= end lexicographically
func (ws *watchStream) Watch(key []byte, ranges *adt.IntervalTree, startRev int64, fcs ...FilterFunc) WatchID {
// prevent wrong ranges where start >= end lexicographically
// watch request with 'WithFromKey' has empty-byte range end
if len(end) != 0 && bytes.Compare(key, end) != -1 {
return -1
if ranges != nil {
numGoodRanges := 0
ranges.Visit(adt.NewBytesAffineInterval([]byte{0}, nil),
func(iv *adt.IntervalValue) bool {
if iv.Ivl.Begin.Compare(iv.Ivl.End) < 0 {
numGoodRanges++
}
return true
})
if (numGoodRanges != ranges.Len()) || (numGoodRanges == 0) {
return -1
}
}

ws.mu.Lock()
Expand All @@ -115,7 +134,7 @@ func (ws *watchStream) Watch(key, end []byte, startRev int64, fcs ...FilterFunc)
id := ws.nextID
ws.nextID++

w, c := ws.watchable.watch(key, end, startRev, id, ws.ch, fcs...)
w, c := ws.watchable.watch(key, ranges, startRev, id, ws.ch, fcs...)

ws.cancels[id] = c
ws.watchers[id] = w
Expand Down

0 comments on commit 4551c14

Please sign in to comment.