Skip to content

Commit

Permalink
KTOR-5979 Fix frame flag defragmentation (#3703)
Browse files Browse the repository at this point in the history
  • Loading branch information
e5l committed Jul 26, 2023
1 parent b7f0564 commit 45b5415
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import io.ktor.server.testing.*
import io.ktor.server.websocket.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.bits.*
import io.ktor.utils.io.charsets.*
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
Expand Down Expand Up @@ -579,6 +580,49 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
}
}

@Test
open fun testFragmentedFlagsFromTheFirstFrame() = runTest {
val first = CompletableDeferred<Frame.Text>()
val second = CompletableDeferred<Frame.Text>()
createAndStartServer {
webSocket("/") {
val frame = incoming.receive()
assertIs<Frame.Text>(frame)
first.complete(frame)

val frame2 = incoming.receive()
assertIs<Frame.Text>(frame2)
second.complete(frame2)
}
}

useSocket {
negotiateHttpWebSocket()

output.apply {
repeat(2) {
writeFrameTest(Frame.Text(false, "Hello".toByteArray(), true, false, false), false)
writeFrameTest(Frame.Text(true, ", World".toByteArray(), false, false, false), false, opcode = 0)
}
writeFrameTest(Frame.Close(), false)
flush()
}
}

fun checkFrame(frame: Frame) {
assertIs<Frame.Text>(frame)
assertTrue(frame.fin)
assertTrue(frame.rsv1)
assertFalse(frame.rsv2)
assertFalse(frame.rsv3)

assertEquals("Hello, World", frame.readText())
}

checkFrame(first.await())
checkFrame(second.await())
}

private suspend fun Connection.negotiateHttpWebSocket() {
// send upgrade request
output.apply {
Expand Down Expand Up @@ -621,6 +665,7 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
if (replyCloseFrame) socket.close()
break@loop
}

else -> fail("Unexpected frame $frame: \n${hex(frame.data)}")
}
}
Expand Down Expand Up @@ -684,3 +729,53 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
}
}
}

internal suspend fun ByteWriteChannel.writeFrameTest(frame: Frame, masking: Boolean, opcode: Int? = null) {
val length = frame.data.size

val flagsAndOpcode = frame.fin.flagAt(7) or
frame.rsv1.flagAt(6) or
frame.rsv2.flagAt(5) or
frame.rsv3.flagAt(4) or
(opcode ?: frame.frameType.opcode)

writeByte(flagsAndOpcode.toByte())

val formattedLength = when {
length < 126 -> length
length <= 0xffff -> 126
else -> 127
}

val maskAndLength = masking.flagAt(7) or formattedLength

writeByte(maskAndLength.toByte())

when (formattedLength) {
126 -> writeShort(length.toShort())
127 -> writeLong(length.toLong())
}

val data = ByteReadPacket(frame.data)

val maskedData = when (masking) {
true -> {
val maskKey = Random.nextInt()
writeInt(maskKey)
data.mask(maskKey)
}
false -> data
}
writePacket(maskedData)
}

internal fun Boolean.flagAt(at: Int) = if (this) 1 shl at else 0

private fun ByteReadPacket.mask(maskKey: Int): ByteReadPacket = withMemory(4) { maskMemory ->
maskMemory.storeIntAt(0, maskKey)
buildPacket {
repeat(remaining.toInt()) { i ->
writeByte((readByte().toInt() xor (maskMemory[i % 4].toInt())).toByte())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ class TomcatWebSocketTest :
@Ignore
override fun testClientClosingFirst() {
}

@Ignore
override fun testFragmentedFlagsFromTheFirstFrame() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ internal class DefaultWebSocketSessionImpl(
private fun runIncomingProcessor(ponger: SendChannel<Frame.Ping>): Job = launch(
IncomingProcessorCoroutineName + Dispatchers.Unconfined
) {
var last: BytePacketBuilder? = null
var firstFrame: Frame? = null
var frameBody: BytePacketBuilder? = null
var closeFramePresented = false
try {
@OptIn(DelicateCoroutinesApi::class)
Expand All @@ -177,31 +178,37 @@ internal class DefaultWebSocketSessionImpl(
is Frame.Pong -> pinger.value?.send(frame)
is Frame.Ping -> ponger.send(frame)
else -> {
checkMaxFrameSize(last, frame)
checkMaxFrameSize(frameBody, frame)

if (!frame.fin) {
if (last == null) {
last = BytePacketBuilder()
if (firstFrame == null) {
firstFrame = frame
}
if (frameBody == null) {
frameBody = BytePacketBuilder()
}

frameBody!!.writeFully(frame.data)
return@consumeEach
}

last!!.writeFully(frame.data)
if (firstFrame == null) {
filtered.send(processIncomingExtensions(frame))
return@consumeEach
}

val frameToSend = last?.let { builder ->
builder.writeFully(frame.data)
Frame.byType(
fin = true,
frame.frameType,
builder.build().readBytes(),
frame.rsv1,
frame.rsv2,
frame.rsv3
)
} ?: frame

last = null
filtered.send(processIncomingExtensions(frameToSend))
frameBody!!.writeFully(frame.data)
val defragmented = Frame.byType(
fin = true,
firstFrame!!.frameType,
frameBody!!.build().readBytes(),
firstFrame!!.rsv1,
firstFrame!!.rsv2,
firstFrame!!.rsv3
)

firstFrame = null
filtered.send(processIncomingExtensions(defragmented))
}
}
}
Expand All @@ -211,7 +218,7 @@ internal class DefaultWebSocketSessionImpl(
filtered.close(cause)
} finally {
ponger.close()
last?.release()
frameBody?.release()
filtered.close()

if (!closeFramePresented) {
Expand Down

0 comments on commit 45b5415

Please sign in to comment.