From b3215e017a7aae588099690e41f534fe3d38ac59 Mon Sep 17 00:00:00 2001 From: "Ross A. Baker" Date: Mon, 22 Jun 2020 22:09:31 -0400 Subject: [PATCH] Raise throwables into body if streaming has begun --- .../client/asynchttpclient/AsyncHttpClient.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/async-http-client/src/main/scala/org/http4s/client/asynchttpclient/AsyncHttpClient.scala b/async-http-client/src/main/scala/org/http4s/client/asynchttpclient/AsyncHttpClient.scala index dcada008673..acbd0974ad6 100644 --- a/async-http-client/src/main/scala/org/http4s/client/asynchttpclient/AsyncHttpClient.scala +++ b/async-http-client/src/main/scala/org/http4s/client/asynchttpclient/AsyncHttpClient.scala @@ -22,18 +22,14 @@ import org.asynchttpclient.handler.StreamedAsyncHandler import org.asynchttpclient.request.body.generator.{BodyGenerator, ReactiveStreamsBodyGenerator} import org.asynchttpclient.{Request => AsyncRequest, Response => _, _} import org.http4s.internal.CollectionCompat.CollectionConverters._ -import org.http4s.internal.invokeCallback import org.http4s.internal.bug import org.http4s.internal.threads._ -import org.log4s.getLogger import org.reactivestreams.Publisher import _root_.io.netty.handler.codec.http.cookie.Cookie import org.asynchttpclient.uri.Uri import org.asynchttpclient.cookie.CookieStore object AsyncHttpClient { - private[this] val logger = getLogger - val defaultConfig = new DefaultAsyncHttpClientConfig.Builder() .setMaxConnectionsPerHost(200) .setMaxConnections(400) @@ -89,6 +85,7 @@ object AsyncHttpClient { var response: Response[F] = Response() val dispose = F.delay { state = State.ABORT } val onStreamCalled = Ref.unsafe[F, Boolean](false) + val deferredThrowable = Deferred.unsafe[F, Throwable] override def onStream(publisher: Publisher[HttpResponseBodyPart]): State = { val eff = for { @@ -106,6 +103,7 @@ object AsyncHttpClient { subscriber .stream(bodyDisposal.set(F.unit) >> subscribeF) .flatMap(part => chunk(Chunk.bytes(part.getBodyPartBytes))) + .mergeHaltBoth(Stream.eval(deferredThrowable.get.flatMap(F.raiseError[Byte]))) responseWithBody = response.copy(body = body) @@ -132,7 +130,12 @@ object AsyncHttpClient { } override def onThrowable(throwable: Throwable): Unit = - invokeCallback(logger)(cb(Left(throwable))) + onStreamCalled.get + .ifM( + ifTrue = deferredThrowable.complete(throwable), + ifFalse = invokeCallbackF(cb(Left(throwable)))) + .runAsync(_ => IO.unit) + .unsafeRunSync() override def onCompleted(): Unit = onStreamCalled.get