Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1.15] Prevent Cilium agents from incorrectly restarting an etcd watch against a different clustermesh-apiserver instance. #32005

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ spec:
{{- end }}
# These need to match the equivalent arguments to etcd in the main container.
- --etcd-cluster-name=clustermesh-apiserver
- --etcd-initial-cluster-token=clustermesh-apiserver
- --etcd-initial-cluster-token=$(INITIAL_CLUSTER_TOKEN)
- --etcd-data-dir=/var/run/etcd
{{- with .Values.clustermesh.apiserver.etcd.init.extraArgs }}
{{- toYaml . | trim | nindent 8 }}
Expand All @@ -76,6 +76,10 @@ spec:
configMapKeyRef:
name: cilium-config
key: cluster-name
- name: INITIAL_CLUSTER_TOKEN
valueFrom:
fieldRef:
fieldPath: metadata.uid
{{- with .Values.clustermesh.apiserver.etcd.init.extraEnv }}
{{- toYaml . | trim | nindent 8 }}
{{- end }}
Expand Down Expand Up @@ -108,7 +112,7 @@ spec:
# uses net.SplitHostPort() internally and it accepts the that format.
- --listen-client-urls=https://127.0.0.1:2379,https://[$(HOSTNAME_IP)]:2379
- --advertise-client-urls=https://[$(HOSTNAME_IP)]:2379
- --initial-cluster-token=clustermesh-apiserver
- --initial-cluster-token=$(INITIAL_CLUSTER_TOKEN)
- --auto-compaction-retention=1
{{- if .Values.clustermesh.apiserver.metrics.etcd.enabled }}
- --listen-metrics-urls=http://[$(HOSTNAME_IP)]:{{ .Values.clustermesh.apiserver.metrics.etcd.port }}
Expand All @@ -121,6 +125,10 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: INITIAL_CLUSTER_TOKEN
valueFrom:
fieldRef:
fieldPath: metadata.uid
ports:
- name: etcd
containerPort: 2379
Expand Down
112 changes: 112 additions & 0 deletions pkg/clustermesh/common/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package common

import (
"context"
"errors"
"fmt"
"sync/atomic"

"go.etcd.io/etcd/api/v3/etcdserverpb"
"google.golang.org/grpc"
)

var (
ErrClusterIDChanged = errors.New("etcd cluster ID has changed")
ErrEtcdInvalidResponse = errors.New("received an invalid etcd response")
)

// newUnaryInterceptor returns a new unary client interceptor that validates the
// cluster ID of any received etcd responses.
func newUnaryInterceptor(cl *clusterLock) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
return err
}
return validateReply(cl, reply)
}
}

// newStreamInterceptor returns a new stream client interceptor that validates
// the cluster ID of any received etcd responses.
func newStreamInterceptor(cl *clusterLock) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
return &wrappedClientStream{
ClientStream: s,
clusterLock: cl,
}, nil
}
}

// wrappedClientStream is a wrapper around a grpc.ClientStream that adds
// validation for the etcd cluster ID
type wrappedClientStream struct {
grpc.ClientStream
clusterLock *clusterLock
}

// RecvMsg implements the grpc.ClientStream interface, adding validation for the etcd cluster ID
func (w *wrappedClientStream) RecvMsg(m interface{}) error {
if err := w.ClientStream.RecvMsg(m); err != nil {
return err
}

return validateReply(w.clusterLock, m)
}

type etcdResponse interface {
GetHeader() *etcdserverpb.ResponseHeader
}

func validateReply(cl *clusterLock, reply any) error {
resp, ok := reply.(etcdResponse)
if !ok || resp.GetHeader() == nil {
select {
case cl.errors <- ErrEtcdInvalidResponse:
default:
}
return ErrEtcdInvalidResponse
}

if err := cl.validateClusterId(resp.GetHeader().ClusterId); err != nil {
select {
case cl.errors <- err:
default:
}
return err
}
return nil
}

// clusterLock is a wrapper around an atomic uint64 that can only be set once. It
// provides validation for an etcd connection to ensure that it is only used
// for the same etcd cluster it was initially connected to. This is to prevent
// accidentally connecting to the wrong cluster in a high availability
// configuration utilizing mutiple active clusters.
type clusterLock struct {
etcdClusterID atomic.Uint64
errors chan error
}

func newClusterLock() *clusterLock {
return &clusterLock{
etcdClusterID: atomic.Uint64{},
errors: make(chan error, 1),
}
}

