Skip to content

Commit

Permalink
ClusterMesh: validate etcd cluster ID
Browse files Browse the repository at this point in the history
[ upstream commit 174e721 ]

[ backporter's notes: backported a stripped down version of the upstream
  commit including the introduction of the interceptors only, as fixing
  a bug occurring in a single clustermesh-apiserver configuration as
  well (during rollouts), by preventing Cilium agents from incorrectly
  restarting an etcd watch against a different clustermesh-apiserver
  instance. ]

In a configuration where there are mutliple replicas of the
clustermesh-apiserver, each Pod runs its own etcd instance with a unique
cluster ID. This commit adds a `clusterLock` type, which is a wrapper
around a uint64 that can only be set once. `clusterLock` is used to
create gRPC unary and stream interceptors that are provided to the etcd
client to intercept and validate the cluster ID in the header of all
responses from the etcd server.

If the client receives a response from a different cluster, the
connection is terminated and restarted. This is designed to prevent
accepting responses from another cluster and potentially missing events
or retaining invalid data.

Since the addition of the interceptors allows quick detection of a
failover event, we no longer need to rely on endpoint status checks to
determine if the connection is healthy. Additionally, since service session
affinity can be unreliable, the status checks could trigger a false
failover event and cause a connection restart. To allow creating etcd
clients for ClusterMesh that do not perform endpoint status checks, the
option NoEndpointStatusChecks was added to ExtraOptions.

Signed-off-by: Tim Horner <timothy.horner@isovalent.com>
Signed-off-by: Marco Iorio <marco.iorio@isovalent.com>
  • Loading branch information
thorn3r authored and aanm committed Apr 18, 2024
1 parent 88be4ec commit 1e37ce6
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 22 deletions.
112 changes: 112 additions & 0 deletions pkg/clustermesh/common/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package common

import (
"context"
"errors"
"fmt"
"sync/atomic"

"go.etcd.io/etcd/api/v3/etcdserverpb"
"google.golang.org/grpc"
)

var (
ErrClusterIDChanged = errors.New("etcd cluster ID has changed")
ErrEtcdInvalidResponse = errors.New("received an invalid etcd response")
)

// newUnaryInterceptor returns a new unary client interceptor that validates the
// cluster ID of any received etcd responses.
func newUnaryInterceptor(cl *clusterLock) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
return err
}
return validateReply(cl, reply)
}
}

// newStreamInterceptor returns a new stream client interceptor that validates
// the cluster ID of any received etcd responses.
func newStreamInterceptor(cl *clusterLock) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
return &wrappedClientStream{
ClientStream: s,
clusterLock: cl,
}, nil
}
}

// wrappedClientStream is a wrapper around a grpc.ClientStream that adds
// validation for the etcd cluster ID
type wrappedClientStream struct {
grpc.ClientStream
clusterLock *clusterLock
}

// RecvMsg implements the grpc.ClientStream interface, adding validation for the etcd cluster ID
func (w *wrappedClientStream) RecvMsg(m interface{}) error {
if err := w.ClientStream.RecvMsg(m); err != nil {
return err
}

return validateReply(w.clusterLock, m)
}

type etcdResponse interface {
GetHeader() *etcdserverpb.ResponseHeader
}

func validateReply(cl *clusterLock, reply any) error {
resp, ok := reply.(etcdResponse)
if !ok || resp.GetHeader() == nil {
select {
case cl.errors <- ErrEtcdInvalidResponse:
default:
}
return ErrEtcdInvalidResponse
}

if err := cl.validateClusterId(resp.GetHeader().ClusterId); err != nil {
select {
case cl.errors <- err:
default:
}
return err
}
return nil
}

// clusterLock is a wrapper around an atomic uint64 that can only be set once. It
// provides validation for an etcd connection to ensure that it is only used
// for the same etcd cluster it was initially connected to. This is to prevent
// accidentally connecting to the wrong cluster in a high availability
// configuration utilizing mutiple active clusters.
type clusterLock struct {
etcdClusterID atomic.Uint64
errors chan error
}

func newClusterLock() *clusterLock {
return &clusterLock{
etcdClusterID: atomic.Uint64{},
errors: make(chan error, 1),
}
}

