Skip to content

Commit

Permalink
KTOR-4164 Fix ClassCastException when development mode is on (#3082)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsinukov committed Jun 27, 2022
1 parent 2fffabe commit 0f60e99
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
Expand Up @@ -36,11 +36,16 @@ public class ApplicationEngineEnvironmentReloading(
override val connectors: List<EngineConnectorConfig>,
internal val modules: List<Application.() -> Unit>,
internal val watchPaths: List<String> = emptyList(),
override val parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
override val rootPath: String = "",
override val developmentMode: Boolean = true
) : ApplicationEngineEnvironment {

override val parentCoroutineContext: CoroutineContext = when {
developmentMode -> parentCoroutineContext + ClassLoaderAwareContinuationInterceptor
else -> parentCoroutineContext
}

public constructor(
classLoader: ClassLoader,
log: Logger,
Expand Down Expand Up @@ -365,3 +370,20 @@ public class ApplicationEngineEnvironmentReloading(

public companion object
}

private object ClassLoaderAwareContinuationInterceptor : ContinuationInterceptor {
override val key: CoroutineContext.Key<*> =
object : CoroutineContext.Key<ClassLoaderAwareContinuationInterceptor> {}

override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> {
val classLoader = Thread.currentThread().contextClassLoader
return object : Continuation<T> {
override val context: CoroutineContext = continuation.context

override fun resumeWith(result: Result<T>) {
Thread.currentThread().contextClassLoader = classLoader
continuation.resumeWith(result)
}
}
}
}
Expand Up @@ -14,6 +14,9 @@ import io.ktor.server.routing.*
import io.ktor.server.testing.*
import io.ktor.server.websocket.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import java.io.*
import kotlin.coroutines.*
import kotlin.test.*
import io.ktor.client.plugins.websocket.WebSockets as ClientWebSockets

Expand Down Expand Up @@ -100,9 +103,9 @@ class TestApplicationTestJvm {

@Test
fun testExternalServicesCustomConfig() = testApplication {
environment {
config = ApplicationConfig("application-custom.conf")
}
environment {
config = ApplicationConfig("application-custom.conf")
}
externalServices {
hosts("http://www.google.com") {
val config = environment.config
Expand All @@ -119,9 +122,60 @@ class TestApplicationTestJvm {
assertEquals("another_test_value", external.bodyAsText())
}

@Test
fun testModuleWithLaunch() = testApplication {
var error: Throwable? = null
val exceptionHandler: CoroutineContext = object : CoroutineExceptionHandler {
override val key: CoroutineContext.Key<*> = CoroutineExceptionHandler.Key
override fun handleException(context: CoroutineContext, exception: Throwable) {
error = exception
}
}
environment {
parentCoroutineContext = exceptionHandler
}
application {
launch {
val byteArrayInputStream = ByteArrayOutputStream()
val objectOutputStream = ObjectOutputStream(byteArrayInputStream)
objectOutputStream.writeObject(TestClass(123))
objectOutputStream.flush()
objectOutputStream.close()

val ois = TestObjectInputStream(ByteArrayInputStream(byteArrayInputStream.toByteArray()))
val test = ois.readObject()
test as TestClass
}
}
routing {
get("/") {
call.respond("OK")
}
}

client.get("/")
Thread.sleep(3000)
assertNull(error)
}

public fun Application.module() {
routing {
get { call.respond("OK FROM MODULE") }
}
}
}

class TestClass(val value: Int) : Serializable

class TestObjectInputStream(input: InputStream) : ObjectInputStream(input) {
override fun resolveClass(desc: ObjectStreamClass?): Class<*> {
val name = desc?.name
val loader = Thread.currentThread().contextClassLoader

return try {
Class.forName(name, false, loader)
} catch (e: ClassNotFoundException) {
super.resolveClass(desc)
}
}
}

0 comments on commit 0f60e99

Please sign in to comment.