Skip to content

Commit

Permalink
Make sure to call handleServerError when the server fails to decode t…
Browse files Browse the repository at this point in the history
…he request entity
  • Loading branch information
julienrf committed Jun 20, 2023
1 parent a09e9ff commit 6a50949
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ trait EndpointsWithCustomErrors extends algebra.EndpointsWithCustomErrors with M
case NonFatal(t) => handleServerError(http4sRequest, t)
}
case Left(errorResponse) => errorResponse.pure[Effect]
})
}.recoverWith { case NonFatal(t) => handleServerError(http4sRequest, t) })
} catch {
case NonFatal(t) => Some(handleServerError(http4sRequest, t))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ trait MuxEndpoints extends algebra.MuxEndpoints with EndpointsWithCustomErrors {
case NonFatal(t) => handleServerError(http4sRequest, t)
}
case Left(errorResponse) => errorResponse.pure[Effect]
})
}.recoverWith { case NonFatal(t) => handleServerError(http4sRequest, t) })
} catch {
case NonFatal(t) => Some(handleServerError(http4sRequest, t))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package endpoints4s.http4s.server

import java.net.ServerSocket

import endpoints4s.{Invalid, Valid}
import endpoints4s.algebra.server.{
ChunkedJsonEntitiesTestSuite,
DecodedUrl,
EndpointsTestSuite
}
import org.http4s.server.Router
import org.http4s.{HttpRoutes, Uri}
import org.http4s.HttpRoutes

import akka.stream.scaladsl.Source

Expand All @@ -24,7 +22,8 @@ import org.http4s.blaze.server.BlazeServerBuilder
import cats.effect.IO

class ChunkedEntitiesServerInterpreterTest
extends EndpointsTestSuite[ChunkedEntitiesEndpointsTestApi]
extends Http4sServerTest[ChunkedEntitiesEndpointsTestApi]
with EndpointsTestSuite[ChunkedEntitiesEndpointsTestApi]
with ChunkedJsonEntitiesTestSuite[ChunkedEntitiesEndpointsTestApi] {

val serverApi = new ChunkedEntitiesEndpointsTestApi()
Expand Down Expand Up @@ -85,51 +84,4 @@ class ChunkedEntitiesServerInterpreterTest
()
}

override def serveEndpoint[Req, Resp](endpoint: serverApi.Endpoint[Req, Resp], response: => Resp)(
runTests: Int => Unit
): Unit = {
val port = {
val socket = new ServerSocket(0)
try socket.getLocalPort
finally if (socket != null) socket.close()
}
val service = HttpRoutes.of[IO](endpoint.implementedBy(_ => response))
val httpApp = Router("/" -> service).orNotFound
val server =
BlazeServerBuilder[IO]
.bindHttp(port, "localhost")
.withHttpApp(httpApp)
server.resource.use(_ => IO(runTests(port))).unsafeRunSync()
()
}

override def serveIdentityEndpoint[Resp](endpoint: serverApi.Endpoint[Resp, Resp])(
runTests: Int => Unit
): Unit = {
val port = {
val socket = new ServerSocket(0)
try socket.getLocalPort
finally if (socket != null) socket.close()
}
val service = HttpRoutes.of[IO](endpoint.implementedBy(identity[Resp]))
val httpApp = Router("/" -> service).orNotFound
val server =
BlazeServerBuilder[IO]
.bindHttp(port, "localhost")
.withHttpApp(httpApp)
server.resource.use(_ => IO(runTests(port))).unsafeRunSync()
()
}

def decodeUrl[A](url: serverApi.Url[A])(rawValue: String): DecodedUrl[A] = {
val uri =
Uri.fromString(rawValue).getOrElse(sys.error(s"Illegal URI: $rawValue"))

url.decodeUrl(uri) match {
case None => DecodedUrl.NotMatched
case Some(Invalid(errors)) => DecodedUrl.Malformed(errors)
case Some(Valid(a)) => DecodedUrl.Matched(a)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package endpoints4s.http4s.server

import akka.http.scaladsl.model.HttpMethods.PUT
import akka.http.scaladsl.model.HttpRequest
import cats.effect.IO
import endpoints4s.algebra
import org.http4s

import java.util.UUID

class ErrorHandlingTest extends Http4sServerTest[Endpoints[IO]] {

val serverApi = new Endpoints[IO]
with algebra.EndpointsTestApi {

private val magicValue = UUID.randomUUID().toString

// Pretend that we could not decode the incoming request
override def emptyRequest: http4s.Request[IO] => IO[Either[http4s.Response[IO], Unit]] =
_ => IO.raiseError(new RuntimeException(magicValue))

// Transform the error when the request could not be decoded
override def handleServerError(request: http4s.Request[IO], throwable: Throwable): IO[http4s.Response[IO]] = {
if (throwable.getMessage == magicValue)
IO.pure(http4s.Response(http4s.Status.PaymentRequired))
else
super.handleServerError(request, throwable)
}

}

"Server" should {
"call the hook handleServerError when a request fails to match the endpoints" in {
serveEndpoint(serverApi.putEndpoint, ()) { port =>
val request =
HttpRequest(method = PUT, uri = s"http://localhost:$port/user/foo123")
whenReady(send(request)) { case (response, _) =>
assert(response.status.intValue() == 402)
}
()
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package endpoints4s.http4s.server

import java.net.ServerSocket

import cats.effect.IO
import endpoints4s.{Invalid, Valid, algebra}
import endpoints4s.algebra.server.DecodedUrl
import org.http4s.Uri
import org.http4s.server.Router
import org.http4s.HttpRoutes
import org.http4s.blaze.server.BlazeServerBuilder

import cats.effect.unsafe.implicits.global

/**
* Base class for http4s server interpreters tests.
*/
trait Http4sServerTest[T <: Endpoints[IO]]
extends algebra.server.ServerTestBase[T] {

def decodeUrl[A](url: serverApi.Url[A])(rawValue: String): DecodedUrl[A] = {
val uri =
Uri.fromString(rawValue).getOrElse(sys.error(s"Illegal URI: $rawValue"))

url.decodeUrl(uri) match {
case None => DecodedUrl.NotMatched
case Some(Invalid(errors)) => DecodedUrl.Malformed(errors)
case Some(Valid(a)) => DecodedUrl.Matched(a)
}
}

private def serveGeneralEndpoint[Req, Resp](
endpoint: serverApi.Endpoint[Req, Resp],
request2response: Req => Resp
)(runTests: Int => Unit): Unit = {
val port = {
val socket = new ServerSocket(0)
try socket.getLocalPort
finally if (socket != null) socket.close()
}

val service = HttpRoutes.of[IO](endpoint.implementedBy(request2response))
val httpApp = Router("/" -> service).orNotFound
val server =
BlazeServerBuilder[IO]
.bindHttp(port, "localhost")
.withHttpApp(httpApp)
server.resource.use(_ => IO(runTests(port))).unsafeRunSync()
}

def serveEndpoint[Req, Resp](
endpoint: serverApi.Endpoint[Req, Resp],
response: => Resp
)(runTests: Int => Unit): Unit =
serveGeneralEndpoint(endpoint, (_: Any) => response)(runTests)

def serveIdentityEndpoint[Resp](
endpoint: serverApi.Endpoint[Resp, Resp]
)(runTests: Int => Unit): Unit =
serveGeneralEndpoint(endpoint, identity[Resp])(runTests)
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
package endpoints4s.http4s.server

import java.net.ServerSocket

import cats.effect.IO
import endpoints4s.{Invalid, Valid}
import endpoints4s.algebra.server.{
AssetsTestSuite,
BasicAuthenticationTestSuite,
DecodedUrl,
EndpointsTestSuite,
JsonEntitiesFromSchemasTestSuite,
SumTypedEntitiesTestSuite,
TextEntitiesTestSuite
}
import org.http4s.server.Router
import org.http4s.{HttpRoutes, Uri}
import org.http4s.blaze.server.BlazeServerBuilder

import endpoints4s.algebra.server.AssetsTestSuite
import cats.effect.unsafe.implicits.global

class ServerInterpreterTest
extends EndpointsTestSuite[EndpointsTestApi]
extends Http4sServerTest[EndpointsTestApi]
with EndpointsTestSuite[EndpointsTestApi]
with BasicAuthenticationTestSuite[EndpointsTestApi]
with JsonEntitiesFromSchemasTestSuite[EndpointsTestApi]
with TextEntitiesTestSuite[EndpointsTestApi]
Expand All @@ -30,36 +22,6 @@ class ServerInterpreterTest

val serverApi = new EndpointsTestApi

def decodeUrl[A](url: serverApi.Url[A])(rawValue: String): DecodedUrl[A] = {
val uri =
Uri.fromString(rawValue).getOrElse(sys.error(s"Illegal URI: $rawValue"))

url.decodeUrl(uri) match {
case None => DecodedUrl.NotMatched
case Some(Invalid(errors)) => DecodedUrl.Malformed(errors)
case Some(Valid(a)) => DecodedUrl.Matched(a)
}
}

private def serveGeneralEndpoint[Req, Resp](
endpoint: serverApi.Endpoint[Req, Resp],
request2response: Req => Resp
)(runTests: Int => Unit): Unit = {
val port = {
val socket = new ServerSocket(0)
try socket.getLocalPort
finally if (socket != null) socket.close()
}

val service = HttpRoutes.of[IO](endpoint.implementedBy(request2response))
val httpApp = Router("/" -> service).orNotFound
val server =
BlazeServerBuilder[IO]
.bindHttp(port, "localhost")
.withHttpApp(httpApp)
server.resource.use(_ => IO(runTests(port))).unsafeRunSync()
}

def assetsResources(pathPrefix: Option[String]) =
serverApi.assetsResources(pathPrefix)

Expand All @@ -70,16 +32,6 @@ class ServerInterpreterTest
],
response: => serverApi.AssetResponse
)(runTests: Int => Unit): Unit =
serveGeneralEndpoint(endpoint, (_: Any) => response)(runTests)

def serveEndpoint[Req, Resp](
endpoint: serverApi.Endpoint[Req, Resp],
response: => Resp
)(runTests: Int => Unit): Unit =
serveGeneralEndpoint(endpoint, (_: Any) => response)(runTests)
serveEndpoint(endpoint, response)(runTests)

def serveIdentityEndpoint[Resp](
endpoint: serverApi.Endpoint[Resp, Resp]
)(runTests: Int => Unit): Unit =
serveGeneralEndpoint(endpoint, identity[Resp])(runTests)
}

0 comments on commit 6a50949

Please sign in to comment.