diff --git a/README.md b/README.md index 3a83102c76..61faf06bac 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ If you are using Maven without BOM, add this to your dependencies: If you are using Gradle 5.x or later, add this to your dependencies: ```Groovy -implementation platform('com.google.cloud:libraries-bom:26.1.1') +implementation platform('com.google.cloud:libraries-bom:26.1.2') implementation 'com.google.cloud:google-cloud-bigquerystorage' ``` diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java index a4642a96b0..f1747bd31a 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java @@ -18,9 +18,26 @@ import com.google.api.core.ApiFuture; import com.google.api.gax.batching.FlowController; import com.google.auto.value.AutoValue; +import com.google.cloud.bigquery.storage.v1.ConnectionWorker.Load; +import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.logging.Logger; import javax.annotation.concurrent.GuardedBy; +/** Pool of connections to accept appends and distirbute to different connections. */ public class ConnectionWorkerPool { + private static final Logger log = Logger.getLogger(ConnectionWorkerPool.class.getName()); /* * Max allowed inflight requests in the stream. Method append is blocked at this. */ @@ -36,11 +53,29 @@ public class ConnectionWorkerPool { */ private final FlowController.LimitExceededBehavior limitExceededBehavior; + /** Map from write stream to corresponding connection. */ + private final Map streamWriterToConnection = + new ConcurrentHashMap<>(); + + /** Map from a connection to a set of write stream that have sent requests onto it. */ + private final Map> connectionToWriteStream = + new ConcurrentHashMap<>(); + + /** Collection of all the created connections. */ + private final Set connectionWorkerPool = + Collections.synchronizedSet(new HashSet<>()); + + /** Enable test related logic. */ + private static boolean enableTesting = false; + /* * TraceId for debugging purpose. */ private final String traceId; + /** Used for test on the number of times createWorker is called. */ + private final AtomicInteger testValueCreateConnectionCount = new AtomicInteger(0); + /* * Tracks current inflight requests in the stream. */ @@ -102,6 +137,15 @@ public class ConnectionWorkerPool { */ private boolean ownsBigQueryWriteClient = false; + /** + * The current maximum connection count. This value is gradually increased till the user defined + * maximum connection count. + */ + private int currentMaxConnectionCount; + + /** Lock for controlling concurrent operation on add / delete connections. */ + private final Lock lock = new ReentrantLock(); + /** Settings for connection pool. */ @AutoValue public abstract static class Settings { @@ -147,6 +191,7 @@ public ConnectionWorkerPool( this.traceId = traceId; this.client = client; this.ownsBigQueryWriteClient = ownsBigQueryWriteClient; + this.currentMaxConnectionCount = settings.minConnectionsPerPool(); } /** @@ -160,13 +205,149 @@ public static void setOptions(Settings settings) { /** Distributes the writing of a message to an underlying connection. */ public ApiFuture append(StreamWriter streamWriter, ProtoRows rows) { - throw new RuntimeException("Append is not implemented!"); + return append(streamWriter, rows, -1); } /** Distributes the writing of a message to an underlying connection. */ public ApiFuture append( StreamWriter streamWriter, ProtoRows rows, long offset) { - throw new RuntimeException("append with offset is not implemented on connection pool!"); + // We are in multiplexing mode after entering the following logic. + ConnectionWorker connectionWorker = + streamWriterToConnection.compute( + streamWriter, + (key, existingStream) -> { + // Though compute on concurrent map is atomic, we still do explicit locking as we + // may have concurrent close(...) triggered. + lock.lock(); + try { + // Stick to the existing stream if it's not overwhelmed. + if (existingStream != null && !existingStream.getLoad().isOverwhelmed()) { + return existingStream; + } + // Try to create or find another existing stream to reuse. + ConnectionWorker createdOrExistingConnection = null; + try { + createdOrExistingConnection = + createOrReuseConnectionWorker(streamWriter, existingStream); + } catch (IOException e) { + throw new IllegalStateException(e); + } + // Update connection to write stream relationship. + connectionToWriteStream.computeIfAbsent( + createdOrExistingConnection, (ConnectionWorker k) -> new HashSet<>()); + connectionToWriteStream.get(createdOrExistingConnection).add(streamWriter); + return createdOrExistingConnection; + } finally { + lock.unlock(); + } + }); + Stopwatch stopwatch = Stopwatch.createStarted(); + ApiFuture responseFuture = + connectionWorker.append( + streamWriter.getStreamName(), streamWriter.getProtoSchema(), rows, offset); + return responseFuture; + } + + /** + * Create a new connection if we haven't reached current maximum, or reuse an existing connection + * with least load. + */ + private ConnectionWorker createOrReuseConnectionWorker( + StreamWriter streamWriter, ConnectionWorker existingConnectionWorker) throws IOException { + String streamReference = streamWriter.getStreamName(); + if (connectionWorkerPool.size() < currentMaxConnectionCount) { + // Always create a new connection if we haven't reached current maximum. + return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema()); + } else { + ConnectionWorker existingBestConnection = + pickBestLoadConnection( + enableTesting ? Load.TEST_LOAD_COMPARATOR : Load.LOAD_COMPARATOR, + ImmutableList.copyOf(connectionWorkerPool)); + if (!existingBestConnection.getLoad().isOverwhelmed()) { + return existingBestConnection; + } else if (currentMaxConnectionCount < settings.maxConnectionsPerPool()) { + // At this point, we have reached the connection cap and the selected connection is + // overwhelmed, we can try scale up the connection pool. + // The connection count will go up one by one until `maxConnectionsPerPool` is reached. + currentMaxConnectionCount += 1; + if (currentMaxConnectionCount > settings.maxConnectionsPerPool()) { + currentMaxConnectionCount = settings.maxConnectionsPerPool(); + } + return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema()); + } else { + // Stick to the original connection if all the connections are overwhelmed. + if (existingConnectionWorker != null) { + return existingConnectionWorker; + } + // If we are at this branch, it means we reached the maximum connections. + return existingBestConnection; + } + } + } + + /** Select out the best connection worker among the given connection workers. */ + static ConnectionWorker pickBestLoadConnection( + Comparator comparator, List connectionWorkerList) { + if (connectionWorkerList.isEmpty()) { + throw new IllegalStateException( + String.format( + "Bug in code! At least one connection worker should be passed in " + + "pickSemiBestLoadConnection(...)")); + } + // Compare all connection workers to find the connection worker with the smallest load. + // Loop and find the connection with the least load. + // The load comparision and computation process + int currentBestIndex = 0; + Load currentBestLoad = connectionWorkerList.get(currentBestIndex).getLoad(); + for (int i = 1; i < connectionWorkerList.size(); i++) { + Load loadToCompare = connectionWorkerList.get(i).getLoad(); + if (comparator.compare(loadToCompare, currentBestLoad) <= 0) { + currentBestIndex = i; + currentBestLoad = loadToCompare; + } + } + return connectionWorkerList.get(currentBestIndex); + } + + /** + * Creates a single connection worker. + * + *

Note this function need to be thread-safe across different stream reference but no need for + * a single stream reference. This is because createConnectionWorker(...) is called via + * computeIfAbsent(...) which is at most once per key. + */ + private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema writeSchema) + throws IOException { + if (enableTesting) { + // Though atomic integer is super lightweight, add extra if check in case adding future logic. + testValueCreateConnectionCount.getAndIncrement(); + } + ConnectionWorker connectionWorker = + new ConnectionWorker( + streamName, + writeSchema, + maxInflightRequests, + maxInflightBytes, + limitExceededBehavior, + traceId, + client, + ownsBigQueryWriteClient); + connectionWorkerPool.add(connectionWorker); + log.info( + String.format( + "Scaling up new connection for stream name: %s, pool size after scaling up %s", + streamName, connectionWorkerPool.size())); + return connectionWorker; + } + + /** Enable Test related logic. */ + public static void enableTestingLogic() { + enableTesting = true; + } + + /** Returns how many times createConnectionWorker(...) is called. */ + int getCreateConnectionCount() { + return testValueCreateConnectionCount.get(); } /** Close the stream writer. Shut down all resources. */ diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java index 922dd66e81..e869668818 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java @@ -152,6 +152,11 @@ public String getStreamName() { return streamName; } + /** @return the passed in user schema. */ + public ProtoSchema getProtoSchema() { + return writerSchema; + } + /** Close the stream writer. Shut down all resources. */ @Override public void close() { diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java new file mode 100644 index 0000000000..02bc075f1d --- /dev/null +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPoolTest.java @@ -0,0 +1,222 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.bigquery.storage.v1; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.api.core.ApiFuture; +import com.google.api.gax.batching.FlowController; +import com.google.api.gax.core.NoCredentialsProvider; +import com.google.api.gax.grpc.testing.MockGrpcService; +import com.google.api.gax.grpc.testing.MockServiceHelper; +import com.google.cloud.bigquery.storage.test.Test.FooType; +import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool.Settings; +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Int64Value; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.threeten.bp.Duration; + +@RunWith(JUnit4.class) +public class ConnectionWorkerPoolTest { + + private FakeBigQueryWrite testBigQueryWrite; + private FakeScheduledExecutorService fakeExecutor; + private static MockServiceHelper serviceHelper; + private BigQueryWriteClient client; + + private static final String TEST_TRACE_ID = "home:job1"; + private static final String TEST_STREAM_1 = "projects/p1/datasets/d1/tables/t1/streams/_default"; + private static final String TEST_STREAM_2 = "projects/p1/datasets/d1/tables/t2/streams/_default"; + + @Before + public void setUp() throws Exception { + testBigQueryWrite = new FakeBigQueryWrite(); + serviceHelper = + new MockServiceHelper( + UUID.randomUUID().toString(), Arrays.asList(testBigQueryWrite)); + serviceHelper.start(); + fakeExecutor = new FakeScheduledExecutorService(); + testBigQueryWrite.setExecutor(fakeExecutor); + client = + BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(NoCredentialsProvider.create()) + .setTransportChannelProvider(serviceHelper.createChannelProvider()) + .build()); + } + + @Test + public void testSingleTableConnection_noOverwhelmedConnection() throws Exception { + // Set the max requests count to a large value so we will not scaling up. + testSend100RequestsToMultiTable( + /*maxRequests=*/ 100000, + /*maxConnections=*/ 8, + /*expectedConnectionCount=*/ 1, + /*tableCount=*/ 1); + } + + @Test + public void testSingleTableConnections_overwhelmed() throws Exception { + // A connection will be considered overwhelmed when the requests count reach 5 (max 10). + testSend100RequestsToMultiTable( + /*maxRequests=*/ 10, + /*maxConnections=*/ 8, + /*expectedConnectionCount=*/ 8, + /*tableCount=*/ 1); + } + + @Test + public void testMultiTableConnection_noOverwhelmedConnection() throws Exception { + // Set the max requests count to a large value so we will not scaling up. + // All tables will share the two connections (2 becasue we set the min connections to be 2). + testSend100RequestsToMultiTable( + /*maxRequests=*/ 100000, + /*maxConnections=*/ 8, + /*expectedConnectionCount=*/ 2, + /*tableCount=*/ 4); + } + + @Test + public void testMultiTableConnections_overwhelmed() throws Exception { + // A connection will be considered overwhelmed when the requests count reach 5 (max 10). + testSend100RequestsToMultiTable( + /*maxRequests=*/ 10, + /*maxConnections=*/ 8, + /*expectedConnectionCount=*/ 8, + /*tableCount=*/ 4); + } + + private void testSend100RequestsToMultiTable( + int maxRequests, int maxConnections, int expectedConnectionCount, int tableCount) + throws IOException, ExecutionException, InterruptedException { + ConnectionWorkerPool connectionWorkerPool = + createConnectionWorkerPool(maxRequests, /*maxBytes=*/ 100000); + ConnectionWorkerPool.setOptions( + Settings.builder().setMaxConnectionsPerPool(maxConnections).build()); + + // Sets the sleep time to simulate requests stuck in connection. + testBigQueryWrite.setResponseSleep(Duration.ofMillis(50L)); + + // Try append 100 requests. + long appendCount = 100; + for (long i = 0; i < appendCount; i++) { + testBigQueryWrite.addResponse(createAppendResponse(i)); + } + List> futures = new ArrayList<>(); + + // Create one stream writer per table. + List streamWriterList = new ArrayList<>(); + for (int i = 0; i < tableCount; i++) { + streamWriterList.add( + getTestStreamWriter( + String.format("projects/p1/datasets/d1/tables/t%s/streams/_default", i))); + } + + for (long i = 0; i < appendCount; i++) { + // Round robinly insert requests to different tables. + futures.add( + sendFooStringTestMessage( + streamWriterList.get((int) (i % streamWriterList.size())), + connectionWorkerPool, + new String[] {String.valueOf(i)}, + i)); + } + + for (int i = 0; i < appendCount; i++) { + AppendRowsResponse response = futures.get(i).get(); + assertThat(response.getAppendResult().getOffset().getValue()).isEqualTo(i); + } + // At the end we should scale up to 8 connections. + assertThat(connectionWorkerPool.getCreateConnectionCount()).isEqualTo(expectedConnectionCount); + + assertThat(testBigQueryWrite.getAppendRequests().size()).isEqualTo(appendCount); + // The request order server received is no longer guaranteed, + HashSet offsets = new HashSet<>(); + for (int i = 0; i < appendCount; i++) { + AppendRowsRequest serverRequest = testBigQueryWrite.getAppendRequests().get(i); + assertThat(serverRequest.getProtoRows().getRows().getSerializedRowsCount()).isGreaterThan(0); + offsets.add(serverRequest.getOffset().getValue()); + } + assertThat(offsets.size()).isEqualTo(appendCount); + } + + private AppendRowsResponse createAppendResponse(long offset) { + return AppendRowsResponse.newBuilder() + .setAppendResult( + AppendRowsResponse.AppendResult.newBuilder().setOffset(Int64Value.of(offset)).build()) + .build(); + } + + private StreamWriter getTestStreamWriter(String streamName) throws IOException { + return StreamWriter.newBuilder(streamName, client) + .setWriterSchema(createProtoSchema()) + .setTraceId(TEST_TRACE_ID) + .build(); + } + + private ProtoSchema createProtoSchema() { + return ProtoSchema.newBuilder() + .setProtoDescriptor( + DescriptorProtos.DescriptorProto.newBuilder() + .setName("Message") + .addField( + DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("foo") + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING) + .setNumber(1) + .build()) + .build()) + .build(); + } + + private ApiFuture sendFooStringTestMessage( + StreamWriter writeStream, + ConnectionWorkerPool connectionWorkerPool, + String[] messages, + long offset) { + return connectionWorkerPool.append(writeStream, createProtoRows(messages), offset); + } + + private ProtoRows createProtoRows(String[] messages) { + ProtoRows.Builder rowsBuilder = ProtoRows.newBuilder(); + for (String message : messages) { + FooType foo = FooType.newBuilder().setFoo(message).build(); + rowsBuilder.addSerializedRows(foo.toByteString()); + } + return rowsBuilder.build(); + } + + ConnectionWorkerPool createConnectionWorkerPool(long maxRequests, long maxBytes) { + ConnectionWorkerPool.enableTestingLogic(); + return new ConnectionWorkerPool( + maxRequests, + maxBytes, + FlowController.LimitExceededBehavior.Block, + TEST_TRACE_ID, + client, + /*ownsBigQueryWriteClient=*/ false); + } +} diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/FakeBigQueryWriteImpl.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/FakeBigQueryWriteImpl.java index 5d8f05fff5..02223ace82 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/FakeBigQueryWriteImpl.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/FakeBigQueryWriteImpl.java @@ -20,7 +20,10 @@ import io.grpc.Status; import io.grpc.stub.StreamObserver; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.Semaphore; @@ -40,7 +43,7 @@ class FakeBigQueryWriteImpl extends BigQueryWriteGrpc.BigQueryWriteImplBase { private final LinkedBlockingQueue writeRequests = new LinkedBlockingQueue<>(); private final LinkedBlockingQueue flushRequests = new LinkedBlockingQueue<>(); - private final LinkedBlockingQueue responses = new LinkedBlockingQueue<>(); + private final List responses = Collections.synchronizedList(new ArrayList<>()); private final LinkedBlockingQueue writeResponses = new LinkedBlockingQueue<>(); private final LinkedBlockingQueue flushResponses = new LinkedBlockingQueue<>(); private final AtomicInteger nextMessageId = new AtomicInteger(1); @@ -54,7 +57,10 @@ class FakeBigQueryWriteImpl extends BigQueryWriteGrpc.BigQueryWriteImplBase { private long closeAfter = 0; private long recordCount = 0; private long connectionCount = 0; - private boolean firstRecord = false; + + // Record whether the first record has been seen on a connection. + private final Map, Boolean> connectionToFirstRequest = + new ConcurrentHashMap<>(); /** Class used to save the state of a possible response. */ private static class Response { @@ -135,7 +141,7 @@ public long getConnectionCount() { public StreamObserver appendRows( final StreamObserver responseObserver) { this.connectionCount++; - this.firstRecord = true; + connectionToFirstRequest.put(responseObserver, true); StreamObserver requestObserver = new StreamObserver() { @Override @@ -143,12 +149,16 @@ public void onNext(AppendRowsRequest value) { LOG.fine("Get request:" + value.toString()); requests.add(value); recordCount++; + int offset = (int) (recordCount - 1); + if (value.hasOffset() && value.getOffset().getValue() != -1) { + offset = (int) value.getOffset().getValue(); + } if (responseSleep.compareTo(Duration.ZERO) > 0) { LOG.fine("Sleeping before response for " + responseSleep.toString()); Uninterruptibles.sleepUninterruptibly( responseSleep.toMillis(), TimeUnit.MILLISECONDS); } - if (firstRecord) { + if (connectionToFirstRequest.get(responseObserver)) { if (!value.getProtoRows().hasWriterSchema() || value.getWriteStream().isEmpty()) { LOG.info( String.valueOf( @@ -161,14 +171,14 @@ public void onNext(AppendRowsRequest value) { return; } } - firstRecord = false; + connectionToFirstRequest.put(responseObserver, false); if (closeAfter > 0 && recordCount % closeAfter == 0 && (numberTimesToClose == 0 || connectionCount <= numberTimesToClose)) { LOG.info("Shutting down connection from test..."); responseObserver.onError(Status.ABORTED.asException()); } else { - final Response response = responses.remove(); + final Response response = responses.get(offset); sendResponse(response, responseObserver); } }