Skip to content

Commit

Permalink
Merge branch 'master' into 4487-metadata-internal-method
Browse files Browse the repository at this point in the history
  • Loading branch information
Aditya-Sood committed Jan 18, 2024
2 parents 26a8b06 + ddd377f commit 4494626
Show file tree
Hide file tree
Showing 79 changed files with 3,783 additions and 2,194 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ jobs:

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@407ffafae6a767df3e0230c3df91b6443ae8df75 # v2.22.8
uses: github/codeql-action/init@1500a131381b66de0c52ac28abb13cd79f4b7ecc # v2.22.12
with:
languages: go

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@407ffafae6a767df3e0230c3df91b6443ae8df75 # v2.22.8
uses: github/codeql-action/analyze@1500a131381b66de0c52ac28abb13cd79f4b7ecc # v2.22.12
2 changes: 1 addition & 1 deletion binarylog/grpc_binarylog_v1/binarylog.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions credentials/alts/alts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
)

const (
defaultTestLongTimeout = 10 * time.Second
defaultTestLongTimeout = 60 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
)

Expand Down Expand Up @@ -392,17 +392,23 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
defer cancel()
c := testgrpc.NewTestServiceClient(conn)
success := false
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
if err == nil {
success = true
break
}
if code := status.Code(err); code == codes.Unavailable {
// The server is not ready yet. Try again.
if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
// The server is not ready yet or there were too many concurrent handshakes.
// Try again.
continue
}
t.Fatalf("c.UnaryCall() failed: %v", err)
}
if !success {
t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
}
}

func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
Expand Down
19 changes: 12 additions & 7 deletions credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"io"
"net"
"time"

"golang.org/x/sync/semaphore"
grpc "google.golang.org/grpc"
Expand Down Expand Up @@ -60,8 +61,6 @@ var (
// control number of concurrent created (but not closed) handshakes.
clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
// errOutOfBound occurs when the handshake service returns a consumed
// bytes value larger than the buffer that was passed to it originally.
errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
Expand Down Expand Up @@ -155,8 +154,8 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn,
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !clientHandshakes.TryAcquire(1) {
return nil, nil, errDropped
if err := clientHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
}
defer clientHandshakes.Release(1)

Expand Down Expand Up @@ -208,8 +207,8 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !serverHandshakes.TryAcquire(1) {
return nil, nil, errDropped
if err := serverHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
}
defer serverHandshakes.Release(1)

Expand Down Expand Up @@ -308,8 +307,10 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
var lastWriteTime time.Time
for {
if len(resp.OutFrames) > 0 {
lastWriteTime = time.Now()
if _, err := h.conn.Write(resp.OutFrames); err != nil {
return nil, nil, err
}
Expand All @@ -333,11 +334,15 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
p := append(extra, buf[:n]...)
// Compute the time elapsed since the last write to the peer.
timeElapsed := time.Since(lastWriteTime)
timeElapsedMs := uint32(timeElapsed.Milliseconds())
// From here on, p and extra point to the same slice.
resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_Next{
Next: &altspb.NextHandshakeMessageReq{
InBytes: p,
InBytes: p,
NetworkLatencyMs: timeElapsedMs,
},
},
})
Expand Down
40 changes: 29 additions & 11 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -74,6 +75,9 @@ type testRPCStream struct {
first bool
// useful for testing concurrent calls.
delay time.Duration
// The minimum expected value of the network_latency_ms field in a
// NextHandshakeMessageReq.
minExpectedNetworkLatency time.Duration
}

func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
Expand Down Expand Up @@ -102,6 +106,17 @@ func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
}
}
} else {
switch req := req.ReqOneof.(type) {
case *altspb.HandshakerReq_Next:
// Compare the network_latency_ms field to the minimum expected network
// latency.
if nl := time.Duration(req.Next.NetworkLatencyMs) * time.Millisecond; nl < t.minExpectedNetworkLatency {
return fmt.Errorf("networkLatency (%v) is smaller than expected min network latency (%v)", nl, t.minExpectedNetworkLatency)
}
default:
return fmt.Errorf("handshake request has unexpected type: %v", req)
}

// Add delay to test concurrent calls.
cleanup := stat.Update()
defer cleanup()
Expand Down Expand Up @@ -133,9 +148,11 @@ func (s) TestClientHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
readLatency time.Duration
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)},
{0 * time.Millisecond, 1, time.Duration(0)},
{0 * time.Millisecond, 1, 2 * time.Millisecond},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes), time.Duration(0)},
} {
errc := make(chan error)
stat.Reset()
Expand All @@ -145,16 +162,17 @@ func (s) TestClientHandshake(t *testing.T) {

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: true,
t: t,
isClient: true,
minExpectedNetworkLatency: testCase.readLatency,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ServerInit")
f2 := testutil.MakeFrame("ServerFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
tc := testutil.NewTestConnWithReadLatency(in, out, testCase.readLatency)
chs := &altsHandshaker{
stream: stream,
conn: tc,
Expand All @@ -175,10 +193,10 @@ func (s) TestClientHandshake(t *testing.T) {
}()
}

// Ensure all errors are expected.
// Ensure that there are no errors.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
if err := <-errc; err != nil {
t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err)
}
}

Expand Down Expand Up @@ -232,10 +250,10 @@ func (s) TestServerHandshake(t *testing.T) {
}()
}

// Ensure all errors are expected.
// Ensure that there are no errors.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
if err := <-errc; err != nil {
t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err)
}
}

Expand Down

0 comments on commit 4494626

Please sign in to comment.