Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding call options for PerRPCCredentials #1225

Merged
merged 4 commits into from May 11, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions call.go
Expand Up @@ -219,6 +219,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
callHdr.Creds = c.creds
}

gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
Expand Down
11 changes: 11 additions & 0 deletions rpc_util.go
Expand Up @@ -45,6 +45,7 @@ import (

"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
Expand Down Expand Up @@ -116,6 +117,7 @@ type callInfo struct {
trailerMD metadata.MD
peer *peer.Peer
traceInfo traceInfo // in trace.go
creds credentials.PerRPCCredentials
}

var defaultCallInfo = callInfo{failFast: true}
Expand Down Expand Up @@ -182,6 +184,15 @@ func FailFast(failFast bool) CallOption {
})
}

// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call.
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
return beforeCall(func(c *callInfo) error {
c.creds = creds
return nil
})
}

// The format of the payload: compressed or not?
type payloadFormat uint8

Expand Down
3 changes: 3 additions & 0 deletions stream.go
Expand Up @@ -132,6 +132,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
callHdr.Creds = c.creds
}
var trInfo traceInfo
if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
Expand Down
121 changes: 121 additions & 0 deletions test/end2end_test.go
Expand Up @@ -449,6 +449,7 @@ type test struct {
serverInitialConnWindowSize int32
clientInitialWindowSize int32
clientInitialConnWindowSize int32
perRPCCreds credentials.PerRPCCredentials

// srv and srvAddr are set once startServer is called.
srv *grpc.Server
Expand Down Expand Up @@ -621,6 +622,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.clientInitialConnWindowSize > 0 {
opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize))
}
if te.perRPCCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
Expand Down Expand Up @@ -3984,3 +3988,120 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) {
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
}
}

var (
// test authdata
authdata = map[string]string{
"test-key": "test-value",
"test-key2-bin": string([]byte{1, 2, 3}),
}
)

type testPerRPCCredentials struct{}

func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return authdata, nil
}

func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
return false
}

func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx, fmt.Errorf("didn't find metadata in context")
}
for k, vwant := range authdata {
vgot, ok := md[k]
if !ok {
return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
}
if vgot[0] != vwant {
return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
}
}
return ctx, nil
}

func TestPerRPCCredentialsViaDialOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaDialOptions(t, e)
}
}

func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
te := newTest(t, e)
te.tapHandle = authHandle
te.perRPCCreds = testPerRPCCredentials{}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}

func TestPerRPCCredentialsViaCallOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaCallOptions(t, e)
}
}

func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
te := newTest(t, e)
te.tapHandle = authHandle
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}

func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
}
}

func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
te := newTest(t, e)
te.perRPCCreds = testPerRPCCredentials{}
// When credentials are provided via both dial options and call options,
// we apply both sets.
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx, fmt.Errorf("couldn't find metadata in context")
}
for k, vwant := range authdata {
vgot, ok := md[k]
if !ok {
return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
}
if len(vgot) != 2 {
return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
}
if vgot[0] != vwant || vgot[1] != vwant {
return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
}
}
return ctx, nil
}
te.startServer(&testServer{security: e.security})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
t.Fatalf("Test failed. Reason: %v", err)
}
}
32 changes: 29 additions & 3 deletions transport/http2_client.go
Expand Up @@ -101,6 +101,8 @@ type http2Client struct {
// The scheme used: https if TLS is on, http otherwise.
scheme string

isSecure bool

creds []credentials.PerRPCCredentials

// Boolean to keep track of reading activity on transport.
Expand Down Expand Up @@ -181,6 +183,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
conn.Close()
}
}(conn)
var isSecure bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group these two with

var(
  ...
)

var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
Expand All @@ -191,6 +194,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
}
isSecure = true
}
kp := opts.KeepaliveParams
// Validate keepalive parameters.
Expand Down Expand Up @@ -230,6 +234,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
scheme: scheme,
state: reachable,
activeStreams: make(map[uint32]*Stream),
isSecure: isSecure,
creds: opts.PerRPCCredentials,
maxStreams: defaultMaxStreamsClient,
streamsQuota: newQuotaPool(defaultMaxStreamsClient),
Expand Down Expand Up @@ -356,9 +361,29 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
authData[k] = v
}
}
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
callAuthData := make(map[string]string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make authData a map[string][]string and reuse authData?

if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure channel")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/channel/connection
The word channel is misleading in golang...

}
data, err := callCreds.GetRequestMetadata(ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetRequestMetadata also takes parameter uri.

if err != nil {
return nil, streamErrorf(codes.Internal, "transport: %v", err)
}
for k, v := range data {
// Capital header names are illegal in HTTP/2
k = strings.ToLower(k)
callAuthData[k] = v
}
}
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
Expand Down Expand Up @@ -437,9 +462,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}

for k, v := range authData {
// Capital header names are illegal in HTTP/2.
k = strings.ToLower(k)
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, v := range callAuthData {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
var (
hasMD bool
Expand Down
3 changes: 3 additions & 0 deletions transport/transport.go
Expand Up @@ -474,6 +474,9 @@ type CallHdr struct {
// outbound message.
SendCompress string

// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials

// Flush indicates whether a new stream command should be sent
// to the peer without waiting for the first data. This is
// only a hint. The transport may modify the flush decision
Expand Down