Skip to content

Commit

Permalink
Merge pull request #105 from hnaderi/71-extended-amqp-capabilities
Browse files Browse the repository at this point in the history
Implemented Connection block/unblock notification
  • Loading branch information
hnaderi committed Apr 6, 2023
2 parents 0e92cc9 + 3cd53be commit 62a79c9
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 19 deletions.
12 changes: 7 additions & 5 deletions modules/client/src/main/scala/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ object Connection {
case h: Frame.Header => dispatcher.header(h)
case Frame.Method(0, value) =>
value match {
case m @ ConnectionClass.OpenOk => state.onOpened
case m @ ConnectionClass.CloseOk => state.onClosed
case m: ConnectionClass.Close => state.onCloseRequest(m)
case _ => ???
case m @ ConnectionClass.OpenOk => state.onOpened
case m @ ConnectionClass.CloseOk => state.onClosed
case m: ConnectionClass.Close => state.onCloseRequest(m)
case ConnectionClass.Blocked(msg) => state.onBlocked(msg)
case ConnectionClass.Unblocked => state.onUnblocked
case _ => ???
}
case m: Frame.Method => dispatcher.invoke(m)
case Frame.Heartbeat => state.onHeartbeat
Expand Down Expand Up @@ -143,7 +145,7 @@ object Connection {
enum Status {
case Connecting
case Connected
case Opened
case Opened(blocked: Boolean = false)
case Closed
}
}
4 changes: 2 additions & 2 deletions modules/client/src/main/scala/StartupNegotiation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ object StartupNegotiation {
ShortString("publisher_confirms") -> true,
ShortString("authentication_failure_close") -> true,
ShortString("consumer_cancel_notify") -> true,
ShortString("basic.nack") -> true
// ShortString("connection.blocked") -> true
ShortString("basic.nack") -> true,
ShortString("connection.blocked") -> true
).updated(ShortString("connection_name"), connectionName)
)
}
Expand Down
18 changes: 15 additions & 3 deletions modules/client/src/main/scala/internal/ConnectionState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ private[client] trait ConnectionState[F[_]] extends Signal[F, Status] {
def onCloseRequest: F[Unit]
def onCloseRequest(req: ConnectionClass.Close): F[Unit]
def onHeartbeat: F[Unit]
def onBlocked(msg: ShortString): F[Unit]
def onUnblocked: F[Unit]

def config: F[NegotiatedConfig]
def awaitOpened: F[Unit]
Expand Down Expand Up @@ -96,14 +98,24 @@ private[client] object ConnectionState {

override def onOpened: F[Unit] = hasOpened.complete(Right(())) *> underlying
.modify {
case Status.Connected => (Status.Opened, true)
case Status.Connected => (Status.Opened(), true)
case other => (other, false)
}
.ifM(F.unit, F.raiseError(new IllegalStateException))

override def onHeartbeat: F[Unit] = underlying.get.flatMap {
case Status.Opened => output.write(Frame.Heartbeat)
case _ => F.raiseError(new IllegalStateException)
case Status.Opened(_) => output.write(Frame.Heartbeat)
case _ => F.raiseError(new IllegalStateException)
}

override def onBlocked(msg: ShortString): F[Unit] = underlying.get.flatMap {
case Status.Opened(_) => underlying.set(Status.Opened(true))
case _ => F.raiseError(new IllegalStateException)
}

override def onUnblocked: F[Unit] = underlying.get.flatMap {
case Status.Opened(_) => underlying.set(Status.Opened(false))
case _ => F.raiseError(new IllegalStateException)
}

override def config: F[NegotiatedConfig] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,35 @@ class ConnectionReceiveSuite extends InternalTestSuite {
} yield ()
}
}

test("Dispatches connection blocked/unblocked notifications") {
val methods: Gen[ConnectionClass.Blocked | ConnectionClass.Unblocked.type] =
Gen.oneOf(
ConnectionDataGenerator.blockedGen,
ConnectionDataGenerator.unblockedGen
)

forAllF(methods) { method =>
val frame = Frame.Method(ChannelNumber(0), method)
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- FakeConnectionState(Status.Opened())
_ <- Stream(frame)
.through(Connection.receive(st, fd))
.compile
.drain
_ <- method match {
case ConnectionClass.Blocked(msg) =>
st.interactions.assertFirst(
FakeConnectionState.Interaction.Blocked(msg)
)
case ConnectionClass.Unblocked =>
st.interactions.assertFirst(
FakeConnectionState.Interaction.Unblocked
)
}
} yield ()
}
}
}
32 changes: 28 additions & 4 deletions modules/client/src/test/scala/internal/ConnectionStateSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ConnectionStateSuite extends InternalTestSuite {
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- s.awaitOpened
_ <- s.get.assertEquals(Status.Opened)
_ <- s.get.assertEquals(Status.Opened())
} yield ()
}

