Skip to content

Commit

Permalink
feat: Add header back to the client (#2016)
Browse files Browse the repository at this point in the history
* feat: add back header setting

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .
  • Loading branch information
yirutang committed Feb 28, 2023
1 parent 35db0fb commit de00447
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 9 deletions.
Expand Up @@ -18,6 +18,7 @@
import com.google.api.core.ApiFuture;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.batching.FlowController;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.auto.value.AutoValue;
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.ProtoData;
import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializtionError;
Expand Down Expand Up @@ -77,6 +78,11 @@ class ConnectionWorker implements AutoCloseable {
*/
private String streamName;

/*
* The location of this connection.
*/
private String location = null;

/*
* The proto schema of rows to write. This schema can change during multiplexing.
*/
Expand Down Expand Up @@ -211,6 +217,7 @@ public static long getApiMaxRequestBytes() {

public ConnectionWorker(
String streamName,
String location,
ProtoSchema writerSchema,
long maxInflightRequests,
long maxInflightBytes,
Expand All @@ -223,6 +230,9 @@ public ConnectionWorker(
this.hasMessageInWaitingQueue = lock.newCondition();
this.inflightReduced = lock.newCondition();
this.streamName = streamName;
if (location != null && !location.isEmpty()) {
this.location = location;
}
this.maxRetryDuration = maxRetryDuration;
if (writerSchema == null) {
throw new StatusRuntimeException(
Expand All @@ -236,6 +246,18 @@ public ConnectionWorker(
this.waitingRequestQueue = new LinkedList<AppendRequestAndResponse>();
this.inflightRequestQueue = new LinkedList<AppendRequestAndResponse>();
// Always recreate a client for connection worker.
HashMap<String, String> newHeaders = new HashMap<>();
newHeaders.putAll(clientSettings.toBuilder().getHeaderProvider().getHeaders());
if (this.location == null) {
newHeaders.put("x-goog-request-params", "write_stream=" + this.streamName);
} else {
newHeaders.put("x-goog-request-params", "write_location=" + this.location);
}
BigQueryWriteSettings stubSettings =
clientSettings
.toBuilder()
.setHeaderProvider(FixedHeaderProvider.create(newHeaders))
.build();
this.client = BigQueryWriteClient.create(clientSettings);

this.appendThread =
Expand Down Expand Up @@ -297,6 +319,24 @@ public void run(Throwable finalStatus) {

/** Schedules the writing of rows at given offset. */
ApiFuture<AppendRowsResponse> append(StreamWriter streamWriter, ProtoRows rows, long offset) {
if (this.location != null && this.location != streamWriter.getLocation()) {
throw new StatusRuntimeException(
Status.fromCode(Code.INVALID_ARGUMENT)
.withDescription(
"StreamWriter with location "
+ streamWriter.getLocation()
+ " is scheduled to use a connection with location "
+ this.location));
} else if (this.location == null && streamWriter.getStreamName() != this.streamName) {
// Location is null implies this is non-multiplexed connection.
throw new StatusRuntimeException(
Status.fromCode(Code.INVALID_ARGUMENT)
.withDescription(
"StreamWriter with stream name "
+ streamWriter.getStreamName()
+ " is scheduled to use a connection with stream name "
+ this.streamName));
}
Preconditions.checkNotNull(streamWriter);
AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder();
requestBuilder.setProtoRows(
Expand All @@ -322,6 +362,10 @@ Boolean isUserClosed() {
}
}

String getWriteLocation() {
return this.location;
}

private ApiFuture<AppendRowsResponse> appendInternal(
StreamWriter streamWriter, AppendRowsRequest message) {
AppendRequestAndResponse requestWrapper = new AppendRequestAndResponse(message, streamWriter);
Expand Down
Expand Up @@ -288,7 +288,8 @@ private ConnectionWorker createOrReuseConnectionWorker(
String streamReference = streamWriter.getStreamName();
if (connectionWorkerPool.size() < currentMaxConnectionCount) {
// Always create a new connection if we haven't reached current maximum.
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
return createConnectionWorker(
streamWriter.getStreamName(), streamWriter.getLocation(), streamWriter.getProtoSchema());
} else {
ConnectionWorker existingBestConnection =
pickBestLoadConnection(
Expand All @@ -304,7 +305,10 @@ private ConnectionWorker createOrReuseConnectionWorker(
if (currentMaxConnectionCount > settings.maxConnectionsPerRegion()) {
currentMaxConnectionCount = settings.maxConnectionsPerRegion();
}
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
return createConnectionWorker(
streamWriter.getStreamName(),
streamWriter.getLocation(),
streamWriter.getProtoSchema());
} else {
// Stick to the original connection if all the connections are overwhelmed.
if (existingConnectionWorker != null) {
Expand Down Expand Up @@ -359,15 +363,16 @@ static ConnectionWorker pickBestLoadConnection(
* a single stream reference. This is because createConnectionWorker(...) is called via
* computeIfAbsent(...) which is at most once per key.
*/
private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema writeSchema)
throws IOException {
private ConnectionWorker createConnectionWorker(
String streamName, String location, ProtoSchema writeSchema) throws IOException {
if (enableTesting) {
// Though atomic integer is super lightweight, add extra if check in case adding future logic.
testValueCreateConnectionCount.getAndIncrement();
}
ConnectionWorker connectionWorker =
new ConnectionWorker(
streamName,
location,
writeSchema,
maxInflightRequests,
maxInflightBytes,
Expand Down
Expand Up @@ -208,6 +208,7 @@ private StreamWriter(Builder builder) throws IOException {
SingleConnectionOrConnectionPool.ofSingleConnection(
new ConnectionWorker(
builder.streamName,
builder.location,
builder.writerSchema,
builder.maxInflightRequest,
builder.maxInflightBytes,
Expand Down
Expand Up @@ -430,6 +430,7 @@ private StreamWriter getTestStreamWriter(String streamName) throws IOException {
return StreamWriter.newBuilder(streamName)
.setWriterSchema(createProtoSchema())
.setTraceId(TEST_TRACE_ID)
.setLocation("us")
.setCredentialsProvider(NoCredentialsProvider.create())
.setChannelProvider(serviceHelper.createChannelProvider())
.build();
Expand Down
Expand Up @@ -39,13 +39,15 @@
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.logging.Logger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ConnectionWorkerTest {
private static final Logger log = Logger.getLogger(StreamWriter.class.getName());
private static final String TEST_STREAM_1 = "projects/p1/datasets/d1/tables/t1/streams/s1";
private static final String TEST_STREAM_2 = "projects/p2/datasets/d2/tables/t2/streams/s2";
private static final String TEST_TRACE_ID = "DATAFLOW:job_id";
Expand Down Expand Up @@ -84,10 +86,12 @@ public void testMultiplexedAppendSuccess() throws Exception {
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setWriterSchema(createProtoSchema("foo"))
.setLocation("us")
.build();
StreamWriter sw2 =
StreamWriter.newBuilder(TEST_STREAM_2, client)
.setWriterSchema(createProtoSchema("complicate"))
.setLocation("us")
.build();
// We do a pattern of:
// send to stream1, string1
Expand Down Expand Up @@ -205,11 +209,20 @@ public void testAppendInSameStream_switchSchema() throws Exception {
// send to stream1, schema1
// ...
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
StreamWriter sw2 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema2).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema2)
.build();
StreamWriter sw3 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema3).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema3)
.build();
for (long i = 0; i < appendCount; i++) {
switch ((int) i % 4) {
case 0:
Expand Down Expand Up @@ -305,10 +318,14 @@ public void testAppendInSameStream_switchSchema() throws Exception {
public void testAppendButInflightQueueFull() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
6,
100000,
Expand Down Expand Up @@ -356,10 +373,14 @@ public void testAppendButInflightQueueFull() throws Exception {
public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
100000,
100000,
Expand Down Expand Up @@ -411,6 +432,69 @@ public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
assertThat(ex.getCause()).hasMessageThat().contains("Any exception can happen.");
}

@Test
public void testLocationMismatch() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setWriterSchema(schema1)
.setLocation("eu")
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
100000,
100000,
Duration.ofSeconds(100),
FlowController.LimitExceededBehavior.Block,
TEST_TRACE_ID,
client.getSettings());
StatusRuntimeException ex =
assertThrows(
StatusRuntimeException.class,
() ->
sendTestMessage(
connectionWorker,
sw1,
createFooProtoRows(new String[] {String.valueOf(0)}),
0));
assertEquals(
"INVALID_ARGUMENT: StreamWriter with location eu is scheduled to use a connection with location us",
ex.getMessage());
}

@Test
public void testStreamNameMismatch() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_2,
null,
createProtoSchema("foo"),
100000,
100000,
Duration.ofSeconds(100),
FlowController.LimitExceededBehavior.Block,
TEST_TRACE_ID,
client.getSettings());
StatusRuntimeException ex =
assertThrows(
StatusRuntimeException.class,
() ->
sendTestMessage(
connectionWorker,
sw1,
createFooProtoRows(new String[] {String.valueOf(0)}),
0));
assertEquals(
"INVALID_ARGUMENT: StreamWriter with stream name projects/p1/datasets/d1/tables/t1/streams/s1 is scheduled to use a connection with stream name projects/p2/datasets/d2/tables/t2/streams/s2",
ex.getMessage());
}

@Test
public void testExponentialBackoff() throws Exception {
assertThat(ConnectionWorker.calculateSleepTimeMilli(0)).isEqualTo(1);
Expand Down Expand Up @@ -440,6 +524,7 @@ private ConnectionWorker createConnectionWorker(
throws IOException {
return new ConnectionWorker(
streamName,
"us",
createProtoSchema("foo"),
maxRequests,
maxBytes,
Expand Down

0 comments on commit de00447

Please sign in to comment.