Skip to content

Commit

Permalink
Validate etcd paths
Browse files Browse the repository at this point in the history
Kubernetes-commit: 6775c99cd008c457ce3eed401ac1c60c3812dbfa
  • Loading branch information
tallclair authored and k8s-publishing-bot committed Oct 11, 2022
1 parent c19c82e commit 20e3df6
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 89 deletions.
5 changes: 3 additions & 2 deletions pkg/storage/etcd3/linearized_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func TestLinearizedReadRevisionInvariant(t *testing.T) {
// [1] https://etcd.io/docs/v3.5/learning/api_guarantees/#isolation-level-and-consistency-of-replicas
ctx, store, etcdClient := testSetup(t)

key := "/testkey"
dir := "/testing"
key := dir + "/testkey"
out := &example.Pod{}
obj := &example.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo", SelfLink: "testlink"}}

Expand All @@ -53,7 +54,7 @@ func TestLinearizedReadRevisionInvariant(t *testing.T) {
}

list := &example.PodList{}
if err := store.GetList(ctx, "/", storage.ListOptions{Predicate: storage.Everything, Recursive: true}, list); err != nil {
if err := store.GetList(ctx, dir, storage.ListOptions{Predicate: storage.Everything, Recursive: true}, list); err != nil {
t.Errorf("Unexpected List error: %v", err)
}
finalRevision := list.ResourceVersion
Expand Down
138 changes: 94 additions & 44 deletions pkg/storage/etcd3/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,21 @@ func New(c *clientv3.Client, codec runtime.Codec, newFunc func() runtime.Object,

func newStore(c *clientv3.Client, codec runtime.Codec, newFunc func() runtime.Object, prefix string, groupResource schema.GroupResource, transformer value.Transformer, pagingEnabled bool, leaseManagerConfig LeaseManagerConfig) *store {
versioner := storage.APIObjectVersioner{}
// for compatibility with etcd2 impl.
// no-op for default prefix of '/registry'.
// keeps compatibility with etcd2 impl for custom prefixes that don't start with '/'
pathPrefix := path.Join("/", prefix)
if !strings.HasSuffix(pathPrefix, "/") {
// Ensure the pathPrefix ends in "/" here to simplify key concatenation later.
pathPrefix += "/"
}
result := &store{
client: c,
codec: codec,
versioner: versioner,
transformer: transformer,
pagingEnabled: pagingEnabled,
// for compatibility with etcd2 impl.
// no-op for default prefix of '/registry'.
// keeps compatibility with etcd2 impl for custom prefixes that don't start with '/'
pathPrefix: path.Join("/", prefix),
client: c,
codec: codec,
versioner: versioner,
transformer: transformer,
pagingEnabled: pagingEnabled,
pathPrefix: pathPrefix,
groupResource: groupResource,
groupResourceString: groupResource.String(),
watcher: newWatcher(c, codec, newFunc, versioner, transformer),
Expand All @@ -123,9 +128,12 @@ func (s *store) Versioner() storage.Versioner {

// Get implements storage.Interface.Get.
func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, out runtime.Object) error {
key = path.Join(s.pathPrefix, key)
preparedKey, err := s.prepareKey(key)
if err != nil {
return err
}
startTime := time.Now()
getResp, err := s.client.KV.Get(ctx, key)
getResp, err := s.client.KV.Get(ctx, preparedKey)
metrics.RecordEtcdRequestLatency("get", getTypeName(out), startTime)
if err != nil {
return err
Expand All @@ -138,11 +146,11 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou
if opts.IgnoreNotFound {
return runtime.SetZeroValue(out)
}
return storage.NewKeyNotFoundError(key, 0)
return storage.NewKeyNotFoundError(preparedKey, 0)
}
kv := getResp.Kvs[0]

data, _, err := s.transformer.TransformFromStorage(ctx, kv.Value, authenticatedDataString(key))
data, _, err := s.transformer.TransformFromStorage(ctx, kv.Value, authenticatedDataString(preparedKey))
if err != nil {
return storage.NewInternalError(err.Error())
}
Expand All @@ -152,6 +160,10 @@ func (s *store) Get(ctx context.Context, key string, opts storage.GetOptions, ou

// Create implements storage.Interface.Create.
func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object, ttl uint64) error {
preparedKey, err := s.prepareKey(key)
if err != nil {
return err
}
trace := utiltrace.New("Create etcd3",
utiltrace.Field{"audit-id", endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{"key", key},
Expand All @@ -170,24 +182,23 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object,
if err != nil {
return err
}
key = path.Join(s.pathPrefix, key)

opts, err := s.ttlOpts(ctx, int64(ttl))
if err != nil {
return err
}

newData, err := s.transformer.TransformToStorage(ctx, data, authenticatedDataString(key))
newData, err := s.transformer.TransformToStorage(ctx, data, authenticatedDataString(preparedKey))
trace.Step("TransformToStorage finished", utiltrace.Field{"err", err})
if err != nil {
return storage.NewInternalError(err.Error())
}

startTime := time.Now()
txnResp, err := s.client.KV.Txn(ctx).If(
notFound(key),
notFound(preparedKey),
).Then(
clientv3.OpPut(key, string(newData), opts...),
clientv3.OpPut(preparedKey, string(newData), opts...),
).Commit()
metrics.RecordEtcdRequestLatency("create", getTypeName(obj), startTime)
trace.Step("Txn call finished", utiltrace.Field{"err", err})
Expand All @@ -196,7 +207,7 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object,
}

if !txnResp.Succeeded {
return storage.NewKeyExistsError(key, 0)
return storage.NewKeyExistsError(preparedKey, 0)
}

if out != nil {
Expand All @@ -212,12 +223,15 @@ func (s *store) Create(ctx context.Context, key string, obj, out runtime.Object,
func (s *store) Delete(
ctx context.Context, key string, out runtime.Object, preconditions *storage.Preconditions,
validateDeletion storage.ValidateObjectFunc, cachedExistingObject runtime.Object) error {
preparedKey, err := s.prepareKey(key)
if err != nil {
return err
}
v, err := conversion.EnforcePtr(out)
if err != nil {
return fmt.Errorf("unable to convert output object to pointer: %v", err)
}
key = path.Join(s.pathPrefix, key)
return s.conditionalDelete(ctx, key, out, v, preconditions, validateDeletion, cachedExistingObject)
return s.conditionalDelete(ctx, preparedKey, out, v, preconditions, validateDeletion, cachedExistingObject)
}

func (s *store) conditionalDelete(
Expand Down Expand Up @@ -330,6 +344,10 @@ func (s *store) conditionalDelete(
func (s *store) GuaranteedUpdate(
ctx context.Context, key string, destination runtime.Object, ignoreNotFound bool,
preconditions *storage.Preconditions, tryUpdate storage.UpdateFunc, cachedExistingObject runtime.Object) error {
preparedKey, err := s.prepareKey(key)
if err != nil {
return err
}
trace := utiltrace.New("GuaranteedUpdate etcd3",
utiltrace.Field{"audit-id", endpointsrequest.GetAuditIDTruncated(ctx)},
utiltrace.Field{"key", key},
Expand All @@ -340,16 +358,15 @@ func (s *store) GuaranteedUpdate(
if err != nil {
return fmt.Errorf("unable to convert output object to pointer: %v", err)
}
key = path.Join(s.pathPrefix, key)

getCurrentState := func() (*objState, error) {
startTime := time.Now()
getResp, err := s.client.KV.Get(ctx, key)
getResp, err := s.client.KV.Get(ctx, preparedKey)
metrics.RecordEtcdRequestLatency("get", getTypeName(destination), startTime)
if err != nil {
return nil, err
}
return s.getState(ctx, getResp, key, v, ignoreNotFound)
return s.getState(ctx, getResp, preparedKey, v, ignoreNotFound)
}

var origState *objState
Expand All @@ -365,9 +382,9 @@ func (s *store) GuaranteedUpdate(
}
trace.Step("initial value restored")

transformContext := authenticatedDataString(key)
transformContext := authenticatedDataString(preparedKey)
for {
if err := preconditions.Check(key, origState.obj); err != nil {
if err := preconditions.Check(preparedKey, origState.obj); err != nil {
// If our data is already up to date, return the error
if origStateIsCurrent {
return err
Expand Down Expand Up @@ -453,11 +470,11 @@ func (s *store) GuaranteedUpdate(

startTime := time.Now()
txnResp, err := s.client.KV.Txn(ctx).If(
clientv3.Compare(clientv3.ModRevision(key), "=", origState.rev),
clientv3.Compare(clientv3.ModRevision(preparedKey), "=", origState.rev),
).Then(
clientv3.OpPut(key, string(newData), opts...),
clientv3.OpPut(preparedKey, string(newData), opts...),
).Else(
clientv3.OpGet(key),
clientv3.OpGet(preparedKey),
).Commit()
metrics.RecordEtcdRequestLatency("update", getTypeName(destination), startTime)
trace.Step("Txn call finished", utiltrace.Field{"err", err})
Expand All @@ -467,8 +484,8 @@ func (s *store) GuaranteedUpdate(
trace.Step("Transaction committed")
if !txnResp.Succeeded {
getResp := (*clientv3.GetResponse)(txnResp.Responses[0].GetResponseRange())
klog.V(4).Infof("GuaranteedUpdate of %s failed because of a conflict, going to retry", key)
origState, err = s.getState(ctx, getResp, key, v, ignoreNotFound)
klog.V(4).Infof("GuaranteedUpdate of %s failed because of a conflict, going to retry", preparedKey)
origState, err = s.getState(ctx, getResp, preparedKey, v, ignoreNotFound)
if err != nil {
return err
}
Expand Down Expand Up @@ -502,18 +519,21 @@ func getNewItemFunc(listObj runtime.Object, v reflect.Value) func() runtime.Obje
}

func (s *store) Count(key string) (int64, error) {
key = path.Join(s.pathPrefix, key)
preparedKey, err := s.prepareKey(key)
if err != nil {
return 0, err
}

// We need to make sure the key ended with "/" so that we only get children "directories".
// e.g. if we have key "/a", "/a/b", "/ab", getting keys with prefix "/a" will return all three,
// while with prefix "/a/" will return only "/a/b" which is the correct answer.
if !strings.HasSuffix(key, "/") {
key += "/"
if !strings.HasSuffix(preparedKey, "/") {
preparedKey += "/"
}

startTime := time.Now()
getResp, err := s.client.KV.Get(context.Background(), key, clientv3.WithRange(clientv3.GetPrefixRangeEnd(key)), clientv3.WithCountOnly())
metrics.RecordEtcdRequestLatency("listWithCount", key, startTime)
getResp, err := s.client.KV.Get(context.Background(), preparedKey, clientv3.WithRange(clientv3.GetPrefixRangeEnd(preparedKey)), clientv3.WithCountOnly())
metrics.RecordEtcdRequestLatency("listWithCount", preparedKey, startTime)
if err != nil {
return 0, err
}
Expand All @@ -522,6 +542,10 @@ func (s *store) Count(key string) (int64, error) {

// GetList implements storage.Interface.
func (s *store) GetList(ctx context.Context, key string, opts storage.ListOptions, listObj runtime.Object) error {
preparedKey, err := s.prepareKey(key)
if err != nil {
return err
}
recursive := opts.Recursive
resourceVersion := opts.ResourceVersion
match := opts.ResourceVersionMatch
Expand All @@ -542,16 +566,15 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption
if err != nil || v.Kind() != reflect.Slice {
return fmt.Errorf("need ptr to slice: %v", err)
}
key = path.Join(s.pathPrefix, key)

// For recursive lists, we need to make sure the key ended with "/" so that we only
// get children "directories". e.g. if we have key "/a", "/a/b", "/ab", getting keys
// with prefix "/a" will return all three, while with prefix "/a/" will return only
// "/a/b" which is the correct answer.
if recursive && !strings.HasSuffix(key, "/") {
key += "/"
if recursive && !strings.HasSuffix(preparedKey, "/") {
preparedKey += "/"
}
keyPrefix := key
keyPrefix := preparedKey

// set the appropriate clientv3 options to filter the returned data set
var limitOption *clientv3.OpOption
Expand Down Expand Up @@ -590,7 +613,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption

rangeEnd := clientv3.GetPrefixRangeEnd(keyPrefix)
options = append(options, clientv3.WithRange(rangeEnd))
key = continueKey
preparedKey = continueKey

// If continueRV > 0, the LIST request needs a specific resource version.
// continueRV==0 is invalid.
Expand Down Expand Up @@ -657,7 +680,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption
}()
for {
startTime := time.Now()
getResp, err = s.client.KV.Get(ctx, key, options...)
getResp, err = s.client.KV.Get(ctx, preparedKey, options...)
if recursive {
metrics.RecordEtcdRequestLatency("list", getTypeName(listPtr), startTime)
} else {
Expand Down Expand Up @@ -729,7 +752,7 @@ func (s *store) GetList(ctx context.Context, key string, opts storage.ListOption
}
*limitOption = clientv3.WithLimit(limit)
}
key = string(lastKey) + "\x00"
preparedKey = string(lastKey) + "\x00"
if withRev == 0 {
withRev = returnedRV
options = append(options, clientv3.WithRev(withRev))
Expand Down Expand Up @@ -794,12 +817,15 @@ func growSlice(v reflect.Value, maxCapacity int, sizes ...int) {

// Watch implements storage.Interface.Watch.
func (s *store) Watch(ctx context.Context, key string, opts storage.ListOptions) (watch.Interface, error) {
preparedKey, err := s.prepareKey(key)
if err != nil {
return nil, err
}
rev, err := s.versioner.ParseResourceVersion(opts.ResourceVersion)
if err != nil {
return nil, err
}
key = path.Join(s.pathPrefix, key)
return s.watcher.Watch(ctx, key, int64(rev), opts.Recursive, opts.ProgressNotify, opts.Predicate)
return s.watcher.Watch(ctx, preparedKey, int64(rev), opts.Recursive, opts.ProgressNotify, opts.Predicate)
}

func (s *store) getState(ctx context.Context, getResp *clientv3.GetResponse, key string, v reflect.Value, ignoreNotFound bool) (*objState, error) {
Expand Down Expand Up @@ -911,6 +937,30 @@ func (s *store) validateMinimumResourceVersion(minimumResourceVersion string, ac
return nil
}

func (s *store) prepareKey(key string) (string, error) {
if key == ".." ||
strings.HasPrefix(key, "../") ||
strings.HasSuffix(key, "/..") ||
strings.Contains(key, "/../") {
return "", fmt.Errorf("invalid key: %q", key)
}
if key == "." ||
strings.HasPrefix(key, "./") ||
strings.HasSuffix(key, "/.") ||
strings.Contains(key, "/./") {
return "", fmt.Errorf("invalid key: %q", key)
}
if key == "" || key == "/" {
return "", fmt.Errorf("empty key: %q", key)
}
// We ensured that pathPrefix ends in '/' in construction, so skip any leading '/' in the key now.
startIndex := 0
if key[0] == '/' {
startIndex = 1
}
return s.pathPrefix + key[startIndex:], nil
}

// decode decodes value of bytes into object. It will also set the object resource version to rev.
// On success, objPtr would be set to the object.
func decode(codec runtime.Codec, versioner storage.Versioner, value []byte, objPtr runtime.Object, rev int64) error {
Expand Down
Loading

0 comments on commit 20e3df6

Please sign in to comment.