Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the case when a stream is closed while still having buffered data for write #330

Merged
merged 8 commits into from
Oct 11, 2023
16 changes: 16 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.etc.types

import kotlin.properties.Delegates
import kotlin.properties.ReadWriteProperty
import kotlin.reflect.KProperty

Expand Down Expand Up @@ -92,3 +93,18 @@ data class CappedValueDelegate<C : Comparable<C>>(
}
}
}

fun <T : Any> Delegates.writeOnce(initialValue: T): ReadWriteProperty<Any?, T> = object : ReadWriteProperty<Any?, T> {
private var value: T = initialValue
private var wasSet = false

public override fun getValue(thisRef: Any?, property: KProperty<*>): T {
return value
}

public override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) {
if (wasSet) throw IllegalStateException("Property ${property.name} cannot be set more than once.")
this.value = value
wasSet = true
}
}
1 change: 1 addition & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream w
class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null)

class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null)
class ClosedForWritingMuxerException(muxId: MuxId) : WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null)
23 changes: 16 additions & 7 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.types.writeOnce
import io.libp2p.etc.util.netty.ByteBufQueue
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.ClosedForWritingMuxerException
import io.libp2p.mux.InvalidFrameMuxerException
import io.libp2p.mux.MuxHandler
import io.libp2p.mux.UnknownStreamIdMuxerException
Expand All @@ -19,6 +21,7 @@ import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import kotlin.math.max
import kotlin.properties.Delegates

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB
Expand All @@ -39,6 +42,7 @@ open class YamuxHandler(
val sendWindowSize = AtomicInteger(initialWindowSize)
val receiveWindowSize = AtomicInteger(initialWindowSize)
val sendBuffer = ByteBufQueue()
var closedForWriting by Delegates.writeOnce(false)

fun dispose() {
sendBuffer.dispose()
Expand Down Expand Up @@ -72,7 +76,7 @@ open class YamuxHandler(
val delta = msg.length.toInt()
sendWindowSize.addAndGet(delta)
// try to send any buffered messages after the window update
drainBuffer()
drainBufferAndMaybeClose()
}

private fun handleFlags(msg: YamuxFrame) {
Expand All @@ -98,7 +102,7 @@ open class YamuxHandler(
}
}

private fun drainBuffer() {
private fun drainBufferAndMaybeClose() {
val maxSendLength = max(0, sendWindowSize.get())
val data = sendBuffer.take(maxSendLength)
sendWindowSize.addAndGet(-data.readableBytes())
Expand All @@ -107,11 +111,18 @@ open class YamuxHandler(
val length = slicedData.readableBytes()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
}

if (closedForWriting && sendBuffer.readableBytes() == 0) {
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
}
}

fun sendData(data: ByteBuf) {
if (closedForWriting) {
throw ClosedForWritingMuxerException(id)
}
fillBuffer(data)
drainBuffer()
drainBufferAndMaybeClose()
}

fun onLocalOpen() {
Expand All @@ -123,10 +134,8 @@ open class YamuxHandler(
}

fun onLocalDisconnect() {
// TODO: this implementation drops remaining data
drainBuffer()
sendBuffer.dispose()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
closedForWriting = true
drainBufferAndMaybeClose()
}

fun onLocalClose() {
Expand Down
43 changes: 43 additions & 0 deletions libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,49 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
msgPart3.data!!.release()
}

@Test
fun `local close for writing should flush buffered data and send close frame on writeWindow update`() {
val handler = openStreamLocal()
val muxId = readFrameOrThrow().streamId.toMuxId()

val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf())
// writing a message which is larger than sendWindowSize
handler.ctx.writeAndFlush(msg)

val msgPart1 = readYamuxFrameOrThrow()
assertThat(msgPart1.length).isEqualTo(256L)
assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256)
msgPart1.data!!.release()

val msgPart2 = readYamuxFrameOrThrow()
assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256)
assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256)
msgPart2.data!!.release()

// locally close for writing while some outbound data is still buffered
handler.ctx.disconnect()

// ACKing message receive
ech.writeInbound(
YamuxFrame(
muxId,
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
initialWindowSize.toLong()
)
)

val msgPart3 = readYamuxFrameOrThrow()
assertThat(msgPart3.length).isEqualTo(1L)
assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1)
msgPart3.data!!.release()

val closeFrame = readYamuxFrameOrThrow()
assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN)
assertThat(closeFrame.length).isEqualTo(0L)
assertThat(closeFrame.data).isNull()
}

companion object {
private fun YamuxStreamIdGenerator.toIterator() = iterator {
while (true) {
Expand Down