From 07967424598eb1abf8b0659fa9fbc52a6e224b5d Mon Sep 17 00:00:00 2001 From: borgespires Date: Mon, 27 Jan 2020 00:55:23 +0000 Subject: [PATCH] Fixes #2135, stop sending duplicate Sec-WebSocket-Protocol header in response when a subprotocol is defined --- .../gateway/AkkaHttpServiceGateway.scala | 14 ++++++- .../gateway/AkkaHttpServiceGatewaySpec.scala | 41 +++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/dev/service-registry/service-locator/src/main/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGateway.scala b/dev/service-registry/service-locator/src/main/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGateway.scala index 5f29322ff1..dabd988a5c 100644 --- a/dev/service-registry/service-locator/src/main/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGateway.scala +++ b/dev/service-registry/service-locator/src/main/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGateway.scala @@ -113,7 +113,7 @@ class AkkaHttpServiceGateway( Flow.fromSinkAndSource(Sink.fromSubscriber(subscriber), Source.fromPublisher(publisher)), chosenSubprotocol ) - webSocketResponse.withHeaders(webSocketResponse.headers ++ filterHeaders(response.headers)) + webSocketResponse.withHeaders(concatHeaders(response, webSocketResponse)) case InvalidUpgradeResponse(response, cause) => log.debug("WebSocket upgrade response was invalid: {}", cause) @@ -191,6 +191,18 @@ class AkkaHttpServiceGateway( headers.filterNot(header => HeadersToFilter(header.lowercaseName())) } + private def concatHeaders(serviceResponse: HttpResponse, webSocketClientResponse: HttpResponse) = { + def addHeader(res: immutable.Seq[HttpHeader], header: HttpHeader) = { + if (res.exists(other => other.lowercaseName().equals(header.lowercaseName()))) { + res + } else { + res :+ header + } + } + + filterHeaders(serviceResponse.headers).foldLeft(webSocketClientResponse.headers)(addHeader) + } + private val bindingFuture = Http().bindAndHandle(handler, config.host, config.port) coordinatedShutdown.addTask(CoordinatedShutdown.PhaseServiceUnbind, "unbind-akka-http-service-gateway") { () => diff --git a/dev/service-registry/service-locator/src/test/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGatewaySpec.scala b/dev/service-registry/service-locator/src/test/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGatewaySpec.scala index c44315d27d..abecb528cb 100644 --- a/dev/service-registry/service-locator/src/test/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGatewaySpec.scala +++ b/dev/service-registry/service-locator/src/test/scala/com/lightbend/lagom/gateway/AkkaHttpServiceGatewaySpec.scala @@ -14,11 +14,10 @@ import akka.http.scaladsl.model.ws.Message import akka.http.scaladsl.model.ws.TextMessage import akka.http.scaladsl.model.ws.UpgradeToWebSocket import akka.http.scaladsl.model.ws.WebSocketRequest -import akka.http.scaladsl.model.HttpEntity -import akka.http.scaladsl.model.HttpRequest -import akka.http.scaladsl.model.HttpResponse +import akka.http.scaladsl.model._ import akka.stream.ActorMaterializer import akka.stream.scaladsl.Flow +import akka.stream.scaladsl.Keep import akka.stream.scaladsl.Sink import akka.stream.scaladsl.Source import akka.util.ByteString @@ -54,7 +53,17 @@ class AkkaHttpServiceGatewaySpec extends WordSpec with Matchers with BeforeAndAf case req if req.uri.path.toString() == "/echo-headers" => HttpResponse(entity = HttpEntity(req.headers.map(h => s"${h.name()}: ${h.value}").mkString("\n"))) case stream if stream.uri.path.toString() == "/stream" => - stream.header[UpgradeToWebSocket].get.handleMessages(Flow[Message]) + stream + .header[UpgradeToWebSocket] + .get + .handleMessages( + Flow[Message], + stream.headers + .find(_.lowercaseName() == "sec-websocket-protocol") + .map(_.value) + .map(_.split(",")) + .map(_.head) + ) }, "localhost", port = 0 @@ -118,6 +127,30 @@ class AkkaHttpServiceGatewaySpec extends WordSpec with Matchers with BeforeAndAf (result should contain).inOrderOnly("Hello", "world") } + "serve websocket requests with the correct response" in { + val flow = http.webSocketClientFlow( + WebSocketRequest( + s"$gatewayWsUri/stream", + collection.immutable.Seq.empty[HttpHeader], + collection.immutable.Seq("echo") + ) + ) + val result = Await.result( + Source + .single(TextMessage("hello world!")) + .viaMat(flow)(Keep.right) + .to(Sink.ignore) + .run(), + 10.seconds + ) + + result.response.status should equal(StatusCodes.SwitchingProtocols) + result.response.headers.count(_.lowercaseName() == "sec-websocket-protocol") should equal(1) + result.response.headers.find(_.lowercaseName() == "sec-websocket-protocol").map(_.value()).get should equal( + "echo" + ) + } + "serve not found when no ACL matches" in { val response = Await.result(http.singleRequest(HttpRequest(uri = s"$gatewayUri/notfound")), 10.seconds) response.status.intValue() should ===(404)