From db751ceebc8b6981d00cd07ce4742196cc1dd50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 11 Sep 2023 14:09:59 +0200 Subject: [PATCH] fix: avoid unbalanced session pool creation (#2442) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: avoid unbalanced session pool creation A query storm at the startup of the client library before the session pool had initialized could cause the creation of an unbalanced session pool. This again would put a large batch of sessions using the same gRPC channel at the head of the pool, which could then continously be used by the application. * fix: automatically balance pool * fix: skip empty pool * fix: shuffle if unbalanced * fix: only reset randomness if actually randomized * test: randomize if many sessions are checked out * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * test: try with more channels * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fix: also consider checked out sessions for unbalanced pool * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * docs: add javadoc for property * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * perf: optimize low-QPS workloads * test: only randomize if more than 2 sessions are checked out * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * test: only skip randomization for existing sessions * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * chore: run formatter * chore: address review comments * docs: update comment on how session is added to the pool --------- Co-authored-by: Owl Bot --- .../com/google/cloud/spanner/SessionPool.java | 170 ++++++++++-- .../cloud/spanner/BaseSessionPoolTest.java | 8 + .../spanner/SessionPoolMaintainerTest.java | 6 +- .../cloud/spanner/SessionPoolStressTest.java | 29 ++- .../google/cloud/spanner/SessionPoolTest.java | 20 +- .../spanner/SessionPoolUnbalancedTest.java | 241 ++++++++++++++++++ 6 files changed, 439 insertions(+), 35 deletions(-) create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolUnbalancedTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index f5fa0ebdc4..c97fbf9b64 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -54,6 +54,7 @@ import com.google.cloud.spanner.SessionPoolOptions.InactiveTransactionRemovalOptions; import com.google.cloud.spanner.SpannerException.ResourceNotFoundException; import com.google.cloud.spanner.SpannerImpl.ClosedException; +import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.MoreObjects; @@ -1366,12 +1367,19 @@ PooledSession get(final boolean eligibleForLongRunning) { } } - final class PooledSession implements Session { + class PooledSession implements Session { @VisibleForTesting SessionImpl delegate; private volatile Instant lastUseTime; private volatile SpannerException lastException; private volatile boolean allowReplacing = true; + /** + * This ensures that the session is added at a random position in the pool the first time it is + * actually added to the pool. + */ + @GuardedBy("lock") + private Position releaseToPosition = initialReleasePosition; + /** * Property to mark if the session is eligible to be long-running. This can only be true if the * session is executing certain types of transactions (for ex - Partitioned DML) which can be @@ -1403,6 +1411,13 @@ private PooledSession(SessionImpl delegate) { this.lastUseTime = clock.instant(); } + int getChannel() { + Long channelHint = (Long) delegate.getOptions().get(SpannerRpc.Option.CHANNEL_HINT); + return channelHint == null + ? 0 + : (int) (channelHint % sessionClient.getSpanner().getOptions().getNumChannels()); + } + @Override public String toString() { return getName(); @@ -1536,7 +1551,7 @@ public void close() { if (state != SessionState.CLOSING) { state = SessionState.AVAILABLE; } - releaseSession(this, Position.FIRST); + releaseSession(this, false); } } @@ -1576,7 +1591,7 @@ private void determineDialectAsync(final SettableFuture dialect) { // in the database dialect, and there's nothing sensible that we can do with it here. dialect.setException(t); } finally { - releaseSession(this, Position.FIRST); + releaseSession(this, false); } }); } @@ -1830,7 +1845,7 @@ private void keepAliveSessions(Instant currTime) { logger.log(Level.FINE, "Keeping alive session " + sessionToKeepAlive.getName()); numSessionsToKeepAlive--; sessionToKeepAlive.keepAlive(); - releaseSession(sessionToKeepAlive, Position.FIRST); + releaseSession(sessionToKeepAlive, false); } catch (SpannerException e) { handleException(e, sessionToKeepAlive); } @@ -1929,7 +1944,7 @@ private void removeLongRunningSessions( } } - private enum Position { + enum Position { FIRST, RANDOM } @@ -1962,6 +1977,15 @@ private enum Position { final PoolMaintainer poolMaintainer; private final Clock clock; + /** + * initialReleasePosition determines where in the pool sessions are added when they are released + * into the pool the first time. This is always RANDOM in production, but some tests use FIRST to + * be able to verify the order of sessions in the pool. Using RANDOM ensures that we do not get an + * unbalanced session pool where all sessions belonging to one gRPC channel are added to the same + * region in the pool. + */ + private final Position initialReleasePosition; + private final Object lock = new Object(); private final Random random = new Random(); @@ -2045,6 +2069,7 @@ static SessionPool createPool( ((GrpcTransportOptions) spannerOptions.getTransportOptions()).getExecutorFactory(), sessionClient, poolMaintainerClock == null ? new Clock() : poolMaintainerClock, + Position.RANDOM, Metrics.getMetricRegistry(), labelValues); } @@ -2053,20 +2078,22 @@ static SessionPool createPool( SessionPoolOptions poolOptions, ExecutorFactory executorFactory, SessionClient sessionClient) { - return createPool(poolOptions, executorFactory, sessionClient, new Clock()); + return createPool(poolOptions, executorFactory, sessionClient, new Clock(), Position.RANDOM); } static SessionPool createPool( SessionPoolOptions poolOptions, ExecutorFactory executorFactory, SessionClient sessionClient, - Clock clock) { + Clock clock, + Position initialReleasePosition) { return createPool( poolOptions, null, executorFactory, sessionClient, clock, + initialReleasePosition, Metrics.getMetricRegistry(), SPANNER_DEFAULT_LABEL_VALUES); } @@ -2077,6 +2104,7 @@ static SessionPool createPool( ExecutorFactory executorFactory, SessionClient sessionClient, Clock clock, + Position initialReleasePosition, MetricRegistry metricRegistry, List labelValues) { SessionPool pool = @@ -2087,6 +2115,7 @@ static SessionPool createPool( executorFactory.get(), sessionClient, clock, + initialReleasePosition, metricRegistry, labelValues); pool.initPool(); @@ -2100,6 +2129,7 @@ private SessionPool( ScheduledExecutorService executor, SessionClient sessionClient, Clock clock, + Position initialReleasePosition, MetricRegistry metricRegistry, List labelValues) { this.options = options; @@ -2108,6 +2138,7 @@ private SessionPool( this.executor = executor; this.sessionClient = sessionClient; this.clock = clock; + this.initialReleasePosition = initialReleasePosition; this.poolMaintainer = new PoolMaintainer(); this.initMetricsCollection(metricRegistry, labelValues); this.waitOnMinSessionsLatch = @@ -2233,7 +2264,7 @@ private void handleException(SpannerException e, PooledSession session) { if (isSessionNotFound(e)) { invalidateSession(session); } else { - releaseSession(session, Position.FIRST); + releaseSession(session, false); } } @@ -2396,26 +2427,38 @@ private void maybeCreateSession() { } } } + /** Releases a session back to the pool. This might cause one of the waiters to be unblocked. */ - private void releaseSession(PooledSession session, Position position) { + private void releaseSession(PooledSession session, boolean isNewSession) { Preconditions.checkNotNull(session); synchronized (lock) { if (closureFuture != null) { return; } if (waiters.size() == 0) { - // No pending waiters - switch (position) { - case RANDOM: - if (!sessions.isEmpty()) { - int pos = random.nextInt(sessions.size() + 1); - sessions.add(pos, session); - break; - } - // fallthrough - case FIRST: - default: - sessions.addFirst(session); + // There are no pending waiters. + // Add to a random position if the head of the session pool already contains many sessions + // with the same channel as this one. + if (session.releaseToPosition == Position.FIRST && isUnbalanced(session)) { + session.releaseToPosition = Position.RANDOM; + } else if (session.releaseToPosition == Position.RANDOM + && !isNewSession + && checkedOutSessions.size() <= 2) { + // Do not randomize if there are few other sessions checked out and this session has been + // used. This ensures that this session will be re-used for the next transaction, which is + // more efficient. + session.releaseToPosition = Position.FIRST; + } + if (session.releaseToPosition == Position.RANDOM && !sessions.isEmpty()) { + // A session should only be added at a random position the first time it is added to + // the pool or if the pool was deemed unbalanced. All following releases into the pool + // should normally happen at the front of the pool (unless the pool is again deemed to be + // unbalanced). + session.releaseToPosition = Position.FIRST; + int pos = random.nextInt(sessions.size() + 1); + sessions.add(pos, session); + } else { + sessions.addFirst(session); } } else { waiters.poll().put(session); @@ -2423,6 +2466,89 @@ private void releaseSession(PooledSession session, Position position) { } } + private boolean isUnbalanced(PooledSession session) { + int channel = session.getChannel(); + int numChannels = sessionClient.getSpanner().getOptions().getNumChannels(); + return isUnbalanced(channel, this.sessions, this.checkedOutSessions, numChannels); + } + + /** + * Returns true if the given list of sessions is considered unbalanced when compared to the + * sessionChannel that is about to be added to the pool. + * + *

The method returns true if all the following is true: + * + *

    + *
  1. The list of sessions is not empty. + *
  2. The number of checked out sessions is > 2. + *
  3. The number of channels being used by the pool is > 1. + *
  4. And at least one of the following is true: + *
      + *
    1. The first numChannels sessions in the list of sessions contains more than 2 + * sessions that use the same channel as the one being added. + *
    2. The list of currently checked out sessions contains more than 2 times the the + * number of sessions with the same channel as the one being added than it should in + * order for it to be perfectly balanced. Perfectly balanced in this case means that + * the list should preferably contain size/numChannels sessions of each channel. + *
    + *
+ * + * @param channelOfSessionBeingAdded the channel number being used by the session that is about to + * be released into the pool + * @param sessions the list of all sessions in the pool + * @param checkedOutSessions the currently checked out sessions of the pool + * @param numChannels the number of channels in use + * @return true if the pool is considered unbalanced, and false otherwise + */ + @VisibleForTesting + static boolean isUnbalanced( + int channelOfSessionBeingAdded, + List sessions, + Set checkedOutSessions, + int numChannels) { + // Do not re-balance the pool if the number of checked out sessions is low, as it is + // better to re-use sessions as much as possible in a low-QPS scenario. + if (sessions.isEmpty() || checkedOutSessions.size() <= 2) { + return false; + } + if (numChannels == 1) { + return false; + } + + // Ideally, the first numChannels sessions in the pool should contain exactly one session for + // each channel. + // Check if the first numChannels sessions at the head of the pool already contain more than 2 + // sessions that use the same channel as this one. If so, we re-balance. + // We also re-balance the pool in the specific case that the pool uses 2 channels and the first + // two sessions use those two channels. + int maxSessionsAtHeadOfPool = Math.min(numChannels, 3); + int count = 0; + for (int i = 0; i < Math.min(numChannels, sessions.size()); i++) { + PooledSession otherSession = sessions.get(i); + if (channelOfSessionBeingAdded == otherSession.getChannel()) { + count++; + if (count >= maxSessionsAtHeadOfPool) { + return true; + } + } + } + // Ideally, the use of a channel in the checked out sessions is exactly + // numCheckedOut / numChannels + // We check whether we are more than a factor two away from that perfect distribution. + // If we are, then we re-balance. + count = 0; + int checkedOutThreshold = Math.max(2, 2 * checkedOutSessions.size() / numChannels); + for (PooledSessionFuture otherSession : checkedOutSessions) { + if (otherSession.isDone() && channelOfSessionBeingAdded == otherSession.get().getChannel()) { + count++; + if (count > checkedOutThreshold) { + return true; + } + } + } + return false; + } + private void handleCreateSessionsFailure(SpannerException e, int count) { synchronized (lock) { for (int i = 0; i < count; i++) { @@ -2622,7 +2748,7 @@ public void onSessionReady(SessionImpl session) { // Release the session to a random position in the pool to prevent the case that a batch // of sessions that are affiliated with the same channel are all placed sequentially in // the pool. - releaseSession(pooledSession, Position.RANDOM); + releaseSession(pooledSession, true); } } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java index 3a595358fe..7f8cf5cc1b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java @@ -26,16 +26,21 @@ import com.google.api.core.ApiFutures; import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory; import com.google.cloud.spanner.SessionPool.Clock; +import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; import com.google.protobuf.Empty; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import org.threeten.bp.Instant; abstract class BaseSessionPoolTest { ScheduledExecutorService mockExecutor; int sessionIndex; + AtomicLong channelHint = new AtomicLong(0L); final class TestExecutorFactory implements ExecutorFactory { @@ -64,6 +69,9 @@ public void release(ScheduledExecutorService executor) { @SuppressWarnings("unchecked") SessionImpl mockSession() { final SessionImpl session = mock(SessionImpl.class); + Map options = new HashMap<>(); + options.put(Option.CHANNEL_HINT, channelHint.getAndIncrement()); + when(session.getOptions()).thenReturn(options); when(session.getName()) .thenReturn( "projects/dummy/instances/dummy/database/dummy/sessions/session" + sessionIndex); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java index ab6a51c926..2f7b14fdad 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolMaintainerTest.java @@ -26,6 +26,7 @@ import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SessionPool.PooledSession; import com.google.cloud.spanner.SessionPool.PooledSessionFuture; +import com.google.cloud.spanner.SessionPool.Position; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import java.util.ArrayList; import java.util.HashMap; @@ -58,6 +59,7 @@ public void setUp() { initMocks(this); when(client.getOptions()).thenReturn(spannerOptions); when(client.getSessionClient(db)).thenReturn(sessionClient); + when(sessionClient.getSpanner()).thenReturn(client); when(spannerOptions.getNumChannels()).thenReturn(4); when(spannerOptions.getDatabaseRole()).thenReturn("role"); setupMockSessionCreation(); @@ -111,9 +113,11 @@ private SessionImpl setupMockSession(final SessionImpl session) { } private SessionPool createPool() throws Exception { + // Allow sessions to be added to the head of the pool in all cases in this test, as it is + // otherwise impossible to know which session exactly is getting pinged at what point in time. SessionPool pool = SessionPool.createPool( - options, new TestExecutorFactory(), client.getSessionClient(db), clock); + options, new TestExecutorFactory(), client.getSessionClient(db), clock, Position.FIRST); pool.idleSessionRemovedListener = input -> { idledSessions.add(input); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java index d1aab02d32..c9ba9b360a 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolStressTest.java @@ -25,6 +25,7 @@ import com.google.api.core.ApiFutures; import com.google.cloud.spanner.SessionClient.SessionConsumer; import com.google.cloud.spanner.SessionPool.PooledSessionFuture; +import com.google.cloud.spanner.SessionPool.Position; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import com.google.cloud.spanner.SessionPoolOptions.ActionOnInactiveTransaction; import com.google.cloud.spanner.SessionPoolOptions.InactiveTransactionRemovalOptions; @@ -39,6 +40,7 @@ import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -65,10 +67,10 @@ public class SessionPoolStressTest extends BaseSessionPoolTest { SessionPool pool; SessionPoolOptions options; ExecutorService createExecutor = Executors.newSingleThreadExecutor(); - Object lock = new Object(); + final Object lock = new Object(); Random random = new Random(); FakeClock clock = new FakeClock(); - Map sessions = new HashMap<>(); + final Map sessions = new ConcurrentHashMap<>(); // Exception keeps track of where the session was closed at. Map closedSessions = new HashMap<>(); Set expiredSessions = new HashSet<>(); @@ -92,6 +94,7 @@ private void setupSpanner(DatabaseId db) { when(spannerOptions.getNumChannels()).thenReturn(4); when(spannerOptions.getDatabaseRole()).thenReturn("role"); SessionClient sessionClient = mock(SessionClient.class); + when(sessionClient.getSpanner()).thenReturn(mockSpanner); when(mockSpanner.getSessionClient(db)).thenReturn(sessionClient); when(mockSpanner.getOptions()).thenReturn(spannerOptions); doAnswer( @@ -226,22 +229,26 @@ public void stressTest() throws Exception { } pool = SessionPool.createPool( - builder.build(), new TestExecutorFactory(), mockSpanner.getSessionClient(db), clock); + builder.build(), + new TestExecutorFactory(), + mockSpanner.getSessionClient(db), + clock, + Position.RANDOM); pool.idleSessionRemovedListener = pooled -> { String name = pooled.getName(); - synchronized (lock) { - sessions.remove(name); - return null; - } + // We do not take the test lock here, as we already hold the session pool lock. Taking the + // test lock as well here can cause a deadlock. + sessions.remove(name); + return null; }; pool.longRunningSessionRemovedListener = pooled -> { String name = pooled.getName(); - synchronized (lock) { - sessions.remove(name); - return null; - } + // We do not take the test lock here, as we already hold the session pool lock. Taking the + // test lock as well here can cause a deadlock. + sessions.remove(name); + return null; }; for (int i = 0; i < concurrentThreads; i++) { new Thread( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index b20d5dd652..8949ba6afa 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -52,6 +52,7 @@ import com.google.cloud.spanner.SessionPool.Clock; import com.google.cloud.spanner.SessionPool.PooledSession; import com.google.cloud.spanner.SessionPool.PooledSessionFuture; +import com.google.cloud.spanner.SessionPool.Position; import com.google.cloud.spanner.SessionPool.SessionConsumerImpl; import com.google.cloud.spanner.SpannerImpl.ClosedException; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; @@ -126,7 +127,7 @@ private SessionPool createPool() { private SessionPool createPool(Clock clock) { return SessionPool.createPool( - options, new TestExecutorFactory(), client.getSessionClient(db), clock); + options, new TestExecutorFactory(), client.getSessionClient(db), clock, Position.RANDOM); } private SessionPool createPool( @@ -137,6 +138,7 @@ private SessionPool createPool( new TestExecutorFactory(), client.getSessionClient(db), clock, + Position.RANDOM, metricRegistry, labelValues); } @@ -146,6 +148,7 @@ public void setUp() { initMocks(this); when(client.getOptions()).thenReturn(spannerOptions); when(client.getSessionClient(db)).thenReturn(sessionClient); + when(sessionClient.getSpanner()).thenReturn(client); when(spannerOptions.getNumChannels()).thenReturn(4); when(spannerOptions.getDatabaseRole()).thenReturn("role"); options = @@ -204,13 +207,27 @@ public void sessionCreation() { @Test public void poolLifo() { setupMockSessionCreation(); + options = + options + .toBuilder() + .setMinSessions(2) + .setWaitForMinSessions(Duration.ofSeconds(10L)) + .build(); pool = createPool(); + pool.maybeWaitOnMinSessions(); Session session1 = pool.getSession().get(); Session session2 = pool.getSession().get(); assertThat(session1).isNotEqualTo(session2); session2.close(); session1.close(); + + // Check the session out and back in once more to finalize their positions. + session1 = pool.getSession().get(); + session2 = pool.getSession().get(); + session2.close(); + session1.close(); + Session session3 = pool.getSession().get(); Session session4 = pool.getSession().get(); assertThat(session3).isEqualTo(session1); @@ -1187,6 +1204,7 @@ public void testSessionNotFoundReadWriteTransaction() { SpannerImpl spanner = mock(SpannerImpl.class); SessionClient sessionClient = mock(SessionClient.class); when(spanner.getSessionClient(db)).thenReturn(sessionClient); + when(sessionClient.getSpanner()).thenReturn(spanner); doAnswer( invocation -> { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolUnbalancedTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolUnbalancedTest.java new file mode 100644 index 0000000000..5a9365eaed --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolUnbalancedTest.java @@ -0,0 +1,241 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SessionPool.isUnbalanced; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.SessionPool.PooledSessionFuture; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SessionPoolUnbalancedTest { + + static PooledSession mockedSession(int channel) { + PooledSession session = mock(PooledSession.class); + when(session.getChannel()).thenReturn(channel); + return session; + } + + static List mockedSessions(int... channels) { + return Arrays.stream(channels) + .mapToObj(SessionPoolUnbalancedTest::mockedSession) + .collect(Collectors.toList()); + } + + static PooledSessionFuture mockedCheckedOutSession(int channel) { + PooledSession session = mockedSession(channel); + PooledSessionFuture future = mock(PooledSessionFuture.class); + when(future.get()).thenReturn(session); + when(future.isDone()).thenReturn(true); + return future; + } + + static Set mockedCheckedOutSessions(int... channels) { + return Arrays.stream(channels) + .mapToObj(SessionPoolUnbalancedTest::mockedCheckedOutSession) + .collect(Collectors.toSet()); + } + + @Test + public void testIsUnbalancedBasics() { + // An empty session pool is never unbalanced. + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1), 1)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1), 2)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1), 4)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1), 1)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1), 2)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1), 4)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1, 1), 1)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1, 1), 2)); + assertFalse(isUnbalanced(1, mockedSessions(), mockedCheckedOutSessions(1, 1, 1, 1, 1), 4)); + + // A session pool that has 2 or fewer sessions checked out is never unbalanced. + // This prevents low-QPS scenarios from re-balancing the pool. + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1), mockedCheckedOutSessions(), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1), mockedCheckedOutSessions(), 2)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1), mockedCheckedOutSessions(), 4)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1), mockedCheckedOutSessions(1), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1), mockedCheckedOutSessions(1), 2)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1), mockedCheckedOutSessions(1), 4)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1), mockedCheckedOutSessions(1, 1), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1), mockedCheckedOutSessions(1, 1), 2)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1), mockedCheckedOutSessions(1, 1), 4)); + + // A session pool that uses only 1 channel is never unbalanced. + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1), mockedCheckedOutSessions(), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1), mockedCheckedOutSessions(), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1), mockedCheckedOutSessions(), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1, 1), mockedCheckedOutSessions(), 1)); + assertFalse(isUnbalanced(1, mockedSessions(1, 1, 1), mockedCheckedOutSessions(1, 1, 1), 1)); + assertFalse( + isUnbalanced(1, mockedSessions(1, 1, 1, 1), mockedCheckedOutSessions(1, 1, 1, 1), 1)); + assertFalse( + isUnbalanced(1, mockedSessions(1, 1, 1, 1, 1), mockedCheckedOutSessions(1, 1, 1, 1, 1), 1)); + assertFalse( + isUnbalanced( + 1, mockedSessions(1, 1, 1, 1, 1, 1), mockedCheckedOutSessions(1, 1, 1, 1, 1, 1), 1)); + } + + @Test + public void testIsUnbalanced_returnsFalseForBalancedPool() { + assertFalse( + isUnbalanced(1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertFalse( + isUnbalanced(2, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertFalse( + isUnbalanced(3, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertFalse( + isUnbalanced(4, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + + assertFalse( + isUnbalanced(1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(4, 3, 2, 1), 4)); + assertFalse( + isUnbalanced(2, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(4, 3, 2, 1), 4)); + assertFalse( + isUnbalanced(3, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(4, 3, 2, 1), 4)); + assertFalse( + isUnbalanced(4, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(4, 3, 2, 1), 4)); + + assertFalse( + isUnbalanced( + 1, + mockedSessions(1, 2, 3, 4, 1, 2, 3, 4), + mockedCheckedOutSessions(1, 2, 3, 4, 1, 2, 3, 4), + 4)); + + // We only check the first numChannels sessions that are in the pool, so the fact that the end + // of the pool is unbalanced is not a reason to re-balance. + assertFalse( + isUnbalanced( + 1, mockedSessions(1, 2, 3, 4, 1, 1, 1, 1), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertFalse( + isUnbalanced(1, mockedSessions(1, 2, 1, 1, 1, 1), mockedCheckedOutSessions(1, 2), 2)); + assertFalse( + isUnbalanced( + 1, + mockedSessions(1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1), + mockedCheckedOutSessions(1, 2, 3, 4), + 8)); + assertFalse( + isUnbalanced( + 1, + mockedSessions(1, 1, 2, 2, 3, 3, 4, 4, 1, 1, 1, 1), + mockedCheckedOutSessions(1, 2, 3, 4), + 8)); + + // The list of checked out sessions is allowed to contain up to twice the number of sessions + // with a given channel than it should for a perfect distribution (perfect means + // num_sessions_with_a_channel == num_channels). + assertFalse( + isUnbalanced(1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 1, 2, 3), 4)); + assertFalse( + isUnbalanced( + 1, + mockedSessions(1, 2, 3, 4), + mockedCheckedOutSessions(1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 2, 3, 4, 5, 6), + 8)); + // We're only checking the list of checked out sessions against the channel that is being added + // to the pool. + assertFalse( + isUnbalanced(1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(2, 2, 2, 2), 4)); + + // We do not consider a pool unbalanced if the list of checked out sessions only contains 2 of + // the same channel, even if that would still be 'more than twice the ideal number'. This + // prevents that a small number of checked out sessions that happen to use the same channel + // causes the pool to be considered unbalanced. + assertFalse( + isUnbalanced( + 1, mockedSessions(1, 2, 3, 4, 5, 6, 7, 8), mockedCheckedOutSessions(1, 1, 2), 8)); + + // A larger number of checked out sessions means that we can also have a 'large' number of the + // same channels in that list, as long as it does not exceed twice the number that it should be + // for an ideal distribution. + assertFalse( + isUnbalanced( + 1, + mockedSessions(1, 2, 3, 4, 5, 6, 7, 8), + mockedCheckedOutSessions(1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 2, 4, 5, 5, 3, 4, 8, 8), + 8)); + } + + @Test + public void testIsUnbalanced_returnsTrueForUnbalancedPool() { + // The pool is considered unbalanced if the first numChannel sessions contain 3 or more of the + // same sessions as the one that is being added. Also; if the pool uses only 2 channels, then it + // is also considered unbalanced if the two first sessions in the pool already use the same + // channel as the one being added. + assertTrue(isUnbalanced(1, mockedSessions(1, 1), mockedCheckedOutSessions(1, 2, 1, 2), 2)); + assertTrue(isUnbalanced(2, mockedSessions(2, 2), mockedCheckedOutSessions(1, 2, 1, 2), 2)); + + assertTrue( + isUnbalanced(1, mockedSessions(1, 1, 1, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertTrue( + isUnbalanced(2, mockedSessions(2, 2, 2, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertTrue( + isUnbalanced(3, mockedSessions(1, 3, 3, 3), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + assertTrue( + isUnbalanced(4, mockedSessions(1, 4, 4, 4), mockedCheckedOutSessions(1, 2, 3, 4), 4)); + + assertTrue( + isUnbalanced( + 1, mockedSessions(1, 2, 3, 4, 5, 6, 1, 1), mockedCheckedOutSessions(1, 2, 3, 4), 8)); + assertTrue( + isUnbalanced( + 2, mockedSessions(1, 3, 4, 5, 6, 2, 2, 2), mockedCheckedOutSessions(1, 2, 3, 4), 8)); + assertTrue( + isUnbalanced( + 3, mockedSessions(1, 2, 3, 3, 4, 5, 3, 6), mockedCheckedOutSessions(1, 2, 3, 4), 8)); + assertTrue( + isUnbalanced( + 4, mockedSessions(1, 2, 3, 4, 5, 4, 5, 4), mockedCheckedOutSessions(1, 2, 3, 4), 8)); + + // The pool is also considered unbalanced if the list of checked out sessions contain more than + // 2 times as many sessions of the one being returned as it should. + assertTrue( + isUnbalanced(1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 1, 2, 1), 4)); + assertTrue( + isUnbalanced(2, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(2, 2, 2, 4), 4)); + assertTrue( + isUnbalanced(3, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 3, 3, 3), 4)); + assertTrue( + isUnbalanced(4, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(4, 2, 4, 4), 4)); + assertTrue( + isUnbalanced( + 1, mockedSessions(1, 2, 3, 4), mockedCheckedOutSessions(1, 1, 2, 1, 1, 2, 3, 1), 4)); + + assertTrue( + isUnbalanced( + 1, mockedSessions(1, 2, 3, 4, 5, 6, 7, 8), mockedCheckedOutSessions(1, 1, 1, 3), 8)); + assertTrue( + isUnbalanced( + 1, + mockedSessions(1, 2, 3, 4, 5, 6, 7, 8), + mockedCheckedOutSessions(1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 1, 1), + 8)); + } +}