diff --git a/fs2/src/jsonrpclib/fs2/FS2Channel.scala b/fs2/src/jsonrpclib/fs2/FS2Channel.scala index 52b7ca5..15d2792 100644 --- a/fs2/src/jsonrpclib/fs2/FS2Channel.scala +++ b/fs2/src/jsonrpclib/fs2/FS2Channel.scala @@ -14,6 +14,7 @@ import jsonrpclib.internals.MessageDispatcher import jsonrpclib.internals._ import scala.util.Try +import _root_.fs2.concurrent.SignallingRef trait FS2Channel[F[_]] extends Channel[F] { def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, Unit] = @@ -21,6 +22,9 @@ trait FS2Channel[F[_]] extends Channel[F] { def withEndpoints(endpoint: Endpoint[F], rest: Endpoint[F]*)(implicit F: Monad[F]): Resource[F, Unit] = (endpoint :: rest.toList).traverse_(withEndpoint) + + def open: Resource[F, Unit] + def openStream: Stream[F, Unit] } object FS2Channel { @@ -28,23 +32,25 @@ object FS2Channel { def lspCompliant[F[_]: Concurrent]( byteStream: Stream[F, Byte], byteSink: Pipe[F, Byte, Nothing], - startingEndpoints: List[Endpoint[F]] = List.empty, bufferSize: Int = 512 ): Stream[F, FS2Channel[F]] = internals.LSP.writeSink(byteSink, bufferSize).flatMap { sink => - apply[F](internals.LSP.readStream(byteStream), sink, startingEndpoints) + apply[F](internals.LSP.readStream(byteStream), sink) } def apply[F[_]: Concurrent]( payloadStream: Stream[F, Payload], - payloadSink: Payload => F[Unit], - startingEndpoints: List[Endpoint[F]] = List.empty[Endpoint[F]] + payloadSink: Payload => F[Unit] ): Stream[F, FS2Channel[F]] = { - val endpointsMap = startingEndpoints.map(ep => ep.method -> ep).toMap for { supervisor <- Stream.resource(Supervisor[F]) - ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0)).toStream - impl = new Impl(payloadSink, ref, supervisor) - _ <- Stream(()).concurrently(payloadStream.evalMap(impl.handleReceivedPayload)) + ref <- Ref[F].of(State[F](Map.empty, Map.empty, 0)).toStream + isOpen <- SignallingRef[F].of(false).toStream + awaitingSink = isOpen.waitUntil(identity) >> payloadSink(_: Payload) + impl = new Impl(awaitingSink, ref, isOpen, supervisor) + _ <- Stream(()).concurrently { + // Gatekeeping the pull until the channel is actually marked as open + payloadStream.pauseWhen(isOpen.map(b => !b)).evalMap(impl.handleReceivedPayload) + } } yield impl } @@ -72,6 +78,7 @@ object FS2Channel { private class Impl[F[_]]( private val sink: Payload => F[Unit], private val state: Ref[F, FS2Channel.State[F]], + private val isOpen: SignallingRef[F, Boolean], supervisor: Supervisor[F] )(implicit F: Concurrent[F]) extends MessageDispatcher[F] @@ -88,6 +95,9 @@ object FS2Channel { def unmountEndpoint(method: String): F[Unit] = state.update(_.removeEndpoint(method)) + def open: Resource[F, Unit] = Resource.make[F, Unit](isOpen.set(true))(_ => isOpen.set(false)) + def openStream: Stream[F, Unit] = Stream.resource(open) + protected def background[A](fa: F[A]): F[Unit] = supervisor.supervise(fa).void protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ??? protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method)) diff --git a/fs2/src/jsonrpclib/fs2/package.scala b/fs2/src/jsonrpclib/fs2/package.scala index 8992858..c77c114 100644 --- a/fs2/src/jsonrpclib/fs2/package.scala +++ b/fs2/src/jsonrpclib/fs2/package.scala @@ -3,6 +3,8 @@ package jsonrpclib import _root_.fs2.Stream import cats.MonadThrow import cats.Monad +import cats.effect.kernel.Resource +import cats.effect.kernel.MonadCancel package object fs2 { @@ -10,6 +12,10 @@ package object fs2 { def toStream: Stream[F, A] = Stream.eval(fa) } + private[jsonrpclib] implicit class ResourceOps[F[_], A](private val fa: Resource[F, A]) extends AnyVal { + def asStream(implicit F: MonadCancel[F, Throwable]): Stream[F, A] = Stream.resource(fa) + } + implicit def catsMonadic[F[_]: MonadThrow]: Monadic[F] = new Monadic[F] { def doFlatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = Monad[F].flatMap(fa)(f) diff --git a/fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala b/fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala index 16a6e83..16db208 100644 --- a/fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala +++ b/fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala @@ -21,7 +21,7 @@ object FS2ChannelSpec extends SimpleIOSuite { } def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = - test(name)(run.compile.lastOrError) + test(name)(run.compile.lastOrError.timeout(10.second)) testRes("Round trip") { val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1))) @@ -31,8 +31,10 @@ object FS2ChannelSpec extends SimpleIOSuite { stdin <- Queue.bounded[IO, Payload](10).toStream serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer) clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer) - _ <- Stream.resource(serverSideChannel.withEndpoint(endpoint)) + _ <- serverSideChannel.withEndpoint(endpoint).asStream remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc") + _ <- serverSideChannel.open.asStream + _ <- clientSideChannel.open.asStream result <- remoteFunction(IntWrapper(1)).toStream } yield { expect.same(result, IntWrapper(2)) @@ -44,9 +46,11 @@ object FS2ChannelSpec extends SimpleIOSuite { for { stdout <- Queue.bounded[IO, Payload](10).toStream stdin <- Queue.bounded[IO, Payload](10).toStream - _ <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer) + serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer) clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer) remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc") + _ <- serverSideChannel.open.asStream + _ <- clientSideChannel.open.asStream result <- remoteFunction(IntWrapper(1)).attempt.toStream } yield { expect.same(result, Left(ErrorPayload(-32601, "Method inc not found", None))) @@ -65,8 +69,10 @@ object FS2ChannelSpec extends SimpleIOSuite { stdin <- Queue.bounded[IO, Payload](10).toStream serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload)) clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload)) - _ <- Stream.resource(serverSideChannel.withEndpoint(endpoint)) + _ <- serverSideChannel.withEndpoint(endpoint).asStream remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc") + _ <- serverSideChannel.open.asStream + _ <- clientSideChannel.open.asStream timedResults <- (1 to 10).toList.map(IntWrapper(_)).parTraverse(remoteFunction).timed.toStream } yield { val (time, results) = timedResults