From 3cd53bee1351ea9c8482f3504b8e6645bbda66f1 Mon Sep 17 00:00:00 2001 From: Hossein Naderi Date: Thu, 6 Apr 2023 16:10:48 +0330 Subject: [PATCH] Implemented Connection block/unblock notification --- .../client/src/main/scala/Connection.scala | 12 ++++--- .../src/main/scala/StartupNegotiation.scala | 4 +-- .../main/scala/internal/ConnectionState.scala | 18 +++++++++-- .../internal/ConnectionReceiveSuite.scala | 31 ++++++++++++++++++ .../scala/internal/ConnectionStateSuite.scala | 32 ++++++++++++++++--- .../scala/internal/FakeConnectionState.scala | 18 ++++++++--- 6 files changed, 96 insertions(+), 19 deletions(-) diff --git a/modules/client/src/main/scala/Connection.scala b/modules/client/src/main/scala/Connection.scala index 9b2c9a9..a077b80 100644 --- a/modules/client/src/main/scala/Connection.scala +++ b/modules/client/src/main/scala/Connection.scala @@ -107,10 +107,12 @@ object Connection { case h: Frame.Header => dispatcher.header(h) case Frame.Method(0, value) => value match { - case m @ ConnectionClass.OpenOk => state.onOpened - case m @ ConnectionClass.CloseOk => state.onClosed - case m: ConnectionClass.Close => state.onCloseRequest(m) - case _ => ??? + case m @ ConnectionClass.OpenOk => state.onOpened + case m @ ConnectionClass.CloseOk => state.onClosed + case m: ConnectionClass.Close => state.onCloseRequest(m) + case ConnectionClass.Blocked(msg) => state.onBlocked(msg) + case ConnectionClass.Unblocked => state.onUnblocked + case _ => ??? } case m: Frame.Method => dispatcher.invoke(m) case Frame.Heartbeat => state.onHeartbeat @@ -143,7 +145,7 @@ object Connection { enum Status { case Connecting case Connected - case Opened + case Opened(blocked: Boolean = false) case Closed } } diff --git a/modules/client/src/main/scala/StartupNegotiation.scala b/modules/client/src/main/scala/StartupNegotiation.scala index 1662736..ca4f738 100644 --- a/modules/client/src/main/scala/StartupNegotiation.scala +++ b/modules/client/src/main/scala/StartupNegotiation.scala @@ -156,8 +156,8 @@ object StartupNegotiation { ShortString("publisher_confirms") -> true, ShortString("authentication_failure_close") -> true, ShortString("consumer_cancel_notify") -> true, - ShortString("basic.nack") -> true - // ShortString("connection.blocked") -> true + ShortString("basic.nack") -> true, + ShortString("connection.blocked") -> true ).updated(ShortString("connection_name"), connectionName) ) } diff --git a/modules/client/src/main/scala/internal/ConnectionState.scala b/modules/client/src/main/scala/internal/ConnectionState.scala index 9abe922..6f6d763 100644 --- a/modules/client/src/main/scala/internal/ConnectionState.scala +++ b/modules/client/src/main/scala/internal/ConnectionState.scala @@ -40,6 +40,8 @@ private[client] trait ConnectionState[F[_]] extends Signal[F, Status] { def onCloseRequest: F[Unit] def onCloseRequest(req: ConnectionClass.Close): F[Unit] def onHeartbeat: F[Unit] + def onBlocked(msg: ShortString): F[Unit] + def onUnblocked: F[Unit] def config: F[NegotiatedConfig] def awaitOpened: F[Unit] @@ -96,14 +98,24 @@ private[client] object ConnectionState { override def onOpened: F[Unit] = hasOpened.complete(Right(())) *> underlying .modify { - case Status.Connected => (Status.Opened, true) + case Status.Connected => (Status.Opened(), true) case other => (other, false) } .ifM(F.unit, F.raiseError(new IllegalStateException)) override def onHeartbeat: F[Unit] = underlying.get.flatMap { - case Status.Opened => output.write(Frame.Heartbeat) - case _ => F.raiseError(new IllegalStateException) + case Status.Opened(_) => output.write(Frame.Heartbeat) + case _ => F.raiseError(new IllegalStateException) + } + + override def onBlocked(msg: ShortString): F[Unit] = underlying.get.flatMap { + case Status.Opened(_) => underlying.set(Status.Opened(true)) + case _ => F.raiseError(new IllegalStateException) + } + + override def onUnblocked: F[Unit] = underlying.get.flatMap { + case Status.Opened(_) => underlying.set(Status.Opened(false)) + case _ => F.raiseError(new IllegalStateException) } override def config: F[NegotiatedConfig] = diff --git a/modules/client/src/test/scala/internal/ConnectionReceiveSuite.scala b/modules/client/src/test/scala/internal/ConnectionReceiveSuite.scala index 18a1320..43f491a 100644 --- a/modules/client/src/test/scala/internal/ConnectionReceiveSuite.scala +++ b/modules/client/src/test/scala/internal/ConnectionReceiveSuite.scala @@ -120,4 +120,35 @@ class ConnectionReceiveSuite extends InternalTestSuite { } yield () } } + + test("Dispatches connection blocked/unblocked notifications") { + val methods: Gen[ConnectionClass.Blocked | ConnectionClass.Unblocked.type] = + Gen.oneOf( + ConnectionDataGenerator.blockedGen, + ConnectionDataGenerator.unblockedGen + ) + + forAllF(methods) { method => + val frame = Frame.Method(ChannelNumber(0), method) + for { + fd <- FakeFrameDispatcher() + output <- FakeFrameOutput() + st <- FakeConnectionState(Status.Opened()) + _ <- Stream(frame) + .through(Connection.receive(st, fd)) + .compile + .drain + _ <- method match { + case ConnectionClass.Blocked(msg) => + st.interactions.assertFirst( + FakeConnectionState.Interaction.Blocked(msg) + ) + case ConnectionClass.Unblocked => + st.interactions.assertFirst( + FakeConnectionState.Interaction.Unblocked + ) + } + } yield () + } + } } diff --git a/modules/client/src/test/scala/internal/ConnectionStateSuite.scala b/modules/client/src/test/scala/internal/ConnectionStateSuite.scala index 075bf40..ce272eb 100644 --- a/modules/client/src/test/scala/internal/ConnectionStateSuite.scala +++ b/modules/client/src/test/scala/internal/ConnectionStateSuite.scala @@ -121,7 +121,7 @@ class ConnectionStateSuite extends InternalTestSuite { _ <- s.onConnected(config) _ <- s.onOpened _ <- s.awaitOpened - _ <- s.get.assertEquals(Status.Opened) + _ <- s.get.assertEquals(Status.Opened()) } yield () } @@ -132,7 +132,7 @@ class ConnectionStateSuite extends InternalTestSuite { _ <- s.awaitOpened.timeout(10.days).intercept[TimeoutException] _ <- s.onOpened _ <- s.awaitOpened - _ <- s.get.assertEquals(Status.Opened) + _ <- s.get.assertEquals(Status.Opened()) } yield () } @@ -180,6 +180,30 @@ class ConnectionStateSuite extends InternalTestSuite { } yield () } + test("Accept onBlocked/onUnblocked if opened") { + for { + s <- SUT + _ <- s.onConnected(config) + _ <- s.onOpened + _ <- s.get.assertEquals(Status.Opened(false)) + _ <- s.onBlocked(ShortString.empty) + _ <- s.get.assertEquals(Status.Opened(true)) + _ <- s.onUnblocked + _ <- s.get.assertEquals(Status.Opened(false)) + } yield () + } + + test("Ignores redundant onUnblocked if opened") { + for { + s <- SUT + _ <- s.onConnected(config) + _ <- s.onOpened + _ <- s.get.assertEquals(Status.Opened(false)) + _ <- s.onUnblocked + _ <- s.get.assertEquals(Status.Opened(false)) + } yield () + } + test("Accepts server close request if is opened") { forAllF(ConnectionDataGenerator.closeGen) { close => for { @@ -195,7 +219,7 @@ class ConnectionStateSuite extends InternalTestSuite { Frame.Method(ChannelNumber(0), ConnectionClass.CloseOk) ) ) - _ <- s.get.assertEquals(Status.Opened) + _ <- s.get.assertEquals(Status.Opened()) } yield () } } @@ -223,7 +247,7 @@ class ConnectionStateSuite extends InternalTestSuite { ) ) ) - _ <- s.get.assertEquals(Status.Opened) + _ <- s.get.assertEquals(Status.Opened()) } yield () } } diff --git a/modules/client/src/test/scala/internal/FakeConnectionState.scala b/modules/client/src/test/scala/internal/FakeConnectionState.scala index 31fa60b..5998144 100644 --- a/modules/client/src/test/scala/internal/FakeConnectionState.scala +++ b/modules/client/src/test/scala/internal/FakeConnectionState.scala @@ -74,7 +74,7 @@ final class FakeConnectionState( override def onFailed(ex: Throwable): IO[Unit] = interactions.add(Interaction.Failed(ex)) >> state.set(Status.Closed) - override def onOpened: IO[Unit] = state.set(Status.Opened) >> + override def onOpened: IO[Unit] = state.set(Status.Opened()) >> openedDef.complete(Right(())) >> interactions.add(Interaction.Opened) override def config: IO[NegotiatedConfig] = connectedDef.get @@ -84,6 +84,12 @@ final class FakeConnectionState( override def onHeartbeat: IO[Unit] = interactions.add(Interaction.Heartbeat) *> heartbeatError.run + override def onBlocked(msg: ShortString): IO[Unit] = + interactions.add(Interaction.Blocked(msg)) >> state.set(Status.Opened(true)) + + override def onUnblocked: IO[Unit] = + interactions.add(Interaction.Unblocked) >> state.set(Status.Opened(false)) + def setAsWontOpen = openedDef.complete(Left(new Exception)).void } @@ -92,6 +98,8 @@ object FakeConnectionState { case Connected(config: NegotiatedConfig) case CloseRequest(close: ConnectionClass.Close) case ClientCloseRequest, Opened, Closed, Heartbeat + case Blocked(msg: ShortString) + case Unblocked case Failed(ex: Throwable) } @@ -110,10 +118,10 @@ object FakeConnectionState { then connected.complete(defaultConfig).void else IO.unit - _ <- - if currentState == Status.Opened - then opened.complete(Right(())).void - else IO.unit + _ <- currentState match { + case Status.Opened(_) => opened.complete(Right(())).void + case _ => IO.unit + } } yield new FakeConnectionState( interactions,