@@ -16,12 +16,14 @@
package com .google .cloud .spanner .spi .v1 ;
import static com .google .common .truth .Truth .assertThat ;
import static org .hamcrest .CoreMatchers .equalTo ;
import static org .hamcrest .CoreMatchers .is ;
import static org .hamcrest .MatcherAssert .assertThat ;
import com .google .api .core .ApiFunction ;
import com .google .cloud .NoCredentials ;
import com .google .auth .oauth2 .AccessToken ;
import com .google .auth .oauth2 .OAuth2Credentials ;
import com .google .cloud .spanner .DatabaseAdminClient ;
import com .google .cloud .spanner .DatabaseClient ;
import com .google .cloud .spanner .DatabaseId ;
@@ -31,9 +33,11 @@
import com .google .cloud .spanner .ResultSet ;
import com .google .cloud .spanner .Spanner ;
import com .google .cloud .spanner .SpannerOptions ;
import com .google .cloud .spanner .SpannerOptions .CallCredentialsProvider ;
import com .google .cloud .spanner .Statement ;
import com .google .cloud .spanner .admin .database .v1 .MockDatabaseAdminImpl ;
import com .google .cloud .spanner .admin .instance .v1 .MockInstanceAdminImpl ;
import com .google .cloud .spanner .spi .v1 .SpannerRpc .Option ;
import com .google .common .base .Stopwatch ;
import com .google .protobuf .ListValue ;
import com .google .spanner .admin .database .v1 .Database ;
@@ -45,13 +49,24 @@
import com .google .spanner .v1 .StructType ;
import com .google .spanner .v1 .StructType .Field ;
import com .google .spanner .v1 .TypeCode ;
import io .grpc .CallCredentials ;
import io .grpc .Context ;
import io .grpc .Contexts ;
import io .grpc .ManagedChannelBuilder ;
import io .grpc .Metadata ;
import io .grpc .Metadata .Key ;
import io .grpc .Server ;
import io .grpc .ServerCall ;
import io .grpc .ServerCallHandler ;
import io .grpc .ServerInterceptor ;
import io .grpc .auth .MoreCallCredentials ;
import io .grpc .netty .shaded .io .grpc .netty .NettyServerBuilder ;
import java .io .IOException ;
import java .net .InetSocketAddress ;
import java .util .ArrayList ;
import java .util .HashMap ;
import java .util .List ;
import java .util .Map ;
import java .util .concurrent .TimeUnit ;
import java .util .regex .Pattern ;
import org .junit .After ;
@@ -91,11 +106,27 @@ public class GapicSpannerRpcTest {
.build ())
.setMetadata (SELECT1AND2_METADATA )
.build ();
private static final String STATIC_OAUTH_TOKEN = "STATIC_TEST_OAUTH_TOKEN" ;
private static final String VARIABLE_OAUTH_TOKEN = "VARIABLE_TEST_OAUTH_TOKEN" ;
private static final OAuth2Credentials STATIC_CREDENTIALS =
OAuth2Credentials .create (
new AccessToken (
STATIC_OAUTH_TOKEN ,
new java .util .Date (
System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
private static final OAuth2Credentials VARIABLE_CREDENTIALS =
OAuth2Credentials .create (
new AccessToken (
VARIABLE_OAUTH_TOKEN ,
new java .util .Date (
System .currentTimeMillis () + TimeUnit .MILLISECONDS .convert (1L , TimeUnit .DAYS ))));
private MockSpannerServiceImpl mockSpanner ;
private MockInstanceAdminImpl mockInstanceAdmin ;
private MockDatabaseAdminImpl mockDatabaseAdmin ;
private Server server ;
private InetSocketAddress address ;
private final Map <SpannerRpc .Option , Object > optionsMap = new HashMap <>();
@ Before
public void startServer () throws IOException {
@@ -111,8 +142,24 @@ public void startServer() throws IOException {
.addService (mockSpanner )
.addService (mockInstanceAdmin )
.addService (mockDatabaseAdmin )
// Add a server interceptor that will check that we receive the variable OAuth token
// from the CallCredentials, and not the one set as static credentials.
.intercept (
new ServerInterceptor () {
@ Override
public <ReqT , RespT > ServerCall .Listener <ReqT > interceptCall (
ServerCall <ReqT , RespT > call ,
Metadata headers ,
ServerCallHandler <ReqT , RespT > next ) {
String auth =
headers .get (Key .of ("authorization" , Metadata .ASCII_STRING_MARSHALLER ));
assertThat (auth ).isEqualTo ("Bearer " + VARIABLE_OAUTH_TOKEN );
return Contexts .interceptCall (Context .current (), call , headers , next );
}
})
.build ()
.start ();
optionsMap .put (Option .CHANNEL_HINT , Long .valueOf (1L ));
}
@ After
@@ -229,6 +276,55 @@ && getNumberOfThreadsWithName(SPANNER_THREAD_NAME, false)
assertThat (getNumberOfThreadsWithName (SPANNER_THREAD_NAME , true ), is (equalTo (0 )));
}
@ Test
public void testCallCredentialsProviderPreferenceAboveCredentials () {
SpannerOptions options =
SpannerOptions .newBuilder ()
.setCredentials (STATIC_CREDENTIALS )
.setCallCredentialsProvider (
new CallCredentialsProvider () {
@ Override
public CallCredentials getCallCredentials () {
return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
}
})
.build ();
GapicSpannerRpc rpc = new GapicSpannerRpc (options );
// GoogleAuthLibraryCallCredentials doesn't implement equals, so we can only check for the
// existence.
assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
.isNotNull ();
rpc .shutdown ();
}
@ Test
public void testCallCredentialsProviderReturnsNull () {
SpannerOptions options =
SpannerOptions .newBuilder ()
.setCredentials (STATIC_CREDENTIALS )
.setCallCredentialsProvider (
new CallCredentialsProvider () {
@ Override
public CallCredentials getCallCredentials () {
return null ;
}
})
.build ();
GapicSpannerRpc rpc = new GapicSpannerRpc (options );
assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
.isNull ();
rpc .shutdown ();
}
@ Test
public void testNoCallCredentials () {
SpannerOptions options = SpannerOptions .newBuilder ().setCredentials (STATIC_CREDENTIALS ).build ();
GapicSpannerRpc rpc = new GapicSpannerRpc (options );
assertThat (rpc .newCallContext (optionsMap , "/some/resource" ).getCallOptions ().getCredentials ())
.isNull ();
rpc .shutdown ();
}
@ SuppressWarnings ("rawtypes" )
private SpannerOptions createSpannerOptions () {
String endpoint = address .getHostString () + ":" + server .getPort ();
@@ -244,7 +340,17 @@ public ManagedChannelBuilder apply(ManagedChannelBuilder input) {
}
})
.setHost ("http://" + endpoint )
.setCredentials (NoCredentials .getInstance ())
// Set static credentials that will return the static OAuth test token.
.setCredentials (STATIC_CREDENTIALS )
// Also set a CallCredentialsProvider. These credentials should take precedence above
// the static credentials.
.setCallCredentialsProvider (
new CallCredentialsProvider () {
@ Override
public CallCredentials getCallCredentials () {
return MoreCallCredentials .from (VARIABLE_CREDENTIALS );
}
})
.build ();
}