diff --git a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/TeachingSynchronizer.java b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/TeachingSynchronizer.java index 78cb728b0fa8..06a5cc6da51f 100644 --- a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/TeachingSynchronizer.java +++ b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/TeachingSynchronizer.java @@ -42,7 +42,6 @@ import java.util.Objects; import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BooleanSupplier; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -87,13 +86,6 @@ public class TeachingSynchronizer { protected final ReconnectConfig reconnectConfig; - /** - * A mechanism to check if teaching should be stopped, e.g. when the teacher itself has - * fallen behind network. - */ - @Nullable - private final BooleanSupplier requestToStopTeaching; - private final Time time; /** @@ -113,8 +105,6 @@ public class TeachingSynchronizer { * if there is a thread stuck on a blocking IO * operation that will never finish due to a * failure. - * @param requestToStopTeaching - * a function to check periodically if teaching should be stopped * @param reconnectConfig * reconnect configuration from platform */ @@ -125,7 +115,6 @@ public TeachingSynchronizer( @NonNull final MerkleDataOutputStream out, @NonNull final MerkleNode root, @Nullable final Runnable breakConnection, - @Nullable final BooleanSupplier requestToStopTeaching, @NonNull final ReconnectConfig reconnectConfig) { this.time = Objects.requireNonNull(time); @@ -137,7 +126,6 @@ public TeachingSynchronizer( subtrees.add(new TeacherSubtree(root)); this.breakConnection = breakConnection; - this.requestToStopTeaching = requestToStopTeaching; this.reconnectConfig = Objects.requireNonNull(reconnectConfig, "reconnectConfig must not be null"); } @@ -202,16 +190,7 @@ private void sendTree(final MerkleNode root, final TeacherTreeView view) final AtomicBoolean senderIsFinished = new AtomicBoolean(false); - new TeacherSendingThread( - time, - reconnectConfig, - workGroup, - in, - out, - subtrees, - view, - requestToStopTeaching, - senderIsFinished) + new TeacherSendingThread(time, reconnectConfig, workGroup, in, out, subtrees, view, senderIsFinished) .start(); new TeacherReceivingThread<>(workGroup, in, view, senderIsFinished).start(); diff --git a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/internal/TeacherSendingThread.java b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/internal/TeacherSendingThread.java index 3d5aa20d8faf..0e2974f2a891 100644 --- a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/internal/TeacherSendingThread.java +++ b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/merkle/synchronization/internal/TeacherSendingThread.java @@ -33,10 +33,8 @@ import com.swirlds.common.threading.pool.StandardWorkGroup; import com.swirlds.common.utility.throttle.RateLimiter; import edu.umd.cs.findbugs.annotations.NonNull; -import edu.umd.cs.findbugs.annotations.Nullable; import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BooleanSupplier; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -63,9 +61,6 @@ public class TeacherSendingThread { private final Queue subtrees; private final TeacherTreeView view; - @Nullable - private final BooleanSupplier requestToStopTeaching; - private final AtomicBoolean senderIsFinished; private final RateLimiter rateLimiter; @@ -82,8 +77,6 @@ public class TeacherSendingThread { * @param subtrees a queue containing roots of subtrees to send, may have more roots added by this * class * @param view an object that interfaces with the subtree - * @param requestToStopTeaching a function to check periodically if teaching should be stopped, e.g. because of the - * teacher has fallen behind network * @param senderIsFinished set to true when this thread has finished */ public TeacherSendingThread( @@ -94,14 +87,12 @@ public TeacherSendingThread( final AsyncOutputStream> out, final Queue subtrees, final TeacherTreeView view, - @Nullable final BooleanSupplier requestToStopTeaching, final AtomicBoolean senderIsFinished) { this.workGroup = workGroup; this.in = in; this.out = out; this.subtrees = subtrees; this.view = view; - this.requestToStopTeaching = requestToStopTeaching; this.senderIsFinished = senderIsFinished; final int maxRate = reconnectConfig.teacherMaxNodesPerSecond(); @@ -212,13 +203,6 @@ private void run() { while (view.areThereNodesToHandle()) { rateLimit(); - - if ((requestToStopTeaching != null) && requestToStopTeaching.getAsBoolean()) { - logger.info( - RECONNECT.getMarker(), - "Teacher's sending thread is requested to stop teaching (fallen behind?)"); - break; - } final T node = view.getNextNodeToHandle(); sendLesson(node); } diff --git a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/threading/BlockingResourceProvider.java b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/threading/BlockingResourceProvider.java index 2968b191460c..45b99bddea7e 100644 --- a/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/threading/BlockingResourceProvider.java +++ b/platform-sdk/swirlds-common/src/main/java/com/swirlds/common/threading/BlockingResourceProvider.java @@ -66,6 +66,16 @@ public boolean acquireProvidePermit() { return providePermit.tryAcquire(); } + /** + * Try to acquire the provide permit bypassing the check to see if the consumer is waiting for the resource, this + * will block the providers until {@link #releaseProvidePermit()} is called + * + * @return true if the permit has been acquired + */ + public boolean tryBlockProvidePermit() { + return providePermit.tryAcquire(); + } + /** * Release a previously acquired provide permit */ diff --git a/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTest.java b/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTest.java index c1e0ea1440e3..3d5c1eddf489 100644 --- a/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTest.java +++ b/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTest.java @@ -41,7 +41,6 @@ import java.util.List; import java.util.Queue; import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeAll; @@ -247,23 +246,6 @@ void fullTeacherSingleLeafLearner2() { assertDoesNotThrow(this::reconnect, "Should not throw a Exception"); } - @Test - @Tags({@Tag("VirtualMerkle"), @Tag("Reconnect")}) - @DisplayName("Teacher is requested to stop teaching after a few attempts") - void simulateTeacherFallenBehind() { - teacherMap.put(A_KEY, APPLE); - teacherMap.put(B_KEY, BANANA); - teacherMap.put(C_KEY, CHERRY); - teacherMap.put(D_KEY, DATE); - teacherMap.put(E_KEY, EGGPLANT); - teacherMap.put(F_KEY, FIG); - - final AtomicInteger counter = new AtomicInteger(0); - requestTeacherToStop = () -> counter.incrementAndGet() == 4; - - reconnectMultipleTimes(2); - } - /** * This test simulates some divergence from the teacher and the learner. At the time both the teacher and learner * had diverged, both had simple integer values for the key and value. At the time of divergence, the teacher had diff --git a/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTestBase.java b/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTestBase.java index 5a28c4626f24..d6fa19817696 100644 --- a/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTestBase.java +++ b/platform-sdk/swirlds-merkle/src/test/java/com/swirlds/virtual/merkle/reconnect/VirtualMapReconnectTestBase.java @@ -54,7 +54,6 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Path; -import java.util.function.BooleanSupplier; import java.util.function.Function; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -96,7 +95,6 @@ public class VirtualMapReconnectTestBase { protected VirtualMap learnerMap; protected BrokenBuilder teacherBuilder; protected BrokenBuilder learnerBuilder; - protected BooleanSupplier requestTeacherToStop; VirtualDataSourceBuilder createBuilder() throws IOException { // The tests create maps with identical names. They would conflict with each other in the default @@ -132,7 +130,6 @@ void setupEach() throws Exception { learnerBuilder = createBrokenBuilder(dataSourceBuilder); teacherMap = new VirtualMap<>("Teacher", teacherBuilder); learnerMap = new VirtualMap<>("Learner", learnerBuilder); - requestTeacherToStop = () -> false; // don't interrupt teaching by default } @BeforeAll @@ -221,10 +218,7 @@ protected void reconnectMultipleTimes( try { final MerkleNode node = MerkleTestUtils.hashAndTestSynchronization( - learnerTree, - failureExpected ? brokenTeacherTree : teacherTree, - requestTeacherToStop, - reconnectConfig); + learnerTree, failureExpected ? brokenTeacherTree : teacherTree, reconnectConfig); node.release(); assertFalse(failureExpected, "We should only succeed on the last try"); final VirtualRoot root = learnerMap.getRight(); diff --git a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectController.java b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectController.java index d4f645fc249b..29a4bbae12da 100644 --- a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectController.java +++ b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectController.java @@ -134,6 +134,16 @@ public boolean acquireLearnerPermit() { return connectionProvider.acquireProvidePermit(); } + /** + * Try to block the learner permit for reconnect. The method {@link #cancelLearnerPermit()} should be called + * to unblock the permit. + * + * @return true if the permit has been blocked + */ + public boolean blockLearnerPermit() { + return connectionProvider.tryBlockProvidePermit(); + } + /** * Releases a previously acquired permit for reconnect */ diff --git a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectProtocol.java b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectProtocol.java index d09f69cf8071..cb12e5bc700f 100644 --- a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectProtocol.java +++ b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectProtocol.java @@ -159,7 +159,7 @@ public boolean shouldAccept() { RECONNECT.getMarker(), "Rejecting reconnect request from node {} because this node has fallen behind", peerId); - reconnectMetrics.recordReconnectRejection(peerId); + reconnectRejected(); return false; } @@ -169,7 +169,7 @@ public boolean shouldAccept() { RECONNECT.getMarker(), "Rejecting reconnect request from node {} because this node isn't ACTIVE", peerId); - reconnectMetrics.recordReconnectRejection(peerId); + reconnectRejected(); return false; } @@ -181,35 +181,52 @@ public boolean shouldAccept() { RECONNECT.getMarker(), "Rejecting reconnect request from node {} due to lack of a fully signed state", peerId); - reconnectMetrics.recordReconnectRejection(peerId); + reconnectRejected(); return false; } if (!teacherState.get().isComplete()) { // this is only possible if signed state manager violates its contractual obligations - teacherState.close(); - teacherState = null; stateIncompleteLogger.error( RECONNECT.getMarker(), "Rejecting reconnect request from node {} due to lack of a fully signed state." + " The signed state manager attempted to provide a state that was not" + " fully signed, which should not be possible.", peerId); - reconnectMetrics.recordReconnectRejection(peerId); + reconnectRejected(); + return false; + } + + // we should not become a learner while we are teaching + // this can happen if we fall behind while we are teaching + // in this case, we want to finish teaching before we start learning + // so we acquire the learner permit and release it when we are done teaching + if (!reconnectController.blockLearnerPermit()) { + reconnectRejected(); return false; } // Check if a reconnect with the learner is permitted by the throttle. final boolean reconnectPermittedByThrottle = teacherThrottle.initiateReconnect(peerId); - if (reconnectPermittedByThrottle) { - initiatedBy = InitiatedBy.PEER; - return true; - } else { + if (!reconnectPermittedByThrottle) { + reconnectRejected(); + reconnectController.cancelLearnerPermit(); + return false; + } + + initiatedBy = InitiatedBy.PEER; + return true; + } + + /** + * Called when we reject a reconnect as a teacher + */ + private void reconnectRejected() { + if (teacherState != null) { teacherState.close(); teacherState = null; - reconnectMetrics.recordReconnectRejection(peerId); - return false; } + reconnectMetrics.recordReconnectRejection(peerId); } /** {@inheritDoc} */ @@ -218,6 +235,8 @@ public void acceptFailed() { teacherState.close(); teacherState = null; teacherThrottle.reconnectAttemptFinished(); + // cancel the permit acquired in shouldAccept() so that we can start learning if we need to + reconnectController.cancelLearnerPermit(); } /** {@inheritDoc} */ @@ -269,13 +288,14 @@ private void teacher(final Connection connection) { connection.getSelfId(), connection.getOtherId(), state.get().getRound(), - fallenBehindManager::hasFallenBehind, reconnectMetrics, configuration) .execute(state.get()); } finally { teacherThrottle.reconnectAttemptFinished(); teacherState = null; + // cancel the permit acquired in shouldAccept() so that we can start learning if we need to + reconnectController.cancelLearnerPermit(); } } diff --git a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectTeacher.java b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectTeacher.java index 05fd51ec4283..5f593b9b76af 100644 --- a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectTeacher.java +++ b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/ReconnectTeacher.java @@ -34,12 +34,10 @@ import com.swirlds.platform.network.Connection; import com.swirlds.platform.state.signed.SignedState; import edu.umd.cs.findbugs.annotations.NonNull; -import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; import java.net.SocketException; import java.time.Duration; import java.util.Objects; -import java.util.function.BooleanSupplier; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -70,12 +68,6 @@ public class ReconnectTeacher { private final ThreadManager threadManager; private final Time time; - /** - * A function to check periodically if teaching should be stopped, e.g. when the teacher has fallen behind. - */ - @Nullable - private final BooleanSupplier requestToStopTeaching; - /** * @param threadManager responsible for managing thread lifecycles * @param connection the connection to be used for the reconnect @@ -94,7 +86,6 @@ public ReconnectTeacher( @NonNull final NodeId selfId, @NonNull final NodeId otherId, final long lastRoundReceived, - @Nullable final BooleanSupplier requestToStopTeaching, @NonNull final ReconnectMetrics statistics, @NonNull final Configuration configuration) { @@ -106,7 +97,6 @@ public ReconnectTeacher( this.selfId = Objects.requireNonNull(selfId); this.otherId = Objects.requireNonNull(otherId); this.lastRoundReceived = lastRoundReceived; - this.requestToStopTeaching = requestToStopTeaching; this.statistics = Objects.requireNonNull(statistics); this.configuration = Objects.requireNonNull(configuration); } @@ -231,7 +221,6 @@ private void reconnect(final SignedState signedState) throws InterruptedExceptio new MerkleDataOutputStream(connection.getDos()), signedState.getState(), connection::disconnect, - requestToStopTeaching, reconnectConfig); synchronizer.synchronize(); diff --git a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectProtocol.java b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectProtocol.java index 9ea989ff6389..1223bf1ac69d 100644 --- a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectProtocol.java +++ b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectProtocol.java @@ -172,13 +172,7 @@ public void runProtocol(final Connection connection) private void teacher(final Connection connection) { try { new EmergencyReconnectTeacher( - time, - threadManager, - stateFinder, - reconnectSocketTimeout, - fallenBehindManager::hasFallenBehind, - reconnectMetrics, - configuration) + time, threadManager, stateFinder, reconnectSocketTimeout, reconnectMetrics, configuration) .execute(connection); } finally { teacherThrottle.reconnectAttemptFinished(); diff --git a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectTeacher.java b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectTeacher.java index c7e91b276e8a..b49ede60cfe5 100644 --- a/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectTeacher.java +++ b/platform-sdk/swirlds-platform-core/src/main/java/com/swirlds/platform/reconnect/emergency/EmergencyReconnectTeacher.java @@ -30,11 +30,9 @@ import com.swirlds.platform.state.signed.SignedState; import com.swirlds.platform.state.signed.SignedStateFinder; import edu.umd.cs.findbugs.annotations.NonNull; -import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; import java.time.Duration; import java.util.Objects; -import java.util.function.BooleanSupplier; import java.util.function.Predicate; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -51,15 +49,11 @@ public class EmergencyReconnectTeacher { private final Configuration configuration; private final Time time; - @Nullable - private final BooleanSupplier requestToStopTeaching; - /** * @param time provides wall clock time * @param threadManager responsible for managing thread lifecycles * @param stateFinder finds an acceptable state for emergency reconnect * @param reconnectSocketTimeout the socket timeout to use when executing a reconnect - * @param requestToStopTeaching to be checked periodically if teaching should be stopped * @param reconnectMetrics tracks reconnect metrics * @param configuration the configuration for the platform */ @@ -68,7 +62,6 @@ public EmergencyReconnectTeacher( @NonNull final ThreadManager threadManager, @NonNull final SignedStateFinder stateFinder, @NonNull final Duration reconnectSocketTimeout, - @Nullable final BooleanSupplier requestToStopTeaching, @NonNull final ReconnectMetrics reconnectMetrics, @NonNull final Configuration configuration) { this.time = Objects.requireNonNull(time); @@ -76,7 +69,6 @@ public EmergencyReconnectTeacher( this.stateFinder = Objects.requireNonNull(stateFinder, "stateFinder must not be null"); this.reconnectSocketTimeout = Objects.requireNonNull(reconnectSocketTimeout, "reconnectSocketTimeout must not be null"); - this.requestToStopTeaching = requestToStopTeaching; this.reconnectMetrics = Objects.requireNonNull(reconnectMetrics, "reconnectMetrics must not be null"); this.configuration = Objects.requireNonNull(configuration, "configuration must not be null"); } @@ -126,7 +118,6 @@ public void execute(final Connection connection) { connection.getSelfId(), connection.getOtherId(), reservedState.get().getRound(), - requestToStopTeaching, reconnectMetrics, configuration) .execute(reservedState.get()); diff --git a/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectProtocolTests.java b/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectProtocolTests.java index 7494ec337499..dbdf23459829 100644 --- a/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectProtocolTests.java +++ b/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectProtocolTests.java @@ -20,12 +20,15 @@ import static com.swirlds.platform.state.signed.ReservedSignedState.createNullReservation; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.swirlds.base.time.Time; @@ -37,6 +40,7 @@ import com.swirlds.config.api.Configuration; import com.swirlds.platform.gossip.FallenBehindManager; import com.swirlds.platform.metrics.ReconnectMetrics; +import com.swirlds.platform.network.Connection; import com.swirlds.platform.state.RandomSignedStateGenerator; import com.swirlds.platform.state.State; import com.swirlds.platform.state.signed.ReservedSignedState; @@ -68,6 +72,9 @@ class ReconnectProtocolTests { */ private PlatformStatusGetter activeStatusGetter; + private ReconnectController reconnectController; + private ReconnectThrottle teacherThrottle; + private static Stream initiateParams() { return Stream.of( Arguments.of(new InitiateParams( @@ -123,6 +130,12 @@ public String toString() { void setup() { activeStatusGetter = mock(PlatformStatusGetter.class); when(activeStatusGetter.getCurrentStatus()).thenReturn(PlatformStatus.ACTIVE); + + reconnectController = mock(ReconnectController.class); + when(reconnectController.blockLearnerPermit()).thenReturn(true); + + teacherThrottle = mock(ReconnectThrottle.class); + when(teacherThrottle.initiateReconnect(any())).thenReturn(true); } @DisplayName("Test the conditions under which the protocol should and should not be initiated") @@ -187,7 +200,7 @@ void testShouldAccept(final AcceptParams params) { () -> reservedSignedState, Duration.of(100, ChronoUnit.MILLIS), mock(ReconnectMetrics.class), - mock(ReconnectController.class), + reconnectController, mock(SignedStateValidator.class), fallenBehindManager, activeStatusGetter, @@ -265,7 +278,7 @@ void testTeacherThrottleReleased() { () -> null, Duration.of(100, ChronoUnit.MILLIS), mock(ReconnectMetrics.class), - mock(ReconnectController.class), + reconnectController, mock(SignedStateValidator.class), fallenBehindManager, activeStatusGetter, @@ -285,7 +298,7 @@ void testTeacherThrottleReleased() { () -> reservedSignedState, Duration.of(100, ChronoUnit.MILLIS), mock(ReconnectMetrics.class), - mock(ReconnectController.class), + reconnectController, mock(SignedStateValidator.class), fallenBehindManager, activeStatusGetter, @@ -369,7 +382,7 @@ void abortedTeacher() { () -> reservedSignedState, Duration.of(100, ChronoUnit.MILLIS), mock(ReconnectMetrics.class), - mock(ReconnectController.class), + reconnectController, mock(SignedStateValidator.class), fallenBehindManager, activeStatusGetter, @@ -417,9 +430,6 @@ void teacherHasNoSignedState() { @Test @DisplayName("Teacher doesn't have a status of ACTIVE") void teacherNotActive() { - final ReconnectThrottle throttle = mock(ReconnectThrottle.class); - when(throttle.initiateReconnect(any())).thenReturn(true); - final FallenBehindManager fallenBehindManager = mock(FallenBehindManager.class); when(fallenBehindManager.hasFallenBehind()).thenReturn(false); @@ -434,7 +444,7 @@ void teacherNotActive() { final ReconnectProtocol protocol = new ReconnectProtocol( getStaticThreadManager(), new NodeId(0), - throttle, + teacherThrottle, () -> reservedSignedState, Duration.of(100, ChronoUnit.MILLIS), mock(ReconnectMetrics.class), @@ -447,4 +457,75 @@ void teacherNotActive() { assertFalse(protocol.shouldAccept()); } + + @Test + @DisplayName("Teacher holds the learner permit while teaching") + void teacherHoldsLearnerPermit() { + final SignedState signedState = spy(new RandomSignedStateGenerator().build()); + when(signedState.isComplete()).thenReturn(true); + signedState.reserve("test"); + + final ReconnectProtocol protocol = new ReconnectProtocol( + getStaticThreadManager(), + new NodeId(0), + teacherThrottle, + () -> signedState.reserve("test"), + Duration.of(100, ChronoUnit.MILLIS), + mock(ReconnectMetrics.class), + reconnectController, + mock(SignedStateValidator.class), + mock(FallenBehindManager.class), + activeStatusGetter, + configuration, + Time.getCurrent()); + + assertTrue(protocol.shouldAccept()); + + verify(reconnectController, times(1)).blockLearnerPermit(); + verify(reconnectController, times(0)).cancelLearnerPermit(); + + protocol.acceptFailed(); + + verify(reconnectController, times(1)).blockLearnerPermit(); + verify(reconnectController, times(1)).cancelLearnerPermit(); + + assertTrue(protocol.shouldAccept()); + + verify(reconnectController, times(2)).blockLearnerPermit(); + verify(reconnectController, times(1)).cancelLearnerPermit(); + + assertThrows(Exception.class, () -> protocol.runProtocol(mock(Connection.class))); + + verify(reconnectController, times(2)).blockLearnerPermit(); + verify(reconnectController, times(2)).cancelLearnerPermit(); + } + + @Test + @DisplayName("Teacher holds the learner permit while teaching") + void teacherCantAcquireLearnerPermit() { + final SignedState signedState = spy(new RandomSignedStateGenerator().build()); + when(signedState.isComplete()).thenReturn(true); + signedState.reserve("test"); + + when(reconnectController.blockLearnerPermit()).thenReturn(false); + + final ReconnectProtocol protocol = new ReconnectProtocol( + getStaticThreadManager(), + new NodeId(0), + teacherThrottle, + () -> signedState.reserve("test"), + Duration.of(100, ChronoUnit.MILLIS), + mock(ReconnectMetrics.class), + reconnectController, + mock(SignedStateValidator.class), + mock(FallenBehindManager.class), + activeStatusGetter, + configuration, + Time.getCurrent()); + + assertFalse(protocol.shouldAccept()); + + verify(reconnectController, times(1)).blockLearnerPermit(); + verify(reconnectController, times(0)).cancelLearnerPermit(); + } } diff --git a/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectTest.java b/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectTest.java index 8f4bc4a2d4ee..7dbcc83380c5 100644 --- a/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectTest.java +++ b/platform-sdk/swirlds-platform-core/src/test/java/com/swirlds/platform/reconnect/ReconnectTest.java @@ -183,7 +183,6 @@ private ReconnectTeacher buildSender( selfId, otherId, lastRoundReceived, - () -> false, reconnectMetrics, platformContext.getConfiguration()); } diff --git a/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/LaggingTeachingSynchronizer.java b/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/LaggingTeachingSynchronizer.java index 01c03fd8b4f0..43af6bf56c6c 100644 --- a/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/LaggingTeachingSynchronizer.java +++ b/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/LaggingTeachingSynchronizer.java @@ -28,8 +28,6 @@ import com.swirlds.common.merkle.synchronization.internal.Lesson; import com.swirlds.common.merkle.synchronization.streams.AsyncOutputStream; import com.swirlds.common.threading.pool.StandardWorkGroup; -import edu.umd.cs.findbugs.annotations.Nullable; -import java.util.function.BooleanSupplier; /** * A {@link TeachingSynchronizer} with simulated latency. @@ -46,18 +44,9 @@ public LaggingTeachingSynchronizer( final MerkleDataOutputStream out, final MerkleNode root, final int latencyMilliseconds, - @Nullable final BooleanSupplier shouldKeepTeaching, final Runnable breakConnection, final ReconnectConfig reconnectConfig) { - super( - Time.getCurrent(), - getStaticThreadManager(), - in, - out, - root, - breakConnection, - shouldKeepTeaching, - reconnectConfig); + super(Time.getCurrent(), getStaticThreadManager(), in, out, root, breakConnection, reconnectConfig); this.latencyMilliseconds = latencyMilliseconds; } diff --git a/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/MerkleTestUtils.java b/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/MerkleTestUtils.java index 380b37855ac8..f81698116081 100644 --- a/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/MerkleTestUtils.java +++ b/platform-sdk/swirlds-unit-tests/common/swirlds-common-test/src/main/java/com/swirlds/common/test/merkle/util/MerkleTestUtils.java @@ -42,7 +42,6 @@ import com.swirlds.common.test.merkle.dummy.DummyMerkleLeaf2; import com.swirlds.common.test.merkle.dummy.DummyMerkleNode; import com.swirlds.common.threading.pool.StandardWorkGroup; -import edu.umd.cs.findbugs.annotations.Nullable; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; @@ -53,7 +52,6 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BooleanSupplier; /** * Utility methods for testing merkle trees. @@ -969,7 +967,7 @@ private static void learningSynchronizerThread(final LearningSynchronizer learne public static T testSynchronization( final MerkleNode startingTree, final MerkleNode desiredTree, final ReconnectConfig reconnectConfig) throws Exception { - return testSynchronization(startingTree, desiredTree, 0, () -> false, reconnectConfig); + return testSynchronization(startingTree, desiredTree, 0, reconnectConfig); } /** @@ -982,16 +980,6 @@ public static T testSynchronization( final int latencyMilliseconds, final ReconnectConfig reconnectConfig) throws Exception { - return testSynchronization(startingTree, desiredTree, latencyMilliseconds, () -> false, reconnectConfig); - } - - public static T testSynchronization( - final MerkleNode startingTree, - final MerkleNode desiredTree, - final int latencyMilliseconds, - @Nullable final BooleanSupplier requestToStopTeaching, - final ReconnectConfig reconnectConfig) - throws Exception { try (PairedStreams streams = new PairedStreams()) { final LearningSynchronizer learner; @@ -1026,7 +1014,6 @@ public static T testSynchronization( e.printStackTrace(); } }, - requestToStopTeaching, reconnectConfig); } else { learner = new LaggingLearningSynchronizer( @@ -1048,7 +1035,6 @@ public static T testSynchronization( streams.getTeacherOutput(), desiredTree, latencyMilliseconds, - requestToStopTeaching, () -> { try { streams.disconnect(); @@ -1184,15 +1170,6 @@ private static void assertReconnectValidity( public static T hashAndTestSynchronization( final MerkleNode startingTree, final MerkleNode desiredTree, final ReconnectConfig reconnectConfig) throws Exception { - return hashAndTestSynchronization(startingTree, desiredTree, () -> false, reconnectConfig); - } - - public static T hashAndTestSynchronization( - final MerkleNode startingTree, - final MerkleNode desiredTree, - final BooleanSupplier requestTeacherToStop, - final ReconnectConfig reconnectConfig) - throws Exception { System.out.println("------------"); System.out.println("starting: " + startingTree); System.out.println("desired: " + desiredTree); @@ -1203,7 +1180,7 @@ public static T hashAndTestSynchronization( if (desiredTree != null && desiredTree.getHash() == null) { MerkleCryptoFactory.getInstance().digestTreeSync(desiredTree); } - return testSynchronization(startingTree, desiredTree, 0, requestTeacherToStop, reconnectConfig); + return testSynchronization(startingTree, desiredTree, 0, reconnectConfig); } /**