From 3d3987c44ab120430d143aca2c65411c1573abe0 Mon Sep 17 00:00:00 2001 From: Trol Date: Wed, 22 Apr 2020 21:39:18 +0800 Subject: [PATCH] (non-intrusive) Implement optional thread interrupt on coroutine cancellation (#57) This is implementation of issue #57 and non-intrusive variant of #1922 Signed-off-by: Trol --- .../api/kotlinx-coroutines-core.api | 4 + .../jvm/src/CancellationPoint.kt | 117 +++++++++++++++ .../InterruptibleCancellationPointTest.kt | 142 ++++++++++++++++++ 3 files changed, 263 insertions(+) create mode 100644 kotlinx-coroutines-core/jvm/src/CancellationPoint.kt create mode 100644 kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 54e355ec37..fc278f9182 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -86,6 +86,10 @@ public final class kotlinx/coroutines/CancellableContinuationKt { public static final fun suspendCancellableCoroutine (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class kotlinx/coroutines/CancellationPointKt { + public static final fun interruptible (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public abstract interface class kotlinx/coroutines/ChildHandle : kotlinx/coroutines/DisposableHandle { public abstract fun childCancelled (Ljava/lang/Throwable;)Z } diff --git a/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt b/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt new file mode 100644 index 0000000000..cd37b07bc1 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt @@ -0,0 +1,117 @@ +package kotlinx.coroutines + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn + +/** + * Makes a blocking code block cancellable (become a cancellation point of the coroutine). + * + * The blocking code block will be interrupted and this function will throw [CancellationException] + * if the coroutine is cancelled. + * + * Example: + * ``` + * GlobalScope.launch(Dispatchers.IO) { + * async { + * // This function will throw [CancellationException]. + * interruptible { + * doSomethingUseful() + * + * // This blocking procedure will be interrupted when this coroutine is canceled + * // by Exception thrown by the below async block. + * doSomethingElseUsefulInterruptible() + * } + * } + * + * async { + * delay(500L) + * throw Exception() + * } + * } + * ``` + */ +public suspend fun interruptible(block: () -> T): T = suspendCoroutineUninterceptedOrReturn sc@{ uCont -> + try { + // fast path: no job + val job = uCont.context[Job] ?: return@sc block() + // slow path + val threadState = ThreadState().apply { initInterrupt(job) } + try { + block() + } finally { + threadState.clearInterrupt() + } + } catch (e: InterruptedException) { + throw CancellationException() + } +} + +private class ThreadState { + + fun initInterrupt(job: Job) { + // starts with Init + if (state.value !== Init) throw IllegalStateException("impossible state") + // remembers this running thread + state.value = Working(Thread.currentThread(), null) + // watches the job for cancellation + val cancelHandle = + job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = CancelHandler()) + // remembers the cancel handle or drops it + state.loop { s -> + when { + s is Working -> if (state.compareAndSet(s, Working(s.thread, cancelHandle))) return + s === Interrupting || s === Interrupted -> return + s === Init || s === Finish -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + fun clearInterrupt() { + state.loop { s -> + when { + s is Working -> if (state.compareAndSet(s, Finish)) { s.cancelHandle!!.dispose(); return } + s === Interrupting -> Thread.yield() // eases the thread + s === Interrupted -> { Thread.interrupted(); return } // no interrupt leak + s === Init || s === Finish -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + private inner class CancelHandler : CompletionHandler { + override fun invoke(cause: Throwable?) { + state.loop { s -> + when { + s is Working -> { + if (state.compareAndSet(s, Interrupting)) { + s.thread!!.interrupt() + state.value = Interrupted + return + } + } + s === Finish -> return + s === Interrupting || s === Interrupted -> return + s === Init -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + } + + private val state: AtomicRef = atomic(Init) + + private interface State + // initial state + private object Init : State + // cancellation watching is setup and/or the continuation is running + private data class Working(val thread: Thread?, val cancelHandle: DisposableHandle?) : State + // the continuation done running without interruption + private object Finish : State + // interrupting this thread + private object Interrupting: State + // done interrupting + private object Interrupted: State +} diff --git a/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt b/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt new file mode 100644 index 0000000000..d6009c694f --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt @@ -0,0 +1,142 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import org.junit.Test +import java.io.IOException +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.assertEquals +import kotlin.test.assertFalse + +class InterruptibleCancellationPointTest: TestBase() { + + @Test + fun testNormalRun() = runBlocking { + var result = interruptible { + var x = doSomethingUsefulBlocking(1, 1) + var y = doSomethingUsefulBlocking(1, 2) + x + y + } + assertEquals(3, result) + } + + @Test + fun testInterrupt() { + val count = AtomicInteger(0) + try { + expect(1) + runBlocking { + launch(Dispatchers.IO) { + async { + try { + // `interruptible` makes a blocking block cancelable (become a cancellation point) + // by interrupting it on cancellation and throws CancellationException + interruptible { + try { + doSomethingUsefulBlocking(100, 1) + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + expect(3) + throw e + } + } + } catch (e: CancellationException) { + expect(4) + } + } + + async { + delay(500L) + expect(2) + throw IOException() + } + } + } + } catch (e: IOException) { + expect(5) + } + finish(6) + } + + @Test + fun testNoInterruptLeak() = runBlocking { + var interrupted = true + + var task = launch(Dispatchers.IO) { + try { + interruptible { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } + } finally { + interrupted = Thread.currentThread().isInterrupted + } + } + + delay(500) + task.cancel() + task.join() + assertFalse(interrupted) + } + + @Test + fun testStress() { + val REPEAT_TIMES = 2_000 + + Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher -> + val interruptLeak = AtomicBoolean(false) + val enterCount = AtomicInteger(0) + val interruptedCount = AtomicInteger(0) + val otherExceptionCount = AtomicInteger(0) + + runBlocking { + repeat(REPEAT_TIMES) { repeat -> + var job = launch(start = CoroutineStart.LAZY, context = dispatcher) { + try { + interruptible { + enterCount.incrementAndGet() + try { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + interruptedCount.incrementAndGet() + throw e + } + } + } catch (e: CancellationException) { + } catch (e: Throwable) { + otherExceptionCount.incrementAndGet() + } finally { + interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted) + } + } + + var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) { + job.cancel() + } + + launch (dispatcher) { + delay((REPEAT_TIMES - repeat).toLong()) + job.start() + } + + launch (dispatcher) { + delay(repeat.toLong()) + cancelJob.start() + } + } + } + + assertFalse(interruptLeak.get()) + assertEquals(enterCount.get(), interruptedCount.get()) + assertEquals(0, otherExceptionCount.get()) + } + } + + private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int { + Thread.sleep(timeUseMillis) + return result + } +}