func (c *clusterLock) validateClusterId(clusterId uint64) error {
// If the cluster ID has not been set, set it to the received cluster ID
c.etcdClusterID.CompareAndSwap(0, clusterId)

if clusterId != c.etcdClusterID.Load() {
return fmt.Errorf("%w: expected %d, got %d", ErrClusterIDChanged, c.etcdClusterID.Load(), clusterId)
}
return nil
}
221 changes: 221 additions & 0 deletions pkg/clustermesh/common/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package common

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"google.golang.org/grpc"
)

type responseType int

const (
status responseType = iota
watch
leaseKeepAlive
leaseGrant
invalid
)

type mockClientStream struct {
grpc.ClientStream
toClient chan *etcdResponse
}

func newMockClientStream() mockClientStream {
return mockClientStream{
toClient: make(chan *etcdResponse),
}
}

func (c mockClientStream) RecvMsg(msg interface{}) error {
return nil
}

func (c mockClientStream) Send(resp *etcdResponse) error {
return nil
}

func newStreamerMock(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return newMockClientStream(), nil
}

func (u unaryResponder) recv() etcdResponse {
var resp unaryResponse
switch u.rt {
case status:
resp = unaryResponse{&etcdserverpb.StatusResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
case leaseGrant:
resp = unaryResponse{&etcdserverpb.LeaseGrantResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
case invalid:
resp = unaryResponse{&etcdserverpb.StatusResponse{}}
}

return resp
}

func (s streamResponder) recv() etcdResponse {
var resp streamResponse
switch s.rt {
case watch:
resp = streamResponse{&etcdserverpb.WatchResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
case leaseKeepAlive:
resp = streamResponse{&etcdserverpb.LeaseKeepAliveResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
case invalid:
resp = streamResponse{&etcdserverpb.WatchResponse{}}
}
return resp

}

func noopInvoker(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
}

type unaryResponder struct {
rt responseType
cid uint64
expError error
}

func (u unaryResponder) expectedErr() error {
return u.expError
}

type unaryResponse struct {
etcdResponse
}

type streamResponder struct {
rt responseType
cid uint64
expError error
}

func (s streamResponder) expectedErr() error {
return s.expError
}

type streamResponse struct {
etcdResponse
}

type mockResponder interface {
recv() etcdResponse
expectedErr() error
}

var maxId uint64 = 0xFFFFFFFFFFFFFFFF

func TestInterceptors(t *testing.T) {
tests := []struct {
name string
initialClusterId uint64
r []mockResponder
}{
{
name: "healthy stream responses",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "healthy unary responses",
initialClusterId: 1,
r: []mockResponder{
unaryResponder{rt: leaseGrant, cid: 1, expError: nil},
unaryResponder{rt: status, cid: 1, expError: nil},
},
},
{
name: "healthy stream and unary responses",
initialClusterId: maxId,
r: []mockResponder{
unaryResponder{rt: leaseGrant, cid: maxId, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: nil},
streamResponder{rt: watch, cid: maxId, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: nil},
streamResponder{rt: watch, cid: maxId, expError: nil},
},
},
{
name: "watch response from another cluster",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 2, expError: ErrClusterIDChanged},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "status response from another cluster",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: ErrClusterIDChanged},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "receive an invalid response with no header",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: leaseKeepAlive, cid: 1, expError: nil},
streamResponder{rt: invalid, cid: 0, expError: ErrEtcdInvalidResponse},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

cl := newClusterLock()
checkForError := func() error {
select {
case err := <-cl.errors:
return err
default:
return nil
}
}

si := newStreamInterceptor(cl)
desc := &grpc.StreamDesc{
StreamName: "test",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
}

cc := &grpc.ClientConn{}

stream, err := si(ctx, desc, cc, "test", newStreamerMock)
require.NoError(t, err)

unaryRecvMsg := newUnaryInterceptor(cl)
for _, responder := range tt.r {

switch response := responder.recv().(type) {
case unaryResponse:
unaryRecvMsg(ctx, "test", nil, response, cc, noopInvoker)
case streamResponse:
stream.RecvMsg(responder.recv())
}
require.ErrorIs(t, checkForError(), responder.expectedErr())
require.Equal(t, tt.initialClusterId, cl.etcdClusterID.Load())
}
})
}

}
Loading

0 comments on commit 1e37ce6

Please sign in to comment.