Skip to content

Commit

Permalink
KTOR-6909 Add SocketTimeout for Test Engine (#4021)
Browse files Browse the repository at this point in the history
  • Loading branch information
marychatte authored and bjhham committed May 7, 2024
1 parent f8f8fe8 commit 8ca9632
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 31 deletions.
2 changes: 1 addition & 1 deletion ktor-http/common/src/io/ktor/http/auth/AuthScheme.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public object AuthScheme {
/**
* Bearer Authentication described in the RFC-6749 & RFC6750:
*
* see https://tools.ietf.org/html/rfc6749
* see https://tools.ietf.org/html/rfc6749
* & https://tools.ietf.org/html/rfc6750
*/
public const val Bearer: String = "Bearer"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@

package io.ktor.tests.server.testing

import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.config.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import io.ktor.server.websocket.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
Expand Down Expand Up @@ -283,6 +288,54 @@ class TestApplicationTestJvm {
}
assertEquals("WebSocket connection failed", error.message)
}

private fun testSocketTimeoutWrite(timeout: Long, expectException: Boolean) = testApplication {
routing {
post {
call.respond(HttpStatusCode.OK, call.request.receiveChannel().readRemaining().toString())
}
}

val clientWithTimeout = createClient {
install(HttpTimeout) {
socketTimeoutMillis = timeout
}
}

val body = object : OutgoingContent.WriteChannelContent() {
override suspend fun writeTo(channel: ByteWriteChannel) {
channel.writeAvailable("Hello".toByteArray())
channel.flush()
delay(300)
channel.writeAvailable("World".toByteArray())
channel.flush()
}
}

if (expectException) {
assertFailsWith<SocketTimeoutException> {
clientWithTimeout.post("/") {
setBody(body)
}
}
} else {
clientWithTimeout.post("/") {
setBody(body)
}.apply {
assertEquals(HttpStatusCode.OK, status)
}
}
}

@Test
fun testSocketTimeoutWriteElapsed() {
testSocketTimeoutWrite(100, true)
}

@Test
fun testSocketTimeoutWriteNotElapsed() {
testSocketTimeoutWrite(1000, false)
}
}

class TestClass(val value: Int) : Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.ktor.server.testing

import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.engine.*
Expand Down Expand Up @@ -203,14 +204,15 @@ class TestApplicationEngine(
setup: TestApplicationRequest.() -> Unit
): TestApplicationCall {
val callJob = GlobalScope.async(coroutineContext) {
handleRequestNonBlocking(closeRequest, setup)
handleRequestNonBlocking(closeRequest, timeoutAttributes = null, setup)
}

return runBlocking { callJob.await() }
}

