Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[websocket] Add roles and path params to weBeforeUpgrade #2173

Merged
merged 1 commit into from Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -105,4 +105,4 @@ object DefaultTasks {

}

private fun Endpoint.hasPathParams() = this.path.contains("{") || this.path.contains("<")
internal fun Endpoint.hasPathParams() = this.path.contains("{") || this.path.contains("<")
Expand Up @@ -94,7 +94,9 @@ class JavalinServletContext(
handlerType = parsedEndpoint.endpoint.method
if (matchedPath != parsedEndpoint.endpoint.path) { // if the path has changed, we have to extract path params
matchedPath = parsedEndpoint.endpoint.path
pathParamMap = parsedEndpoint.extractPathParams(requestUri)
if (parsedEndpoint.endpoint.hasPathParams()) {
pathParamMap = parsedEndpoint.extractPathParams(requestUri)
}
}
if (handlerType != AFTER) {
endpointHandlerPath = parsedEndpoint.endpoint.path
Expand Down
Expand Up @@ -60,6 +60,7 @@ class JavalinJettyServlet(val cfg: JavalinConfig) : JettyWebSocketServlet() {
matchedPath = entry.path,
pathParamMap = entry.extractPathParams(requestUri),
)
upgradeContext.setRouteRoles(entry.roles) // set roles for the matched handler
req.setAttribute(upgradeContextKey, upgradeContext)
setWsProtocolHeader(req, res)
// add before handlers
Expand Down
18 changes: 18 additions & 0 deletions javalin/src/test/java/io/javalin/TestBeforeAfterMatched.kt
Expand Up @@ -8,6 +8,7 @@ import io.javalin.http.servlet.DefaultTasks.BEFORE
import io.javalin.http.servlet.DefaultTasks.ERROR
import io.javalin.http.servlet.DefaultTasks.HTTP
import io.javalin.http.staticfiles.Location
import io.javalin.security.RouteRole
import io.javalin.testing.TestUtil
import kong.unirest.HttpResponse
import org.assertj.core.api.Assertions.assertThat
Expand Down Expand Up @@ -338,4 +339,21 @@ class TestBeforeAfterMatched {
}) { _, http ->
assertThat(http.getBody("/p")).isEqualTo("{before=p}/{endpoint=p}")
}

private enum class Role : RouteRole { A }

@Test
fun `routeRoles are available in beforeMatched`() = TestUtil.test { app, http ->
app.beforeMatched { it.result(it.routeRoles().toString()) }
app.get("/test", {}, Role.A)
assertThat(http.getBody("/test")).isEqualTo("[A]")
}

@Test
fun `routeRoles are available in afterMatched`() = TestUtil.test { app, http ->
app.get("/test", {}, Role.A)
app.afterMatched { it.result(it.routeRoles().toString()) }
assertThat(http.getBody("/test")).isEqualTo("[A]")
}

}
150 changes: 60 additions & 90 deletions javalin/src/test/java/io/javalin/TestWebSocket.kt
Expand Up @@ -14,6 +14,7 @@ import io.javalin.http.HttpStatus
import io.javalin.http.UnauthorizedResponse
import io.javalin.json.toJsonString
import io.javalin.plugin.bundled.DevLoggingPlugin
import io.javalin.security.RouteRole
import io.javalin.testing.SerializableObject
import io.javalin.testing.TestUtil
import io.javalin.testing.TypedException
Expand Down Expand Up @@ -65,26 +66,6 @@ class TestWebSocket {
cfg?.invoke(it)
}

private fun accessManagedJavalin(): Javalin = Javalin.create().apply {
this.wsBeforeUpgrade { ctx ->
this.logger().log.add("handling upgrade request ...")
when {
ctx.queryParam("exception") == "true" -> throw UnauthorizedResponse()
ctx.queryParam("allowed") == "true" -> {
this.logger().log.add("upgrade request valid!")
return@wsBeforeUpgrade
}
else -> {
this.logger().log.add("upgrade request invalid!")
ctx.skipRemainingHandlers()
}
}
}
this.ws("/*") { ws ->
ws.onConnect { this.logger().log.add("connected with upgrade request") }
}
}

@Test
fun `each connection receives a unique id`() {
val logger = TestLogger()
Expand Down Expand Up @@ -368,6 +349,26 @@ class TestWebSocket {
}
}

