Skip to content

Commit

Permalink
fix: Change PublisherImpl and SerialBatcher interplay to not call int…
Browse files Browse the repository at this point in the history
…o the network layer on the downcall (#975)

* fix: Change PublisherImpl and SerialBatcher interplay to not call into the network layer on the downcall

This call can cause user publish() calls to block until the stream is able to reconnect on transient stream disconnections

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
dpcollins-google and gcf-owl-bot[bot] committed Dec 8, 2021
1 parent 9327e79 commit e771c49
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 106 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ If you are using Maven, add this to your pom.xml file:
If you are using Gradle without BOM, add this to your dependencies

```Groovy
implementation 'com.google.cloud:google-cloud-pubsublite:1.4.1'
implementation 'com.google.cloud:google-cloud-pubsublite:1.4.2'
```

If you are using SBT, add this to your dependencies

```Scala
libraryDependencies += "com.google.cloud" % "google-cloud-pubsublite" % "1.4.1"
libraryDependencies += "com.google.cloud" % "google-cloud-pubsublite" % "1.4.2"
```

## Authentication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public final class PublisherImpl extends ProxyService
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();

private final AlarmFactory alarmFactory;
private final BatchingSettings batchingSettings;
private final PublishRequest initialRequest;

private final CloseableMonitor monitor = new CloseableMonitor();
Expand All @@ -84,7 +83,10 @@ public boolean isSatisfied() {
@GuardedBy("monitor.monitor")
private Optional<Offset> lastSentOffset = Optional.empty();

@GuardedBy("monitor.monitor")
// batcherMonitor is always acquired after monitor.monitor when both are held.
private final CloseableMonitor batcherMonitor = new CloseableMonitor();

@GuardedBy("batcherMonitor.monitor")
private final SerialBatcher batcher;

private static class InFlightBatch {
Expand Down Expand Up @@ -112,7 +114,6 @@ private static class InFlightBatch {
this.alarmFactory = alarmFactory;
Preconditions.checkNotNull(batchingSettings.getRequestByteThreshold());
Preconditions.checkNotNull(batchingSettings.getElementCountThreshold());
this.batchingSettings = batchingSettings;
this.initialRequest = PublishRequest.newBuilder().setInitialRequest(initialRequest).build();
this.connection =
new RetryingConnectionImpl<>(streamFactory, publisherFactory, this, this.initialRequest);
Expand Down Expand Up @@ -216,40 +217,25 @@ protected void stop() {
flush(); // Flush again in case messages were added since shutdown was set.
}

@GuardedBy("monitor.monitor")
private void processBatch(Collection<UnbatchedMessage> batch) throws CheckedApiException {
if (batch.isEmpty()) return;
InFlightBatch inFlightBatch = new InFlightBatch(batch);
batchesInFlight.add(inFlightBatch);
connection.modifyConnection(
connectionOr -> {
checkState(connectionOr.isPresent(), "Published after the stream shut down.");
connectionOr.get().publish(inFlightBatch.messages);
});
}

@GuardedBy("monitor.monitor")
private void terminateOutstandingPublishes(CheckedApiException e) {
batchesInFlight.forEach(
batch -> batch.messageFutures.forEach(future -> future.setException(e)));
batcher.flush().forEach(m -> m.future().setException(e));
try (CloseableMonitor.Hold h = batcherMonitor.enter()) {
batcher.flush().forEach(batch -> batch.forEach(m -> m.future().setException(e)));
}
batchesInFlight.clear();
}

@Override
public ApiFuture<Offset> publish(Message message) {
PubSubMessage proto = message.toProto();
try (CloseableMonitor.Hold h = monitor.enter()) {
try (CloseableMonitor.Hold h = batcherMonitor.enter()) {
ApiService.State currentState = state();
checkState(
currentState == ApiService.State.RUNNING,
String.format("Cannot publish when Publisher state is %s.", currentState.name()));
checkState(!shutdown, "Published after the stream shut down.");
ApiFuture<Offset> messageFuture = batcher.add(proto);
if (batcher.shouldFlush()) {
processBatch(batcher.flush());
}
return messageFuture;
return batcher.add(proto);
} catch (CheckedApiException e) {
onPermanentError(e);
return ApiFutures.immediateFailedFuture(e);
Expand All @@ -267,12 +253,30 @@ public void cancelOutstandingPublishes() {
private void flushToStream() {
try (CloseableMonitor.Hold h = monitor.enter()) {
if (shutdown) return;
processBatch(batcher.flush());
List<List<UnbatchedMessage>> batches;
try (CloseableMonitor.Hold h2 = batcherMonitor.enter()) {
batches = batcher.flush();
}
for (List<UnbatchedMessage> batch : batches) {
processBatch(batch);
}
} catch (CheckedApiException e) {
onPermanentError(e);
}
}

@GuardedBy("monitor.monitor")
private void processBatch(Collection<UnbatchedMessage> batch) throws CheckedApiException {
if (batch.isEmpty()) return;
InFlightBatch inFlightBatch = new InFlightBatch(batch);
batchesInFlight.add(inFlightBatch);
connection.modifyConnection(
connectionOr -> {
checkState(connectionOr.isPresent(), "Published after the stream shut down.");
connectionOr.get().publish(inFlightBatch.messages);
});
}

// Flushable implementation
@Override
public void flush() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@
import com.google.auto.value.AutoValue;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.proto.PubSubMessage;
import com.google.common.base.Preconditions;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;

// A thread compatible batcher which preserves message order.
class SerialBatcher {
private final long byteLimit;
private final long messageLimit;
private long byteCount = 0L;
private Deque<UnbatchedMessage> messages = new ArrayDeque<>();

@AutoValue
Expand All @@ -49,36 +48,29 @@ public static UnbatchedMessage of(PubSubMessage message, SettableApiFuture<Offse
this.messageLimit = messageLimit;
}

// Callers should always call shouldFlush() after add, and flush() if that returns true.
ApiFuture<Offset> add(PubSubMessage message) {
byteCount += message.getSerializedSize();
SettableApiFuture<Offset> future = SettableApiFuture.create();
messages.add(UnbatchedMessage.of(message, future));
return future;
}

boolean shouldFlush() {
return byteCount >= byteLimit || messages.size() >= messageLimit;
}

// If callers satisfy the conditions on add, one of two things will be true after a call to flush.
// Either, there will be 0-many messages remaining and they will be within the limits, or
// there will be 1 message remaining.
//
// This means, an isolated call to flush will always return all messages in the batcher.
Collection<UnbatchedMessage> flush() {
Deque<UnbatchedMessage> toReturn = messages;
messages = new ArrayDeque<>();
while ((byteCount > byteLimit || toReturn.size() > messageLimit) && toReturn.size() > 1) {
messages.addFirst(toReturn.removeLast());
byteCount -= toReturn.peekLast().message().getSerializedSize();
List<List<UnbatchedMessage>> flush() {
List<List<UnbatchedMessage>> toReturn = new ArrayList<>();
List<UnbatchedMessage> currentBatch = new ArrayList<>();
toReturn.add(currentBatch);
long currentBatchBytes = 0;
for (UnbatchedMessage message : messages) {
long newBatchBytes = currentBatchBytes + message.message().getSerializedSize();
if (currentBatch.size() + 1 > messageLimit || newBatchBytes > byteLimit) {
// If we would be pushed over the limit, create a new batch.
currentBatch = new ArrayList<>();
toReturn.add(currentBatch);
newBatchBytes = message.message().getSerializedSize();
}
currentBatchBytes = newBatchBytes;
currentBatch.add(message);
}
byteCount = messages.stream().mapToLong(value -> value.message().getSerializedSize()).sum();
// Validate the postcondition.
Preconditions.checkState(
messages.size() == 1 || (byteCount <= byteLimit && messages.size() <= messageLimit),
"Postcondition violation in SerialBatcher::flush. The caller is likely not calling flush"
+ " after calling add.");
messages = new ArrayDeque<>();
return toReturn;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
package com.google.cloud.pubsublite.internal.wire;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.api.core.ApiFuture;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.internal.wire.SerialBatcher.UnbatchedMessage;
import com.google.cloud.pubsublite.proto.PubSubMessage;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.junit.Test;
Expand All @@ -36,21 +32,29 @@
@RunWith(JUnit4.class)
public class SerialBatcherTest {
private static final PubSubMessage MESSAGE_1 =
PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("Some data")).build();
PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("data")).build();
private static final PubSubMessage MESSAGE_2 =
PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("Some other data")).build();
PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("other data")).build();
private static final PubSubMessage MESSAGE_3 =
PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("more data")).build();

private static List<PubSubMessage> extractMessages(Collection<UnbatchedMessage> messages) {
private static List<PubSubMessage> extractMessages(List<List<UnbatchedMessage>> messages) {
return messages.stream()
.flatMap(batch -> batch.stream().map(UnbatchedMessage::message))
.collect(Collectors.toList());
}

private static List<PubSubMessage> extractMessagesFromBatch(List<UnbatchedMessage> messages) {
return messages.stream().map(UnbatchedMessage::message).collect(Collectors.toList());
}

@Test
public void shouldFlushAtMessageLimit() throws Exception {
public void needsImmediateFlushAtMessageLimit() throws Exception {
SerialBatcher batcher = new SerialBatcher(/*byteLimit=*/ 10000, /*messageLimit=*/ 1);
assertThat(batcher.shouldFlush()).isFalse();
ApiFuture<Offset> future = batcher.add(PubSubMessage.getDefaultInstance());
assertThat(batcher.shouldFlush()).isTrue();
ImmutableList<UnbatchedMessage> messages = ImmutableList.copyOf(batcher.flush());
List<List<UnbatchedMessage>> batches = batcher.flush();
assertThat(batches).hasSize(1);
List<UnbatchedMessage> messages = batches.get(0);
assertThat(messages).hasSize(1);
assertThat(future.isDone()).isFalse();
messages.get(0).future().set(Offset.of(43));
Expand All @@ -59,31 +63,47 @@ public void shouldFlushAtMessageLimit() throws Exception {

@Test
@SuppressWarnings({"CheckReturnValue", "FutureReturnValueIgnored"})
public void shouldFlushAtMessageLimitAggregated() {
public void moreThanLimitMultipleBatches() throws Exception {
SerialBatcher batcher =
new SerialBatcher(
/*byteLimit=*/ MESSAGE_1.getSerializedSize() + MESSAGE_2.getSerializedSize(),
/*messageLimit=*/ 1000);
batcher.add(MESSAGE_1);
batcher.add(MESSAGE_2);
batcher.add(MESSAGE_3);
List<List<UnbatchedMessage>> batches = batcher.flush();
assertThat(batches).hasSize(2);
assertThat(extractMessagesFromBatch(batches.get(0))).containsExactly(MESSAGE_1, MESSAGE_2);
assertThat(extractMessagesFromBatch(batches.get(1))).containsExactly(MESSAGE_3);
}

@Test
@SuppressWarnings({"CheckReturnValue", "FutureReturnValueIgnored"})
public void flushMessageLimit() {
SerialBatcher batcher = new SerialBatcher(/*byteLimit=*/ 10000, /*messageLimit=*/ 2);
assertThat(batcher.shouldFlush()).isFalse();
batcher.add(MESSAGE_1);
assertThat(batcher.shouldFlush()).isFalse();
batcher.add(MESSAGE_2);
assertThat(batcher.shouldFlush()).isTrue();
assertThat(extractMessages(batcher.flush())).containsExactly(MESSAGE_1, MESSAGE_2);
batcher.add(MESSAGE_3);
List<List<UnbatchedMessage>> batches = batcher.flush();
assertThat(batches.size()).isEqualTo(2);
assertThat(extractMessagesFromBatch(batches.get(0))).containsExactly(MESSAGE_1, MESSAGE_2);
assertThat(extractMessagesFromBatch(batches.get(1))).containsExactly(MESSAGE_3);
}

@Test
@SuppressWarnings({"CheckReturnValue", "FutureReturnValueIgnored"})
public void shouldFlushAtByteLimitAggregated() {
public void flushByteLimit() {
SerialBatcher batcher =
new SerialBatcher(
/*byteLimit=*/ MESSAGE_1.getSerializedSize() + 1, /*messageLimit=*/ 10000);
assertThat(batcher.shouldFlush()).isFalse();
/*byteLimit=*/ MESSAGE_1.getSerializedSize() + MESSAGE_2.getSerializedSize() + 1,
/*messageLimit=*/ 10000);
batcher.add(MESSAGE_1);
assertThat(batcher.shouldFlush()).isFalse();
batcher.add(MESSAGE_2);
assertThat(batcher.shouldFlush()).isTrue();
assertThat(extractMessages(batcher.flush())).containsExactly(MESSAGE_1);
Preconditions.checkArgument(MESSAGE_2.getSerializedSize() > MESSAGE_1.getSerializedSize());
assertThat(batcher.shouldFlush()).isTrue();
assertThat(extractMessages(batcher.flush())).containsExactly(MESSAGE_2);
batcher.add(MESSAGE_3);
List<List<UnbatchedMessage>> batches = batcher.flush();
assertThat(batches.size()).isEqualTo(2);
assertThat(extractMessagesFromBatch(batches.get(0))).containsExactly(MESSAGE_1, MESSAGE_2);
assertThat(extractMessagesFromBatch(batches.get(1))).containsExactly(MESSAGE_3);
}

@Test
Expand All @@ -93,37 +113,8 @@ public void batchesMessagesAtLimit() {
new SerialBatcher(
/*byteLimit=*/ MESSAGE_1.getSerializedSize() + MESSAGE_2.getSerializedSize(),
/*messageLimit=*/ 10000);
assertThat(batcher.shouldFlush()).isFalse();
batcher.add(MESSAGE_2);
assertThat(batcher.shouldFlush()).isFalse();
batcher.add(MESSAGE_1);
assertThat(batcher.shouldFlush()).isTrue();
assertThat(extractMessages(batcher.flush())).containsExactly(MESSAGE_2, MESSAGE_1);
}

@Test
@SuppressWarnings({"CheckReturnValue", "FutureReturnValueIgnored"})
public void callerNoFlushFailsMessagePrecondition() {
SerialBatcher batcher = new SerialBatcher(/*byteLimit=*/ 10000, /*messageLimit=*/ 1);
batcher.add(MESSAGE_1);
assertThat(batcher.shouldFlush()).isTrue();
batcher.add(MESSAGE_2);
assertThat(batcher.shouldFlush()).isTrue();
batcher.add(PubSubMessage.getDefaultInstance());
assertThat(batcher.shouldFlush()).isTrue();
assertThrows(IllegalStateException.class, batcher::flush);
}

@Test
@SuppressWarnings({"CheckReturnValue", "FutureReturnValueIgnored"})
public void callerNoFlushFailsBytePrecondition() {
SerialBatcher batcher = new SerialBatcher(/*byteLimit=*/ 1, /*messageLimit=*/ 10000);
batcher.add(MESSAGE_1);
assertThat(batcher.shouldFlush()).isTrue();
batcher.add(MESSAGE_2);
assertThat(batcher.shouldFlush()).isTrue();
batcher.add(PubSubMessage.getDefaultInstance());
assertThat(batcher.shouldFlush()).isTrue();
assertThrows(IllegalStateException.class, batcher::flush);
}
}

0 comments on commit e771c49

Please sign in to comment.