internal suspend fun handleRequestNonBlocking(
closeRequest: Boolean = true,
timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration? = null,
setup: TestApplicationRequest.() -> Unit
): TestApplicationCall {
val job = Job(testEngineJob)
Expand All @@ -220,6 +222,9 @@ class TestApplicationEngine(
setup = { processRequest(setup) },
context = Dispatchers.IOBridge + job
)
if (timeoutAttributes != null) {
call.attributes.put(timeoutAttributesKey, timeoutAttributes)
}

val context = SupervisorJob(job) + CoroutineName("request")
withContext(coroutineContext + context) {
Expand Down Expand Up @@ -306,3 +311,5 @@ fun TestApplicationEngine.cookiesSession(callback: () -> Unit) {
callback()
}
}

internal val timeoutAttributesKey = AttributeKey<HttpTimeout.HttpTimeoutCapabilityConfiguration>("TimeoutAttributes")
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public class TestApplicationResponse(
call: TestApplicationCall,
private val readResponse: Boolean = false
) : BaseApplicationResponse(call), CoroutineScope by call {
private val scope: CoroutineScope get() = this

private val timeoutAttributes get() = call.attributes.getOrNull(timeoutAttributesKey)

/**
* Gets a response body text content. Could be blocking. Remains `null` until response appears.
Expand Down Expand Up @@ -76,16 +79,19 @@ public class TestApplicationResponse(
}

@Suppress("DEPRECATION")
@OptIn(DelicateCoroutinesApi::class)
override suspend fun responseChannel(): ByteWriteChannel {
val result = ByteChannel(autoFlush = true)

if (readResponse) {
launchResponseJob(result)
}

val job = GlobalScope.reader(responseJob ?: EmptyCoroutineContext) {
channel.copyAndClose(result, Long.MAX_VALUE)
val job = scope.reader(responseJob ?: EmptyCoroutineContext) {
val readJob = launch {
channel.copyAndClose(result, Long.MAX_VALUE)
}

configureSocketTimeoutIfNeeded(timeoutAttributes, readJob) { channel.totalBytesRead }
}

if (responseJob == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

package io.ktor.server.testing

import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.util.*
import io.ktor.utils.io.*
import kotlinx.coroutines.*

/**
* [on] function receiver object
Expand Down Expand Up @@ -35,3 +40,44 @@ fun TestApplicationResponse.contentType(): ContentType {
val contentTypeHeader = requireNotNull(headers[HttpHeaders.ContentType])
return ContentType.parse(contentTypeHeader)
}

internal fun CoroutineScope.configureSocketTimeoutIfNeeded(
timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration?,
job: Job,
extract: () -> Long
) {
val socketTimeoutMillis = timeoutAttributes?.socketTimeoutMillis
if (socketTimeoutMillis != null) {
socketTimeoutKiller(socketTimeoutMillis, job, extract)
}
}

internal fun CoroutineScope.socketTimeoutKiller(socketTimeoutMillis: Long, job: Job, extract: () -> Long) {
val killJob = launch {
var cur = extract()
while (job.isActive) {
delay(socketTimeoutMillis)
val next = extract()
if (cur == next) {
throw io.ktor.network.sockets.SocketTimeoutException("Socket timeout elapsed")
}
cur = next
}
}
job.invokeOnCompletion {
killJob.cancel()
}
}

@OptIn(InternalAPI::class)
internal fun Throwable.mapToKtor(data: HttpRequestData): Throwable {
return when {
this is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(data, this)
cause?.rootCause is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(
data,
cause?.rootCause
)

else -> this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.ktor.server.testing.client

import io.ktor.client.call.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
Expand Down Expand Up @@ -41,22 +42,26 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie

@OptIn(InternalAPI::class)
override suspend fun execute(data: HttpRequestData): HttpResponseData {
app.start()
if (data.isUpgradeRequest()) {
val (testServerCall, session) = with(data) {
bridge.runWebSocketRequest(url.fullPath, headers, body, callContext())
try {
app.start()
if (data.isUpgradeRequest()) {
val (testServerCall, session) = with(data) {
bridge.runWebSocketRequest(url.fullPath, headers, body, callContext())
}
return with(testServerCall.response) {
httpResponseData(session)
}
}
return with(testServerCall.response) {
httpResponseData(session)
}
}

val testServerCall = with(data) {
runRequest(method, url, headers, body, url.protocol)
}
val testServerCall = with(data) {
runRequest(method, url, headers, body, url.protocol, data.getCapabilityOrNull(HttpTimeout))
}

return with(testServerCall.response) {
httpResponseData(ByteReadChannel(byteContent ?: byteArrayOf()))
return with(testServerCall.response) {
httpResponseData(ByteReadChannel(byteContent ?: byteArrayOf()))
}
} catch (cause: Throwable) {
throw cause.mapToKtor(data)
}
}

Expand All @@ -65,17 +70,18 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie
url: Url,
headers: Headers,
content: OutgoingContent,
protocol: URLProtocol
protocol: URLProtocol,
timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration? = null
): TestApplicationCall {
return app.handleRequestNonBlocking {
return app.handleRequestNonBlocking(timeoutAttributes = timeoutAttributes) {
this.uri = url.fullPath
this.port = url.port
this.method = method
appendRequestHeaders(headers, content)
this.protocol = protocol.name

if (content !is OutgoingContent.NoContent) {
bodyChannel = content.toByteReadChannel()
bodyChannel = content.toByteReadChannel(timeoutAttributes)
}
}
}
Expand Down Expand Up @@ -112,14 +118,20 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie
}
}

@Suppress("DEPRECATION")
private fun OutgoingContent.toByteReadChannel(): ByteReadChannel = when (this) {
is OutgoingContent.NoContent -> ByteReadChannel.Empty
is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes())
is OutgoingContent.ReadChannelContent -> readFrom()
is OutgoingContent.WriteChannelContent -> writer(coroutineContext) {
writeTo(channel)
}.channel
is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this)
}
private fun OutgoingContent.toByteReadChannel(
timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration?
): ByteReadChannel =
when (this) {
is OutgoingContent.NoContent -> ByteReadChannel.Empty
is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes())
is OutgoingContent.ReadChannelContent -> readFrom()
is OutgoingContent.WriteChannelContent -> writer(coroutineContext) {
val job = launch {
writeTo(channel)
}

configureSocketTimeoutIfNeeded(timeoutAttributes, job) { channel.totalBytesWritten }
}.channel
is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.ktor.tests.server.testing

import io.ktor.client.*
import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
Expand All @@ -21,6 +22,7 @@ import io.ktor.server.testing.*
import io.ktor.server.testing.client.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.*
import kotlin.coroutines.*
import kotlin.test.*
Expand Down Expand Up @@ -398,6 +400,49 @@ class TestApplicationTest {
}
}

private fun testSocketTimeoutRead(timeout: Long, expectException: Boolean) = testApplication {
routing {
get {
call.respond(
HttpStatusCode.OK,
object : OutgoingContent.WriteChannelContent() {
override suspend fun writeTo(channel: ByteWriteChannel) {
channel.writeAvailable("Hello".toByteArray())
channel.flush()
delay(300)
}
}
)
}
}

val clientWithTimeout = createClient {
install(HttpTimeout) {
socketTimeoutMillis = timeout
}
}

if (expectException) {
assertFailsWith<SocketTimeoutException> {
clientWithTimeout.get("/")
}
} else {
clientWithTimeout.get("/").apply {
assertEquals(HttpStatusCode.OK, status)
}
}
}

@Test
fun testSocketTimeoutReadElapsed() {
testSocketTimeoutRead(100, true)
}

@Test
fun testSocketTimeoutReadNotElapsed() {
testSocketTimeoutRead(1000, false)
}

class MyElement(val data: String) : CoroutineContext.Element {
override val key: CoroutineContext.Key<*>
get() = MyElement
Expand Down

0 comments on commit 8ca9632

Please sign in to comment.