Skip to content
Permalink
Browse files
fix: allows user-agent header with header provider (#871)
* fix: allows user-agent header with header provider

A bug was introduced, where if the caller tried to set a custom user
agent with a header provider an exception would be thrown (for duplicate
keys). Here, we merge the user agent set by the client along with the
one set by the library, instead of throwing such exception.

* test: adds test for default user agent

Tests if the default user agent is present in the user-agent header set
in the GapicSpannerRpc class.
  • Loading branch information
thiagotnunes committed Feb 17, 2021
1 parent ab14a5e commit 3de7e2a91349cac5d79a32d2cda7ca727140f0bf
@@ -77,7 +77,6 @@
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.RateLimiter;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
@@ -161,6 +160,7 @@
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -244,6 +244,8 @@ private void awaitTermination() throws InterruptedException {
private static final int GRPC_KEEPALIVE_SECONDS = 2 * 60;
private static final String USER_AGENT_KEY = "user-agent";
private static final String CLIENT_LIBRARY_LANGUAGE = "spanner-java";
public static final String DEFAULT_USER_AGENT =
CLIENT_LIBRARY_LANGUAGE + "/" + GaxProperties.getLibraryVersion(GapicSpannerRpc.class);

private final ManagedInstantiatingExecutorProvider executorProvider;
private boolean rpcIsClosed;
@@ -305,18 +307,11 @@ public GapicSpannerRpc(final SpannerOptions options) {
GaxGrpcProperties.getGrpcTokenName(), GaxGrpcProperties.getGrpcVersion())
.build();

HeaderProvider mergedHeaderProvider = options.getMergedHeaderProvider(internalHeaderProvider);
Map<String, String> headersWithUserAgent =
ImmutableMap.<String, String>builder()
.put(
USER_AGENT_KEY,
CLIENT_LIBRARY_LANGUAGE
+ "/"
+ GaxProperties.getLibraryVersion(GapicSpannerRpc.class))
.putAll(mergedHeaderProvider.getHeaders())
.build();
final HeaderProvider mergedHeaderProvider =
options.getMergedHeaderProvider(internalHeaderProvider);
final HeaderProvider headerProviderWithUserAgent =
FixedHeaderProvider.create(headersWithUserAgent);
headerProviderWithUserAgentFrom(mergedHeaderProvider);

this.metadataProvider =
SpannerMetadataProvider.create(
headerProviderWithUserAgent.getHeaders(),
@@ -494,6 +489,16 @@ public <RequestT, ResponseT> UnaryCallable<RequestT, ResponseT> createUnaryCalla
}
}

private static HeaderProvider headerProviderWithUserAgentFrom(HeaderProvider headerProvider) {
final Map<String, String> headersWithUserAgent = new HashMap<>(headerProvider.getHeaders());
final String userAgent = headersWithUserAgent.get(USER_AGENT_KEY);
headersWithUserAgent.put(
USER_AGENT_KEY,
userAgent == null ? DEFAULT_USER_AGENT : userAgent + " " + DEFAULT_USER_AGENT);

return FixedHeaderProvider.create(headersWithUserAgent);
}

private static void checkEmulatorConnection(
SpannerOptions options,
TransportChannelProvider channelProvider,
@@ -24,7 +24,9 @@
import static org.junit.Assume.assumeTrue;

import com.google.api.core.ApiFunction;
import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.spanner.DatabaseAdminClient;
@@ -151,6 +153,8 @@ public class GapicSpannerRpcTest {
private Server server;
private InetSocketAddress address;
private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>();
private Metadata seenHeaders;
private String defaultUserAgent;

@BeforeClass
public static void checkNotEmulator() {
@@ -161,6 +165,7 @@ public static void checkNotEmulator() {

@Before
public void startServer() throws IOException {
defaultUserAgent = "spanner-java/" + GaxProperties.getLibraryVersion(GapicSpannerRpc.class);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1AND2, SELECT1_RESULTSET));
@@ -183,6 +188,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
seenHeaders = headers;
String auth =
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
@@ -502,6 +508,46 @@ public void testAdminRequestsLimitExceededRetryAlgorithm() {
assertThat(alg.shouldRetry(new Exception("random exception"), null)).isFalse();
}

@Test
public void testDefaultUserAgent() {
final SpannerOptions options = createSpannerOptions();
final Spanner spanner = options.getService();
final DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]"));

try (final ResultSet rs = databaseClient.singleUse().executeQuery(SELECT1AND2)) {
rs.next();
}

assertThat(seenHeaders.get(Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER)))
.contains(defaultUserAgent);
}

@Test
public void testCustomUserAgent() {
final HeaderProvider userAgentHeaderProvider =
new HeaderProvider() {
@Override
public Map<String, String> getHeaders() {
final Map<String, String> headers = new HashMap<>();
headers.put("user-agent", "test-agent");
return headers;
}
};
final SpannerOptions options =
createSpannerOptions().toBuilder().setHeaderProvider(userAgentHeaderProvider).build();
final Spanner spanner = options.getService();
final DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]"));

try (final ResultSet rs = databaseClient.singleUse().executeQuery(SELECT1AND2)) {
rs.next();
}

assertThat(seenHeaders.get(Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER)))
.contains("test-agent " + defaultUserAgent);
}

@SuppressWarnings("rawtypes")
private SpannerOptions createSpannerOptions() {
String endpoint = address.getHostString() + ":" + server.getPort();

0 comments on commit 3de7e2a

Please sign in to comment.