func (c *clusterLock) validateClusterId(clusterId uint64) error {
// If the cluster ID has not been set, set it to the received cluster ID
c.etcdClusterID.CompareAndSwap(0, clusterId)

if clusterId != c.etcdClusterID.Load() {
return fmt.Errorf("%w: expected %d, got %d", ErrClusterIDChanged, c.etcdClusterID.Load(), clusterId)
}
return nil
}
221 changes: 221 additions & 0 deletions pkg/clustermesh/common/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package common

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"google.golang.org/grpc"
)

type responseType int

const (
status responseType = iota
watch
leaseKeepAlive
leaseGrant
invalid
)

type mockClientStream struct {
grpc.ClientStream
toClient chan *etcdResponse
}

func newMockClientStream() mockClientStream {
return mockClientStream{
toClient: make(chan *etcdResponse),
}
}

func (c mockClientStream) RecvMsg(msg interface{}) error {
return nil
}

func (c mockClientStream) Send(resp *etcdResponse) error {
return nil
}

func newStreamerMock(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return newMockClientStream(), nil
}

func (u unaryResponder) recv() etcdResponse {
var resp unaryResponse
switch u.rt {
case status:
resp = unaryResponse{&etcdserverpb.StatusResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
case leaseGrant:
resp = unaryResponse{&etcdserverpb.LeaseGrantResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
case invalid:
resp = unaryResponse{&etcdserverpb.StatusResponse{}}
}

return resp
}

func (s streamResponder) recv() etcdResponse {
var resp streamResponse
switch s.rt {
case watch:
resp = streamResponse{&etcdserverpb.WatchResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
case leaseKeepAlive:
resp = streamResponse{&etcdserverpb.LeaseKeepAliveResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
case invalid:
resp = streamResponse{&etcdserverpb.WatchResponse{}}
}
return resp

}

func noopInvoker(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
}

type unaryResponder struct {
rt responseType
cid uint64
expError error
}

func (u unaryResponder) expectedErr() error {
return u.expError
}

type unaryResponse struct {
etcdResponse
}

type streamResponder struct {
rt responseType
cid uint64
expError error
}

func (s streamResponder) expectedErr() error {
return s.expError
}

type streamResponse struct {
etcdResponse
}

type mockResponder interface {
recv() etcdResponse
expectedErr() error
}

var maxId uint64 = 0xFFFFFFFFFFFFFFFF

func TestInterceptors(t *testing.T) {
tests := []struct {
name string
initialClusterId uint64
r []mockResponder
}{
{
name: "healthy stream responses",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "healthy unary responses",
initialClusterId: 1,
r: []mockResponder{
unaryResponder{rt: leaseGrant, cid: 1, expError: nil},
unaryResponder{rt: status, cid: 1, expError: nil},
},
},
{
name: "healthy stream and unary responses",
initialClusterId: maxId,
r: []mockResponder{
unaryResponder{rt: leaseGrant, cid: maxId, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: nil},
streamResponder{rt: watch, cid: maxId, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: nil},
streamResponder{rt: watch, cid: maxId, expError: nil},
},
},
{
name: "watch response from another cluster",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
streamResponder{rt: watch, cid: 2, expError: ErrClusterIDChanged},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "status response from another cluster",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: watch, cid: 1, expError: nil},
unaryResponder{rt: status, cid: maxId, expError: ErrClusterIDChanged},
streamResponder{rt: watch, cid: 1, expError: nil},
},
},
{
name: "receive an invalid response with no header",
initialClusterId: 1,
r: []mockResponder{
streamResponder{rt: leaseKeepAlive, cid: 1, expError: nil},
streamResponder{rt: invalid, cid: 0, expError: ErrEtcdInvalidResponse},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

cl := newClusterLock()
checkForError := func() error {
select {
case err := <-cl.errors:
return err
default:
return nil
}
}

si := newStreamInterceptor(cl)
desc := &grpc.StreamDesc{
StreamName: "test",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
}

cc := &grpc.ClientConn{}

stream, err := si(ctx, desc, cc, "test", newStreamerMock)
require.NoError(t, err)

unaryRecvMsg := newUnaryInterceptor(cl)
for _, responder := range tt.r {

switch response := responder.recv().(type) {
case unaryResponse:
unaryRecvMsg(ctx, "test", nil, response, cc, noopInvoker)
case streamResponse:
stream.RecvMsg(responder.recv())
}
require.ErrorIs(t, checkForError(), responder.expectedErr())
require.Equal(t, tt.initialClusterId, cl.etcdClusterID.Load())
}
})
}

}
Loading
Loading