Expand All @@ -132,7 +132,7 @@ class ConnectionStateSuite extends InternalTestSuite {
_ <- s.awaitOpened.timeout(10.days).intercept[TimeoutException]
_ <- s.onOpened
_ <- s.awaitOpened
_ <- s.get.assertEquals(Status.Opened)
_ <- s.get.assertEquals(Status.Opened())
} yield ()
}

Expand Down Expand Up @@ -180,6 +180,30 @@ class ConnectionStateSuite extends InternalTestSuite {
} yield ()
}

test("Accept onBlocked/onUnblocked if opened") {
for {
s <- SUT
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- s.get.assertEquals(Status.Opened(false))
_ <- s.onBlocked(ShortString.empty)
_ <- s.get.assertEquals(Status.Opened(true))
_ <- s.onUnblocked
_ <- s.get.assertEquals(Status.Opened(false))
} yield ()
}

test("Ignores redundant onUnblocked if opened") {
for {
s <- SUT
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- s.get.assertEquals(Status.Opened(false))
_ <- s.onUnblocked
_ <- s.get.assertEquals(Status.Opened(false))
} yield ()
}

test("Accepts server close request if is opened") {
forAllF(ConnectionDataGenerator.closeGen) { close =>
for {
Expand All @@ -195,7 +219,7 @@ class ConnectionStateSuite extends InternalTestSuite {
Frame.Method(ChannelNumber(0), ConnectionClass.CloseOk)
)
)
_ <- s.get.assertEquals(Status.Opened)
_ <- s.get.assertEquals(Status.Opened())
} yield ()
}
}
Expand Down Expand Up @@ -223,7 +247,7 @@ class ConnectionStateSuite extends InternalTestSuite {
)
)
)
_ <- s.get.assertEquals(Status.Opened)
_ <- s.get.assertEquals(Status.Opened())
} yield ()
}
}
Expand Down
18 changes: 13 additions & 5 deletions modules/client/src/test/scala/internal/FakeConnectionState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ final class FakeConnectionState(
override def onFailed(ex: Throwable): IO[Unit] =
interactions.add(Interaction.Failed(ex)) >> state.set(Status.Closed)

override def onOpened: IO[Unit] = state.set(Status.Opened) >>
override def onOpened: IO[Unit] = state.set(Status.Opened()) >>
openedDef.complete(Right(())) >> interactions.add(Interaction.Opened)

override def config: IO[NegotiatedConfig] = connectedDef.get
Expand All @@ -84,6 +84,12 @@ final class FakeConnectionState(
override def onHeartbeat: IO[Unit] =
interactions.add(Interaction.Heartbeat) *> heartbeatError.run

override def onBlocked(msg: ShortString): IO[Unit] =
interactions.add(Interaction.Blocked(msg)) >> state.set(Status.Opened(true))

override def onUnblocked: IO[Unit] =
interactions.add(Interaction.Unblocked) >> state.set(Status.Opened(false))

def setAsWontOpen = openedDef.complete(Left(new Exception)).void
}

Expand All @@ -92,6 +98,8 @@ object FakeConnectionState {
case Connected(config: NegotiatedConfig)
case CloseRequest(close: ConnectionClass.Close)
case ClientCloseRequest, Opened, Closed, Heartbeat
case Blocked(msg: ShortString)
case Unblocked
case Failed(ex: Throwable)
}

Expand All @@ -110,10 +118,10 @@ object FakeConnectionState {
then connected.complete(defaultConfig).void
else IO.unit

_ <-
if currentState == Status.Opened
then opened.complete(Right(())).void
else IO.unit
_ <- currentState match {
case Status.Opened(_) => opened.complete(Right(())).void
case _ => IO.unit
}

} yield new FakeConnectionState(
interactions,
Expand Down

0 comments on commit 62a79c9

Please sign in to comment.