Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket client - use reactive streams #408

Merged
merged 13 commits into from
Feb 4, 2023
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ lazy val client = project
"org.http4s.netty.client.Http4sChannelPoolMap.this"),
ProblemFilters.exclude[MissingClassProblem]("org.http4s.netty.client.Http4sHandler$"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.http4s.netty.client.Http4sChannelPoolMap.resource")
"org.http4s.netty.client.Http4sChannelPoolMap.resource"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.http4s.netty.client.Http4sWebsocketHandler#Conn.this")
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import cats.effect.kernel.Deferred
import cats.effect.std.Dispatcher
import cats.effect.std.Queue
import cats.syntax.all._
import com.typesafe.netty.HandlerPublisher
import io.netty.buffer.Unpooled
import io.netty.channel._
import io.netty.handler.codec.http.websocketx._
Expand All @@ -31,6 +32,8 @@ import org.http4s.client.websocket.WSFrame
import org.http4s.netty.NettyModelConversion
import org.http4s.netty.client.Http4sWebsocketHandler.fromWSFrame
import org.http4s.netty.client.Http4sWebsocketHandler.toWSFrame
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import scodec.bits.ByteVector

import scala.concurrent.ExecutionContext
Expand All @@ -42,76 +45,110 @@ private[client] class Http4sWebsocketHandler[F[_]](
dispatcher: Dispatcher[F],
callback: (Either[Throwable, WSConnection[F]]) => Unit
)(implicit F: Async[F])
extends SimpleChannelInboundHandler[WebSocketFrame] {
extends SimpleUserEventChannelHandler[
WebSocketClientProtocolHandler.ClientHandshakeStateEvent] {
private val logger = org.log4s.getLogger
private var callbackIssued = false

override def channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame): Unit = {
logger.trace("got> " + msg.getClass)
void(msg match {
case frame: CloseWebSocketFrame =>
val op =
queue.offer(Right(toWSFrame(frame))) >> closed.complete(()) >> F.delay(ctx.close())
dispatcher.unsafeRunSync(op)
case frame: WebSocketFrame =>
val op = queue.offer(Right(toWSFrame(frame)))
dispatcher.unsafeRunSync(op)
})
override def channelActive(ctx: ChannelHandlerContext): Unit = void {
super.channelActive(ctx)
if (!ctx.channel().config().isAutoRead) {
ctx.read()
}
}

override def channelActive(ctx: ChannelHandlerContext): Unit =
logger.trace("channel active")
private def safeRunCallback(result: Either[Throwable, WSConnection[F]]): Unit =
if (!callbackIssued) {
callback(result)
callbackIssued = true
}

@SuppressWarnings(Array("deprecated"))
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = void {
logger.error(cause)("something failed")
callback(Left(cause))
safeRunCallback(Left(cause))
dispatcher.unsafeRunAndForget(
queue.offer(Left(cause)) >> closed.complete(()) >> F.delay(ctx.close()))
}

override def userEventTriggered(ctx: ChannelHandlerContext, evt: Any): Unit =
override def eventReceived(
ctx: ChannelHandlerContext,
evt: WebSocketClientProtocolHandler.ClientHandshakeStateEvent): Unit =
evt match {
case WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_ISSUED =>
logger.trace("Handshake issued")
case WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE =>
logger.trace("Handshake complete")
ctx.read()
callback(new Conn(handshaker.actualSubprotocol(), ctx, queue, closed).asRight[Throwable])
case _ =>
super.userEventTriggered(ctx, evt)

def complete =
closed.complete(()).void >> F.delay(ctx.close()).void

val publisher = new HandlerPublisher(ctx.executor(), classOf[WebSocketFrame]) {
override def requestDemand(): Unit = void {
if (!ctx.channel().config().isAutoRead) {
ctx.read()
}
}

override def cancelled(): Unit =
dispatcher.unsafeRunAndForget(complete)
}
ctx.pipeline().addBefore(ctx.name(), "stream-publisher", publisher)

publisher.subscribe(new Subscriber[WebSocketFrame] {

def isCloseFrame(ws: WSFrame) = ws.isInstanceOf[WSFrame.Close]

override def onSubscribe(s: Subscription): Unit =
s.request(Long.MaxValue)

override def onNext(t: WebSocketFrame): Unit = void {
val converted = toWSFrame(t)
val offer = queue.offer(Right(converted))
val op = if (isCloseFrame(converted)) {
complete >> offer
} else {
offer
}
dispatcher.unsafeRunSync(op)
}

override def onError(t: Throwable): Unit = void {
dispatcher.unsafeRunSync(complete >> queue.offer(Left(t)))
}

override def onComplete(): Unit = void {
dispatcher.unsafeRunSync(complete)
}
})
safeRunCallback(new Conn(handshaker.actualSubprotocol(), ctx).asRight[Throwable])

case WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT =>
safeRunCallback(
Left(new IllegalStateException("Handshake timeout"))
)
}

class Conn(
private class Conn(
sub: String,
ctx: ChannelHandlerContext,
queue: Queue[F, Either[Throwable, WSFrame]],
closed: Deferred[F, Unit])
extends WSConnection[F] {
ctx: ChannelHandlerContext
) extends WSConnection[F] {
private val runInNetty = F.evalOnK(ExecutionContext.fromExecutor(ctx.executor()))

override def send(wsf: WSFrame): F[Unit] = {
logger.trace(s"writing $wsf")
runInNetty(F.delay {
if (ctx.channel().isOpen && ctx.channel().isWritable) {
ctx.writeAndFlush(fromWSFrame(wsf))
()
}
})
}
override def send(wsf: WSFrame): F[Unit] =
sendMany(List(wsf))

override def sendMany[G[_], A <: WSFrame](wsfs: G[A])(implicit G: Foldable[G]): F[Unit] =
runInNetty(F.delay {
if (ctx.channel().isOpen && ctx.channel().isWritable) {
val list = wsfs.toList
list.foreach(wsf => ctx.write(fromWSFrame(wsf)))
ctx.flush()
}
()
})
if (ctx.channel().isActive) {
wsfs.traverse_(wsf => runInNetty(F.delay(ctx.writeAndFlush(fromWSFrame(wsf))).liftToF))
} else {
closed.complete(()).void
}

override def receive: F[Option[WSFrame]] = closed.tryGet.flatMap {
case Some(_) =>
logger.trace("closing")
ctx.close()
none[WSFrame].pure[F]
case None =>
queue.take.rethrow.map(_.some)
Expand All @@ -120,7 +157,7 @@ private[client] class Http4sWebsocketHandler[F[_]](
override def subprotocol: Option[String] = Option(sub)

def close: F[Unit] =
closed.complete(()).void >> F.delay(ctx.close).liftToF
F.unit
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class NettyWSClientBuilder[F[_]](
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Int.box(5 * 1000))
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
.option(ChannelOption.AUTO_READ, java.lang.Boolean.FALSE)
nettyChannelOptions.foldLeft(bootstrap) { case (boot, (opt, value)) =>
boot.option(opt, value)
}
Expand All @@ -127,7 +128,7 @@ class NettyWSClientBuilder[F[_]](

def resource: Resource[F, WSClient[F]] = for {
bs <- createBootstrap
disp <- Dispatcher.parallel[F]
disp <- Dispatcher.parallel[F](await = true)
} yield mkWSClient(bs, disp)

private def mkWSClient(bs: Bootstrap, dispatcher: Dispatcher[F]) =
Expand Down Expand Up @@ -185,8 +186,13 @@ class NettyWSClientBuilder[F[_]](
pipeline.addLast("http-aggregate", new HttpObjectAggregator(8192))
pipeline.addLast("protocol-handler", websocketinit)
pipeline.addLast(
"aggregate2",
"websocket-aggregate",
new WebSocketFrameAggregator(config.maxFramePayloadLength()))
if (idleTimeout.isFinite && idleTimeout.length > 0)
pipeline
.addLast(
"timeout",
new IdleStateHandler(0, 0, idleTimeout.length, idleTimeout.unit))
pipeline.addLast(
"websocket",
new Http4sWebsocketHandler[F](
Expand All @@ -196,11 +202,6 @@ class NettyWSClientBuilder[F[_]](
dispatcher,
callback)
)
if (idleTimeout.isFinite && idleTimeout.length > 0)
pipeline
.addLast(
"timeout",
new IdleStateHandler(0, 0, idleTimeout.length, idleTimeout.unit))
}
})
F.delay(bs.connect(socketAddress).sync()).as(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class EmberWebsocketTest extends IOSuite {
netty <- EmberServerBuilder
.default[IO]
.withHttpWebSocketApp(echoRoutes(_).orNotFound)
.withPort(port"19999")
.withShutdownTimeout(1.second)
.withPort(port"0")
.withShutdownTimeout(100.milli)
.build
.map(s => httpToWsUri(s.baseUri))
} yield netty,
Expand Down Expand Up @@ -68,8 +68,7 @@ class EmberWebsocketTest extends IOSuite {
WSFrame.Binary(ByteVector(3, 99, 12)),
WSFrame.Text("foo"),
WSFrame.Close(1000, "")
)
)
))
}

test("send and receive frames in high-level mode") {
Expand Down