private fun accessManagedJavalin(): Javalin = Javalin.create().apply {
this.wsBeforeUpgrade { ctx ->
this.logger().log.add("handling upgrade request ...")
when {
ctx.queryParam("exception") == "true" -> throw UnauthorizedResponse()
ctx.queryParam("allowed") == "true" -> {
this.logger().log.add("upgrade request valid!")
return@wsBeforeUpgrade
}
else -> {
this.logger().log.add("upgrade request invalid!")
ctx.skipRemainingHandlers()
}
}
}
this.ws("/*") { ws ->
ws.onConnect { this.logger().log.add("connected with upgrade request") }
}
}

@Test
fun `AccessManager rejects invalid request`() = TestUtil.test(accessManagedJavalin()) { app, _ ->
TestClient(app, "/").connectAndDisconnect()
Expand Down Expand Up @@ -490,22 +491,17 @@ class TestWebSocket {

@Test
fun `unmapped exceptions are caught by default handler`() = TestUtil.test { app, _ ->
val exception = Exception("Error message")

app.ws("/ws") { it.onConnect { throw exception } }

app.ws("/ws") { it.onConnect { throw Exception("EX") } }
val client = object : TestClient(app, "/ws") {
override fun onClose(status: Int, message: String, byRemote: Boolean) {
this.app.logger().log.add("Status code: $status")
this.app.logger().log.add("Reason: $message")
}
}

doBlocking({ client.connect() }, { !client.isClosed }) // hmmm

assertThat(client.app.logger().log).containsExactly(
"Status code: ${StatusCode.SERVER_ERROR}",
"Reason: ${exception.message}"
"Reason: EX"
)
}

Expand Down Expand Up @@ -593,7 +589,6 @@ class TestWebSocket {
client.disconnectBlocking()
}


@Test
@Timeout(value = 2, unit = TimeUnit.SECONDS)
fun `websocket disableAutomaticPings() works`() = TestUtil.test(pingingApp()) { app, _ ->
Expand All @@ -608,118 +603,93 @@ class TestWebSocket {
}

@Test
fun `wsBeforeUpgrade and wsAfterUpgrade are invoked`() = TestUtil.test { app, _ ->
app.wsBeforeUpgrade {
app.logger().log.add("before")
}

app.wsAfterUpgrade {
app.logger().log.add("after")
}

fun `wsBeforeUpgrade and wsAfterUpgrade are invoked`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { app.logger().log.add("before") }
app.wsAfterUpgrade { app.logger().log.add("after") }
app.ws("/ws") {}
Unirest.get("http://localhost:${app.port()}/ws")
.header(Header.SEC_WEBSOCKET_KEY, "not-null")
.asString()
http.wsUpgradeRequest("/ws")
assertThat(app.logger().log).containsExactly("before", "after")
}

@Test
fun `wsBeforeUpgrade can modify the upgrade request but wsAfterUpgrade can not`() = TestUtil.test { app, _ ->
app.wsBeforeUpgrade { ctx ->
ctx.header("X-Before", "demo")
}
app.wsAfterUpgrade { ctx ->
ctx.header("X-After", "after")
}
fun `wsBeforeUpgrade can modify the upgrade request but wsAfterUpgrade can not`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { it.header("X-Before", "demo") }
app.wsAfterUpgrade { it.header("X-After", "after") }

app.ws("/ws") {}
val response = Unirest.get("http://localhost:${app.port()}/ws")
.header(Header.SEC_WEBSOCKET_KEY, "not-null")
.asString()
val response = http.wsUpgradeRequest("/ws")
assertThat(response.headers.getFirst("X-Before")).isEqualTo("demo")
assertThat(response.headers.containsKey("X-After")).isFalse()
}

@Test
fun `wsBeforeUpgrade can stop an upgrade request in progress`() = TestUtil.test { app, _ ->
app.wsBeforeUpgrade { _ ->
throw IllegalStateException("denied")
}

fun `wsBeforeUpgrade can stop an upgrade request in progress`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { _ -> throw IllegalStateException("denied") }
app.ws("/ws") { ws ->
ws.onConnect {
app.logger().log.add("connected")
}
ws.onConnect { app.logger().log.add("connected") }
}

val response = Unirest.get("http://localhost:${app.port()}/ws")
.header(Header.SEC_WEBSOCKET_KEY, "not-null")
.asString()
val response = http.wsUpgradeRequest("/ws")
assertThat(response.status).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.code)
assertThat(app.logger().log).isEmpty()
}

@Test
fun `wsBeforeUpgrade exception pattern can be combined with a custom exception handler`() = TestUtil.test { app, _ ->
app.wsBeforeUpgrade {
throw IllegalStateException("denied")
}

fun `wsBeforeUpgrade exception pattern can be combined with a custom exception handler`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { throw IllegalStateException("denied") }
app.exception(IllegalStateException::class.java) { _, ctx ->
app.logger().log.add("exception handled")
ctx.status(HttpStatus.FORBIDDEN)
}

app.ws("/ws") {}

val response = Unirest.get("http://localhost:${app.port()}/ws")
.header(Header.SEC_WEBSOCKET_KEY, "not-null")
.asString()
val response = http.wsUpgradeRequest("/ws")
assertThat(app.logger().log).containsExactly("exception handled")
assertThat(response.status).isEqualTo(HttpStatus.FORBIDDEN.code)
}

@Test
fun `wsBeforeUpgrade does work with skipRemainingHandlers`() = TestUtil.test { app, _ ->
fun `wsBeforeUpgrade does work with skipRemainingHandlers`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { it.status(HttpStatus.FORBIDDEN).skipRemainingHandlers() }

app.ws("/ws") { ws ->
ws.onConnect {
app.logger().log.add("connected")
}
ws.onConnect { app.logger().log.add("connected") }
}

val client = TestClient(app, "/ws")
client.connectAndDisconnect()
val response = Unirest.get("http://localhost:${app.port()}/ws")
.header(Header.SEC_WEBSOCKET_KEY, "not-null")
.asString()
val response = http.wsUpgradeRequest("/ws")
assertThat(response.status).isEqualTo(HttpStatus.FORBIDDEN.code)
assertThat(app.logger().log).isEmpty()
}

@Test
fun `wsBeforeUpgrade in full lifecycle`() = TestUtil.test { app, _ ->
app.wsBeforeUpgrade {
app.logger().log.add("before-upgrade")
}

app.wsAfterUpgrade {
app.logger().log.add("after-upgrade")
}

app.wsBeforeUpgrade { app.logger().log.add("before-upgrade") }
app.wsAfterUpgrade { app.logger().log.add("after-upgrade") }
app.ws("/ws") { ws ->
ws.onConnect { app.logger().log.add("connect") }
ws.onMessage { app.logger().log.add("msg") }
ws.onClose { app.logger().log.add("close") }
}

val client = TestClient(app, "/ws")
client.connectSendAndDisconnect("test-message")
assertThat(app.logger().log).containsExactly("before-upgrade", "after-upgrade", "connect", "msg", "close")
}

private enum class Role : RouteRole { A }

@Test
fun `routeRoles are available in wsBeforeUpgrade`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { app.logger().log.add(it.routeRoles().toString()) }
app.ws("/ws", {}, Role.A)
http.wsUpgradeRequest("/ws")
assertThat(app.logger().log).containsExactly("[A]")
}

@Test
fun `pathParams are available in wsBeforeUpgrade`() = TestUtil.test { app, http ->
app.wsBeforeUpgrade { app.logger().log.add(it.pathParam("param")) }
app.ws("/ws/{param}") {}
http.wsUpgradeRequest("/ws/123")
assertThat(app.logger().log).containsExactly("123")
}

// ********************************************************************************************
// Helpers
Expand Down
3 changes: 2 additions & 1 deletion javalin/src/test/java/io/javalin/testing/HttpUtil.kt
Expand Up @@ -7,6 +7,7 @@
package io.javalin.testing

import io.javalin.http.ContentType
import io.javalin.http.Header
import io.javalin.http.HttpStatus
import kong.unirest.HttpMethod
import kong.unirest.HttpResponse
Expand All @@ -32,7 +33,7 @@ class HttpUtil(port: Int) {
fun htmlGet(path: String) = Unirest.get(origin + path).header("Accept", ContentType.HTML).asString()
fun jsonGet(path: String) = Unirest.get(origin + path).header("Accept", ContentType.JSON).asString()
fun sse(path: String) = Unirest.get(origin + path).header("Accept", "text/event-stream").header("Connection", "keep-alive").header("Cache-Control", "no-cache").asStringAsync()

fun wsUpgradeRequest(path: String) =Unirest.get(origin + path).header(Header.SEC_WEBSOCKET_KEY, "not-null").asString()
}

fun HttpResponse<*>.httpCode(): HttpStatus =
Expand Down