Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Yamux] Increase write buffer size and make it configurable #317

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ Add the library to the `dependencies` section of the pom file:
<groupId>io.libp2p</groupId>
<artifactId>jvm-libp2p</artifactId>
<version>X.Y.Z-RELEASE</version>
<type>pom</type>
</dependency>
```

Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import java.net.URL
// To publish the release artifact to CloudSmith repo run the following :
// ./gradlew publish -PcloudsmithUser=<user> -PcloudsmithApiKey=<api-key>

description = "an implementation of libp2p for the jvm"
description = "a libp2p implementation for the JVM, written in Kotlin"

plugins {
val kotlinVersion = "1.6.21"
Expand Down
22 changes: 15 additions & 7 deletions libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.libp2p.core.mux
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.multistream.ProtocolBinding
import io.libp2p.mux.mplex.MplexStreamMuxer
import io.libp2p.mux.yamux.DEFAULT_MAX_BUFFERED_CONNECTION_WRITES
import io.libp2p.mux.yamux.YamuxStreamMuxer

fun interface StreamMuxerProtocol {
Expand All @@ -20,14 +21,21 @@ fun interface StreamMuxerProtocol {
)
}

/**
* @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection before termination
*/
@JvmStatic
val Yamux = StreamMuxerProtocol { multistreamProtocol, protocols ->
YamuxStreamMuxer(
multistreamProtocol.createMultistream(
protocols
).toStreamHandler(),
multistreamProtocol
)
@JvmOverloads
fun getYamux(maxBufferedConnectionWrites: Int = DEFAULT_MAX_BUFFERED_CONNECTION_WRITES): StreamMuxerProtocol {
return StreamMuxerProtocol { multistreamProtocol, protocols ->
YamuxStreamMuxer(
multistreamProtocol.createMultistream(
protocols
).toStreamHandler(),
multistreamProtocol,
maxBufferedConnectionWrites
)
}
}
}
}
25 changes: 16 additions & 9 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024
const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB

open class YamuxHandler(
override val multistreamProtocol: MultistreamProtocol,
override val maxFrameDataLength: Int,
ready: CompletableFuture<StreamMuxer.Session>?,
inboundStreamHandler: StreamHandler<*>,
initiator: Boolean
initiator: Boolean,
private val maxBufferedConnectionWrites: Int
) : MuxHandler(ready, inboundStreamHandler) {
private val idGenerator = AtomicInteger(if (initiator) 1 else 2) // 0 is reserved
private val windowSizes = ConcurrentHashMap<MuxId, AtomicInteger>()
private val sendBuffers = ConcurrentHashMap<MuxId, SendBuffer>()

inner class SendBuffer(val id: MuxId, val ctx: ChannelHandlerContext) {
private inner class SendBuffer(val id: MuxId, val ctx: ChannelHandlerContext) {
private val bufferedData = ArrayDeque<ByteBuf>()

fun add(data: ByteBuf) {
Expand All @@ -44,19 +45,24 @@ open class YamuxHandler(
val data = bufferedData.first()
val length = data.readableBytes()
if (length <= windowSize.get()) {
sendBlocks(ctx, data, windowSize, id)
sendFrames(ctx, data, windowSize, id)
bufferedData.removeFirst()
} else {
// partial write to fit within window
val toRead = windowSize.get()
if (toRead > 0) {
val partialData = data.readRetainedSlice(toRead)
sendBlocks(ctx, partialData, windowSize, id)
sendFrames(ctx, partialData, windowSize, id)
}
break
}
}
}

fun close() {
bufferedData.forEach { it.release() }
bufferedData.clear()
}
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
Expand Down Expand Up @@ -149,23 +155,24 @@ open class YamuxHandler(
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id, ctx) }
buffer.add(data)
val totalBufferedWrites = calculateTotalBufferedWrites()
if (totalBufferedWrites > MAX_BUFFERED_CONNECTION_WRITES) {
if (totalBufferedWrites > maxBufferedConnectionWrites) {
buffer.close()
throw Libp2pException(
"Overflowed send buffer ($totalBufferedWrites/$MAX_BUFFERED_CONNECTION_WRITES) for connection ${
"Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${
ctx.channel().id().asLongText()
}"
)
}
return
}
sendBlocks(ctx, data, windowSize, child.id)
sendFrames(ctx, data, windowSize, child.id)
}

private fun calculateTotalBufferedWrites(): Int {
return sendBuffers.values.sumOf { it.bufferedBytes() }
}

fun sendBlocks(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) {
fun sendFrames(ctx: ChannelHandlerContext, data: ByteBuf, windowSize: AtomicInteger, id: MuxId) {
data.sliceMaxSize(minOf(windowSize.get(), maxFrameDataLength))
.map { slicedData ->
val length = slicedData.readableBytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import java.util.concurrent.CompletableFuture

class YamuxStreamMuxer(
val inboundStreamHandler: StreamHandler<*>,
private val multistreamProtocol: MultistreamProtocol
private val multistreamProtocol: MultistreamProtocol,
private val maxBufferedConnectionWrites: Int
) : StreamMuxer, StreamMuxerDebug {

override val protocolDescriptor = ProtocolDescriptor("/yamux/1.0.0")
Expand All @@ -30,7 +31,8 @@ class YamuxStreamMuxer(
yamuxFrameCodec.maxFrameDataLength,
muxSessionReady,
inboundStreamHandler,
ch.isInitiator
ch.isInitiator,
maxBufferedConnectionWrites
)
)

Expand Down
37 changes: 36 additions & 1 deletion libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.mux.yamux

import io.libp2p.core.Libp2pException
import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocolV1
import io.libp2p.etc.types.fromHex
Expand All @@ -8,13 +9,16 @@ import io.libp2p.mux.MuxHandler
import io.libp2p.mux.MuxHandlerAbstractTest
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
import io.libp2p.tools.readAllBytesAndRelease
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

class YamuxHandlerTest : MuxHandlerAbstractTest() {

override val maxFrameDataLength = 256
private val maxBufferedConnectionWrites = 512

private val readFrameQueue = ArrayDeque<AbstractTestMuxFrame>()

override fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler =
Expand All @@ -23,7 +27,8 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
maxFrameDataLength,
null,
streamHandler,
true
true,
maxBufferedConnectionWrites
) {
// MuxHandler consumes the exception. Override this behaviour for testing
@Deprecated("Deprecated in Java")
Expand Down Expand Up @@ -197,6 +202,36 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
assertThat(frame.data).isEqualTo("84")
}

@Test
fun `overflowing buffer throws an exception`() {
val handler = openStreamByLocal()
val streamId = readFrameOrThrow().streamId

ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
-INITIAL_WINDOW_SIZE.toLong()
)
)

val createMessage: () -> ByteBuf =
{ "42".repeat(maxBufferedConnectionWrites / 5).fromHex().toByteBuf(allocateBuf()) }

for (i in 1..5) {
val writeResult = handler.ctx.writeAndFlush(createMessage())
assertThat(writeResult.isSuccess).isTrue()
}

// next message will overflow the configured buffer
val writeResult = handler.ctx.writeAndFlush(createMessage())
assertThat(writeResult.isSuccess).isFalse()
assertThat(writeResult.cause())
.isInstanceOf(Libp2pException::class.java)
.hasMessage("Overflowed send buffer (612/512) for connection test")
}

@Test
fun `test ping`() {
val id: Long = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ val MultistreamProtocolV1: MultistreamProtocolDebug = MultistreamProtocolDebugV1
@Tag("secure-channel")
class TlsSecureChannelTest : SecureChannelTestBase(
::TlsSecureChannel,
listOf(StreamMuxerProtocol.Yamux.createMuxer(MultistreamProtocolV1, listOf())),
listOf(StreamMuxerProtocol.getYamux().createMuxer(MultistreamProtocolV1, listOf())),
TlsSecureChannel.announce
) {
@Test
Expand Down