From 907bdaa1eb138ac31922a41fe7477476f583bddb Mon Sep 17 00:00:00 2001 From: Matthew Stevenson <52979934+matthewstevenson88@users.noreply.github.com> Date: Wed, 7 Jun 2023 18:54:06 -0700 Subject: [PATCH] alts: Read max number of concurrent ALTS handshakes from environment variable. (#6267) * Read max number of concurrent ALTS handshakes from environment variable. * Refactor to use new envconfig file. * Remove impossible if condition in acquire(). * Use weighted semaphore. * Add e2e test for concurrent ALTS handshakes. * Separate into client and server semaphores. * Use TryAcquire instead of Acquire. * Attempt to fix go.sum error. * Run go mod tidy compat=1.17. * Update go.mod for examples subdirectory. * Run go mod tidy -compat=1.17 on examples subdirectory. * Update go.mod in subdirectories. * Update go.mod in security/advancedtls/examples. * Missed another go.mod update. * Do not upgrade glog because it requires Golang 1.19. * Fix glog version in examples/go.sum. * More glog cleanup. * Fix glog issue in gcp/observability/go.sum. * Move ALTS env var into envconfig.go. * Fix go.mod files. * Revert go.sum files. * Revert interop/observability/go.mod change. * Run go mod tidy -compat=1.17 on examples/. * Run gofmt. * Add comment describing test init function. --- credentials/alts/alts_test.go | 91 +++++++++++++------ .../alts/internal/handshaker/handshaker.go | 57 ++++-------- .../internal/handshaker/handshaker_test.go | 13 +-- examples/go.mod | 1 + examples/go.sum | 1 + go.mod | 1 + go.sum | 1 + internal/envconfig/envconfig.go | 3 + 8 files changed, 95 insertions(+), 73 deletions(-) diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index 9a95d462806..20062fe7753 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -31,6 +31,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/alts/internal/handshaker" "google.golang.org/grpc/credentials/alts/internal/handshaker/service" altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" @@ -51,6 +52,14 @@ type s struct { grpctest.Tester } +func init() { + // The vmOnGCP global variable MUST be forced to true. Otherwise, if + // this test is run anywhere except on a GCP VM, then an ALTS handshake + // will immediately fail. + once.Do(func() {}) + vmOnGCP = true +} + func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } @@ -308,14 +317,6 @@ func (s) TestCheckRPCVersions(t *testing.T) { // server, where both client and server offload to a local, fake handshaker // service. func (s) TestFullHandshake(t *testing.T) { - // The vmOnGCP global variable MUST be reset to true after the client - // or server credentials have been created, but before the ALTS - // handshake begins. If vmOnGCP is not reset and this test is run - // anywhere except for a GCP VM, then the ALTS handshake will - // immediately fail. - once.Do(func() {}) - vmOnGCP = true - // Start the fake handshaker service and the server. var wait sync.WaitGroup defer wait.Wait() @@ -325,26 +326,41 @@ func (s) TestFullHandshake(t *testing.T) { defer stopServer() // Ping the server, authenticating with ALTS. - clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress}) - conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds)) - if err != nil { - t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err) + establishAltsConnection(t, handshakerAddress, serverAddress) + + // Close open connections to the fake handshaker service. + if err := service.CloseForTesting(); err != nil { + t.Errorf("service.CloseForTesting() failed: %v", err) } - defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout) - defer cancel() - c := testgrpc.NewTestServiceClient(conn) - for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { - _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}) - if err == nil { - break - } - if code := status.Code(err); code == codes.Unavailable { - // The server is not ready yet. Try again. - continue - } - t.Fatalf("c.UnaryCall() failed: %v", err) +} + +// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes +// between a test client and server, where both client and server offload to a +// local, fake handshaker service. +func (s) TestConcurrentHandshakes(t *testing.T) { + // Set the max number of concurrent handshakes to 3, so that we can + // test the handshaker behavior when handshakes are queued by + // performing more than 3 concurrent handshakes (specifically, 10). + handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3) + + // Start the fake handshaker service and the server. + var wait sync.WaitGroup + defer wait.Wait() + stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait) + defer stopHandshaker() + stopServer, serverAddress := startServer(t, handshakerAddress, &wait) + defer stopServer() + + // Ping the server, authenticating with ALTS. + var waitForConnections sync.WaitGroup + for i := 0; i < 10; i++ { + waitForConnections.Add(1) + go func() { + establishAltsConnection(t, handshakerAddress, serverAddress) + waitForConnections.Done() + }() } + waitForConnections.Wait() // Close open connections to the fake handshaker service. if err := service.CloseForTesting(); err != nil { @@ -366,6 +382,29 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol } } +func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) { + clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress}) + conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err) + } + defer conn.Close() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout) + defer cancel() + c := testgrpc.NewTestServiceClient(conn) + for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { + _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}) + if err == nil { + break + } + if code := status.Code(err); code == codes.Unavailable { + // The server is not ready yet. Try again. + continue + } + t.Fatalf("c.UnaryCall() failed: %v", err) + } +} + func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) { listener, err := testutils.LocalTCPListener() if err != nil { diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index 150ae557676..0854e7af651 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -25,8 +25,8 @@ import ( "fmt" "io" "net" - "sync" + "golang.org/x/sync/semaphore" grpc "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -35,15 +35,13 @@ import ( "google.golang.org/grpc/credentials/alts/internal/conn" altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" + "google.golang.org/grpc/internal/envconfig" ) const ( // The maximum byte size of receive frames. frameLimit = 64 * 1024 // 64 KB rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY" - // maxPendingHandshakes represents the maximum number of concurrent - // handshakes. - maxPendingHandshakes = 100 ) var ( @@ -59,9 +57,9 @@ var ( return conn.NewAES128GCMRekey(s, keyData) }, } - // control number of concurrent created (but not closed) handshakers. - mu sync.Mutex - concurrentHandshakes = int64(0) + // control number of concurrent created (but not closed) handshakes. + clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes)) + serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes)) // errDropped occurs when maxPendingHandshakes is reached. errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached") // errOutOfBound occurs when the handshake service returns a consumed @@ -77,30 +75,6 @@ func init() { } } -func acquire() bool { - mu.Lock() - // If we need n to be configurable, we can pass it as an argument. - n := int64(1) - success := maxPendingHandshakes-concurrentHandshakes >= n - if success { - concurrentHandshakes += n - } - mu.Unlock() - return success -} - -func release() { - mu.Lock() - // If we need n to be configurable, we can pass it as an argument. - n := int64(1) - concurrentHandshakes -= n - if concurrentHandshakes < 0 { - mu.Unlock() - panic("bad release") - } - mu.Unlock() -} - // ClientHandshakerOptions contains the client handshaker options that can // provided by the caller. type ClientHandshakerOptions struct { @@ -134,10 +108,6 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions { return &ServerHandshakerOptions{} } -// TODO: add support for future local and remote endpoint in both client options -// and server options (server options struct does not exist now. When -// caller can provide endpoints, it should be created. - // altsHandshaker is used to complete an ALTS handshake between client and // server. This handshaker talks to the ALTS handshaker service in the metadata // server. @@ -185,10 +155,10 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, // ClientHandshake starts and completes a client ALTS handshake for GCP. Once // done, ClientHandshake returns a secure connection. func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { - if !acquire() { + if !clientHandshakes.TryAcquire(1) { return nil, nil, errDropped } - defer release() + defer clientHandshakes.Release(1) if h.side != core.ClientSide { return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker") @@ -238,10 +208,10 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent // ServerHandshake starts and completes a server ALTS handshake for GCP. Once // done, ServerHandshake returns a secure connection. func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { - if !acquire() { + if !serverHandshakes.TryAcquire(1) { return nil, nil, errDropped } - defer release() + defer serverHandshakes.Release(1) if h.side != core.ServerSide { return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker") @@ -264,8 +234,6 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent } // Prepare server parameters. - // TODO: currently only ALTS parameters are provided. Might need to use - // more options in the future. params := make(map[int32]*altspb.ServerHandshakeParameters) params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{ RecordProtocols: recordProtocols, @@ -391,3 +359,10 @@ func (h *altsHandshaker) Close() { h.stream.CloseSend() } } + +// ResetConcurrentHandshakeSemaphoreForTesting resets the handshake semaphores +// to allow numberOfAllowedHandshakes concurrent handshakes each. +func ResetConcurrentHandshakeSemaphoreForTesting(numberOfAllowedHandshakes int64) { + clientHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes) + serverHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes) +} diff --git a/credentials/alts/internal/handshaker/handshaker_test.go b/credentials/alts/internal/handshaker/handshaker_test.go index 49f07caf8de..40d66161c7b 100644 --- a/credentials/alts/internal/handshaker/handshaker_test.go +++ b/credentials/alts/internal/handshaker/handshaker_test.go @@ -31,6 +31,7 @@ import ( core "google.golang.org/grpc/credentials/alts/internal" altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp" "google.golang.org/grpc/credentials/alts/internal/testutil" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpctest" ) @@ -134,7 +135,7 @@ func (s) TestClientHandshake(t *testing.T) { numberOfHandshakes int }{ {0 * time.Millisecond, 1}, - {100 * time.Millisecond, 10 * maxPendingHandshakes}, + {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)}, } { errc := make(chan error) stat.Reset() @@ -182,8 +183,8 @@ func (s) TestClientHandshake(t *testing.T) { } // Ensure that there are no concurrent calls more than the limit. - if stat.MaxConcurrentCalls > maxPendingHandshakes { - t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) + if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { + t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) } } } @@ -194,7 +195,7 @@ func (s) TestServerHandshake(t *testing.T) { numberOfHandshakes int }{ {0 * time.Millisecond, 1}, - {100 * time.Millisecond, 10 * maxPendingHandshakes}, + {100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)}, } { errc := make(chan error) stat.Reset() @@ -239,8 +240,8 @@ func (s) TestServerHandshake(t *testing.T) { } // Ensure that there are no concurrent calls more than the limit. - if stat.MaxConcurrentCalls > maxPendingHandshakes { - t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) + if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) { + t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes) } } } diff --git a/examples/go.mod b/examples/go.mod index e6aa00e7c62..0bd97db7875 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -20,6 +20,7 @@ require ( github.com/envoyproxy/go-control-plane v0.11.1-0.20230524094728-9239064ad72f // indirect github.com/envoyproxy/protoc-gen-validate v0.10.1 // indirect golang.org/x/net v0.9.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.7.0 // indirect golang.org/x/text v0.9.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/examples/go.sum b/examples/go.sum index 15496ae7e61..6511b1b756a 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1005,6 +1005,7 @@ golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/go.mod b/go.mod index d78084a3ae0..acd6f919f79 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/uuid v1.3.0 golang.org/x/net v0.9.0 golang.org/x/oauth2 v0.7.0 + golang.org/x/sync v0.0.0-20190423024810-112230192c58 golang.org/x/sys v0.7.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 google.golang.org/protobuf v1.30.0 diff --git a/go.sum b/go.sum index 1907f1aa632..98a106b2a17 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,7 @@ golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 80fd5c7d2a4..77c2c0b89f6 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -40,6 +40,9 @@ var ( // pick_first LB policy, which can be enabled by setting the environment // variable "GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG" to "true". PickFirstLBConfig = boolFromEnv("GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG", false) + // ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS + // handshakes that can be performed. + ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) ) func boolFromEnv(envVar string, def bool) bool {