Skip to content

Commit

Permalink
Merge pull request #90 from hnaderi/69-connection-termination-must-in…
Browse files Browse the repository at this point in the history
…terrupt-all-channels

Connection close terminates all channels
  • Loading branch information
hnaderi committed Feb 21, 2023
2 parents f46b792 + 35a7bb7 commit 5051b06
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 18 deletions.
2 changes: 1 addition & 1 deletion modules/client/src/main/scala/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object Connection {
).toResource
dispatcher <- FrameDispatcher[F].toResource
output <- OutputWriter(sendQ.offer).toResource
state <- ConnectionState(output, path).toResource
state <- ConnectionState(output, dispatcher, path).toResource
newChannel = ChannelBuilder(
output,
state,
Expand Down
2 changes: 2 additions & 0 deletions modules/client/src/main/scala/internal/ConnectionState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private[client] trait ConnectionState[F[_]] extends Signal[F, Status] {
private[client] object ConnectionState {
def apply[F[_]](
output: OutputWriter[F, Frame],
dispatcher: FrameDispatcher[F],
path: Path = Path("/")
)(using F: Concurrent[F]): F[ConnectionState[F]] = for {
underlying <- SignallingRef[F, Status](Status.Connecting)
Expand Down Expand Up @@ -87,6 +88,7 @@ private[client] object ConnectionState {
hasOpened.complete(Left(TerminalState)) *>
configDef.complete(Left(TerminalState)) *>
output.onClose *>
dispatcher.onClose *>
underlying.set(Status.Closed)

override def onOpened: F[Unit] = hasOpened.complete(Right(())) *> underlying
Expand Down
5 changes: 5 additions & 0 deletions modules/client/src/main/scala/internal/FrameDispatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ private[client] trait FrameDispatcher[F[_]] {
def header(h: Frame.Header): F[Unit]
def body(b: Frame.Body): F[Unit]
def invoke(m: Frame.Method): F[Unit]
def onClose: F[Unit]

def add[CHANNEL <: ChannelReceiver[F]](
build: ChannelNumber => Resource[F, CHANNEL]
Expand All @@ -50,6 +51,10 @@ private[client] object FrameDispatcher {
def apply[F[_]](using F: Concurrent[F]): F[FrameDispatcher[F]] = for {
state <- SignallingRef[F].of(State[F]())
} yield new {

def onClose: F[Unit] =
state.get.map(_.channels.values.toList).flatMap(_.traverse_(_.onClose))

def header(h: Frame.Header): F[Unit] =
call(h.channel)(_.header(h))

Expand Down
3 changes: 2 additions & 1 deletion modules/client/src/main/scala/internal/LowlevelChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ private[client] trait ChannelReceiver[F[_]] {
def header(h: Frame.Header): F[Unit]
def body(h: Frame.Body): F[Unit]
def method(m: Method): F[Unit]
def onClose: F[Unit]
}

private[client] trait ChannelTransmitter[F[_]] {
Expand Down Expand Up @@ -94,7 +95,7 @@ private[client] object LowlevelChannel {
state <- SignallingRef[F].of(Status.Active)
} yield new LowlevelChannel[F] {

private def onClose: F[Unit] =
override def onClose: F[Unit] =
state.set(Status.Closed) >> rpc.sendNoWait(ChannelClass.CloseOk)

private def handle(f: F[Unit]) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ConnectionReceiveSuite extends InternalTestSuite {
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- ConnectionState[IO](output)
st <- ConnectionState(output, fd)
_ <- Stream.empty.through(Connection.receive(st, fd)).compile.drain
_ <- st.get.assertEquals(Status.Closed)
} yield ()
Expand All @@ -54,7 +54,7 @@ class ConnectionReceiveSuite extends InternalTestSuite {
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- ConnectionState[IO](output)
st <- ConnectionState(output, fd)
_ <- Stream
.raiseError(new Exception)
.through(Connection.receive(st, fd))
Expand All @@ -70,7 +70,7 @@ class ConnectionReceiveSuite extends InternalTestSuite {
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- ConnectionState[IO](output)
st <- ConnectionState(output, fd)
_ <- st.onConnected(config)
_ <- st.onOpened
_ <- Stream(Frame.Heartbeat)
Expand All @@ -90,7 +90,7 @@ class ConnectionReceiveSuite extends InternalTestSuite {
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- ConnectionState[IO](output)
st <- ConnectionState(output, fd)
_ <- st.onConnected(config)
_ <- st.onOpened
_ <- Stream(frame)
Expand All @@ -109,7 +109,7 @@ class ConnectionReceiveSuite extends InternalTestSuite {
for {
fd <- FakeFrameDispatcher()
output <- FakeFrameOutput()
st <- ConnectionState[IO](output)
st <- ConnectionState(output, fd)
_ <- st.onConnected(config)
_ <- st.onOpened
_ <- Stream(method)
Expand Down
34 changes: 27 additions & 7 deletions modules/client/src/test/scala/internal/ConnectionStateSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ import scala.concurrent.duration.*
import Connection.Status

class ConnectionStateSuite extends InternalTestSuite {
private val SUT =
OutputWriter[IO, Frame](_ => IO.unit).flatMap(ConnectionState(_))
private val SUT = for {
out <- OutputWriter[IO, Frame](_ => IO.unit)
fd <- FrameDispatcher[IO]
state <- ConnectionState(out, fd)
} yield state

private val config = NegotiatedConfig(1, 2, 3)

test("Initial state is connecting") {
Expand All @@ -58,7 +62,8 @@ class ConnectionStateSuite extends InternalTestSuite {
forAllF(DomainGenerators.path) { vhost =>
for {
sent <- FakeFrameOutput()
s <- ConnectionState(sent, vhost)
fd <- FrameDispatcher[IO]
s <- ConnectionState(sent, fd, vhost)
_ <- s.onConnected(config)
_ <- s.config.assertEquals(config)
_ <- s.get.assertEquals(Status.Connected)
Expand Down Expand Up @@ -168,7 +173,8 @@ class ConnectionStateSuite extends InternalTestSuite {
forAllF(ConnectionDataGenerator.closeGen) { close =>
for {
sent <- FakeFrameOutput()
s <- ConnectionState(sent)
fd <- FrameDispatcher[IO]
s <- ConnectionState(sent, fd)
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- sent.interactions.reset
Expand All @@ -187,7 +193,8 @@ class ConnectionStateSuite extends InternalTestSuite {
forAllF(ConnectionDataGenerator.closeGen) { close =>
for {
sent <- FakeFrameOutput()
s <- ConnectionState(sent)
fd <- FrameDispatcher[IO]
s <- ConnectionState(sent, fd)
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- sent.interactions.reset
Expand All @@ -213,7 +220,8 @@ class ConnectionStateSuite extends InternalTestSuite {
test("Responds to heartbeats if is opened") {
for {
sent <- FakeFrameOutput()
s <- ConnectionState(sent)
fd <- FrameDispatcher[IO]
s <- ConnectionState(sent, fd)
_ <- s.onConnected(config)
_ <- s.onOpened
_ <- sent.interactions.reset
Expand All @@ -235,9 +243,21 @@ class ConnectionStateSuite extends InternalTestSuite {
test("Output terminates after getting closed") {
for {
sent <- FakeFrameOutput()
s <- ConnectionState(sent)
fd <- FrameDispatcher[IO]
s <- ConnectionState(sent, fd)
_ <- s.onClosed
_ <- sent.interactions.assert(FakeFrameOutput.Interaction.Closed)
} yield ()
}

test("Frame dispatcher terminates after getting closed") {
for {
sent <- FakeFrameOutput()
fd <- FakeFrameDispatcher()
s <- ConnectionState(sent, fd)
_ <- fd.assertOpen
_ <- s.onClosed
_ <- fd.assertClosed
} yield ()
}
}
12 changes: 10 additions & 2 deletions modules/client/src/test/scala/internal/FakeFrameDispatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import cats.effect.IO
import cats.effect.kernel.Ref
import cats.effect.kernel.Resource
import cats.effect.std.Queue
import cats.syntax.all.*
import fs2.concurrent.Signal
import lepus.protocol.*
import lepus.protocol.domains.ChannelNumber
Expand All @@ -32,9 +33,12 @@ import Frame.*

final class FakeFrameDispatcher(
val dispatched: InteractionList[Frame],
error: Option[Exception]
error: Option[Exception],
closed: Ref[IO, Boolean]
) extends FrameDispatcher[IO] {

override def onClose: IO[Unit] = closed.set(true)

override def body(b: Body): IO[Unit] = dispatch(b)

override def invoke(m: Frame.Method): IO[Unit] = dispatch(m)
Expand All @@ -50,9 +54,13 @@ final class FakeFrameDispatcher(

private def dispatch(frame: Frame) =
dispatched.add(frame) >> error.fold(IO.unit)(IO.raiseError)

def assertClosed: IO[Unit] = closed.get.assert
def assertOpen: IO[Unit] = closed.get.map(!_).assert
}

object FakeFrameDispatcher {
def apply(error: Option[Exception] = None): IO[FakeFrameDispatcher] =
InteractionList[Frame].map(new FakeFrameDispatcher(_, error))
(InteractionList[Frame], IO.ref(false))
.mapN(new FakeFrameDispatcher(_, error, _))
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import fs2.Stream
import fs2.concurrent.Signal
import fs2.concurrent.SignallingRef
import lepus.client.Channel.Status
import lepus.client.Confirmation
import lepus.client.internal.FakeLowLevelChannel.Interaction
import lepus.codecs.ConnectionDataGenerator
import lepus.codecs.DomainGenerators
Expand All @@ -34,7 +35,6 @@ import lepus.protocol.BasicClass.Publish
import lepus.protocol.*
import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*
import lepus.client.Confirmation

final class FakeLowLevelChannel(
val interactions: InteractionList[Interaction],
Expand All @@ -44,6 +44,8 @@ final class FakeLowLevelChannel(
private def call[T](f: LowlevelChannel[IO] => IO[T]) = channel.get.flatMap(f)
private val channelS = Stream.eval(channel.get)

override def onClose: IO[Unit] = call(_.onClose)

override def asyncContent(m: ContentMethod): IO[Unit] = call(
_.asyncContent(m)
)
Expand Down
5 changes: 4 additions & 1 deletion modules/client/src/test/scala/internal/FakeReceiver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import lepus.protocol.constants.ReplyCode
import lepus.protocol.domains.*
import munit.CatsEffectSuite
import munit.ScalaCheckSuite
import munit.CatsEffectAssertions.*
import org.scalacheck.effect.PropF.forAllF
import scodec.bits.ByteVector

Expand All @@ -46,7 +47,7 @@ final class FakeReceiver(
def body(h: Frame.Body): IO[Unit] = interact(Interaction.Body(h))
def method(m: Method): IO[Unit] = interact(Interaction.Method(m))

def close: IO[Unit] = interact(Interaction.Close).void
def onClose: IO[Unit] = interact(Interaction.Close).void

private def interact(i: Interaction): IO[Unit] =
interactionList.update(_.prepended(i)) >> error.get.flatMap(
Expand All @@ -58,6 +59,8 @@ final class FakeReceiver(

def interactions: IO[List[Interaction]] = interactionList.get
def lastInteraction: IO[Option[Interaction]] = interactions.map(_.headOption)
def assertClosed: IO[Unit] =
lastInteraction.assertEquals(Some(Interaction.Close))
}

object FakeReceiver {
Expand Down
14 changes: 14 additions & 0 deletions modules/client/src/test/scala/internal/FrameDispatcherSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,18 @@ class FrameDispatcherSuite extends InternalTestSuite {
} yield ()
}
}

test("Must close all channels on close") {
for {
fd <- FrameDispatcher[IO]
channels <- FakeReceiver().replicateA(10)
_ <- channels
.traverse(ch => fd.add(_ => Resource.pure(ch)))
.use(all =>
all.flatTraverse(_.interactions).assertEquals(Nil) >>
fd.onClose >>
all.traverse_(_.assertClosed)
)
} yield ()
}
}

0 comments on commit 5051b06

Please sign in to comment.