Skip to content

Commit

Permalink
Use the fact that network headers specify the number of bytes in the … (
Browse files Browse the repository at this point in the history
#783)

* Use the fact that network headers specify the number of bytes in the payload rather than just parsing from bytes.size, this should allow us to be more precise when parsing NetworkPayloads rather than _hoping_ that bytes.size does not land on pseudo-valid NetworkPayload. This hopefully resolves #782

* Add safety check around HeadersMessage.toString()

* Add invariant to NetworkMessage saying payloadSize in header must be the actual payload size
  • Loading branch information
Christewart committed Oct 6, 2019
1 parent 5cc0b30 commit 09ea1fb
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 21 deletions.
Expand Up @@ -9,4 +9,21 @@ class NetworkMessageTest extends BitcoinSUnitTest {
NetworkMessage(NodeTestUtil.rawNetworkMessage).hex must be(
NodeTestUtil.rawNetworkMessage)
}


it must "serialize and deserialize a version message example from the bitcoin wiki" in {
val hex = {
//taken from here with slight modifications
//https://en.bitcoin.it/wiki/Protocol_documentation#Message_structure
//this example uses an old protocol version WITHOUT the relay flag on the version message
//since we only support protocol version > 7, i added it manually
//this means the payload size is bumped by 1 byte in the NetworkHeader from 100 -> 101
//and a relay byte "00" is appended to the end of the payload
"F9BEB4D976657273696F6E000000000065000000358d4932" +
"62EA0000010000000000000011B2D05000000000010000000000000000000000000000000000FFFF000000000000010000000000000000000000000000000000FFFF0000000000003B2EB35D8CE617650F2F5361746F7368693A302E372E322FC03E0300" +
"00"
}.toLowerCase
val networkMsg = NetworkMessage.fromHex(hex)
networkMsg.hex must be (hex)
}
}
Expand Up @@ -10,6 +10,8 @@ import scodec.bits.ByteVector
* Represents a P2P network message
*/
sealed abstract class NetworkMessage extends NetworkElement {
require(header.payloadSize.toInt == payload.bytes.length, s"Payload size is not what header says it is, " +
s"header.payloadSize=${header.payloadSize.toInt} actual=${payload.bytes.length}")
def header: NetworkHeader
def payload: NetworkPayload
override def bytes: ByteVector = RawNetworkMessageSerializer.write(this)
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/org/bitcoins/core/p2p/NetworkPayload.scala
Expand Up @@ -277,6 +277,15 @@ case class HeadersMessage(count: CompactSizeUInt, headers: Vector[BlockHeader])
override def commandName = NetworkPayload.headersCommandName

override def bytes: ByteVector = RawHeadersMessageSerializer.write(this)

override def toString(): String = {
if (headers.nonEmpty) {
s"HeadersMessage(${count},${headers.head.hashBE.hex}..${headers.last.hashBE.hex}"
} else {
super.toString
}

}
}

object HeadersMessage extends Factory[HeadersMessage] {
Expand Down Expand Up @@ -721,7 +730,7 @@ object PingMessage extends Factory[PingMessage] {
private case class PingMessageImpl(nonce: UInt64) extends PingMessage
override def fromBytes(bytes: ByteVector): PingMessage = {
val pingMsg = RawPingMessageSerializer.read(bytes)
PingMessageImpl(pingMsg.nonce)
pingMsg
}

def apply(nonce: UInt64): PingMessage = PingMessageImpl(nonce)
Expand Down Expand Up @@ -753,7 +762,7 @@ object PongMessage extends Factory[PongMessage] {

def fromBytes(bytes: ByteVector): PongMessage = {
val pongMsg = RawPongMessageSerializer.read(bytes)
PongMessageImpl(pongMsg.nonce)
pongMsg
}

def apply(nonce: UInt64): PongMessage = PongMessageImpl(nonce)
Expand Down
Expand Up @@ -8,9 +8,15 @@ trait RawNetworkMessageSerializer extends RawBitcoinSerializer[NetworkMessage] {

def read(bytes: ByteVector): NetworkMessage = {
//first 24 bytes are the header
val header = NetworkHeader(bytes.take(24))
val payload = NetworkPayload(header, bytes.slice(24, bytes.size))
NetworkMessage(header, payload)
val (headerBytes,payloadBytes) = bytes.splitAt(24)
val header = NetworkHeader.fromBytes(headerBytes)
if (header.payloadSize.toInt > payloadBytes.length) {
throw new RuntimeException(s"We do not have enough bytes for payload! Expected=${header.payloadSize.toInt} got=${payloadBytes.length}")
} else {
val payload = NetworkPayload(header, payloadBytes)
val n = NetworkMessage(header, payload)
n
}
}

def write(networkMessage: NetworkMessage): ByteVector = {
Expand Down
Expand Up @@ -21,6 +21,7 @@ trait RawNetworkHeaderSerializer
* @return the native object for the MessageHeader
*/
def read(bytes: ByteVector): NetworkHeader = {
require(bytes.length == 24, s"Got bytes.length=${bytes.length} when NetworkHeader expects 24 bytes")
val network = Networks.magicToNetwork(bytes.take(4))
//.trim removes the null characters appended to the command name
val commandName = bytes.slice(4, 16).toArray.map(_.toChar).mkString.trim
Expand Down
22 changes: 6 additions & 16 deletions node/src/main/scala/org/bitcoins/node/networking/P2PClient.scala
Expand Up @@ -377,22 +377,12 @@ object P2PClient extends P2PLogger {
val messageTry = Try(NetworkMessage(remainingBytes))
messageTry match {
case Success(message) =>
val expectedPayloadSize = message.header.payloadSize.toInt
val actualPayloadSize = message.payload.bytes.size
if (expectedPayloadSize != actualPayloadSize) {
//this means our tcp frame was not aligned, therefore put the message back in the
//buffer and wait for the remaining bytes
logger.trace(
s"TCP frame not aligned, payload sizes differed. Expected=$expectedPayloadSize, actual=$actualPayloadSize")
(accum.reverse, remainingBytes)
} else {
val newRemainingBytes = remainingBytes.slice(
message.bytes.length,
remainingBytes.length)
logger.trace(
s"Parsed a message=${message.header.commandName} from bytes, continuing with remainingBytes=${newRemainingBytes.length}")
loop(newRemainingBytes, message :: accum)
}
val newRemainingBytes = remainingBytes.slice(
message.bytes.length,
remainingBytes.length)
logger.trace(
s"Parsed a message=${message.header.commandName} from bytes, continuing with remainingBytes=${newRemainingBytes.length}")
loop(newRemainingBytes, message :: accum)
case Failure(exc) =>
logger.trace(
s"Failed to parse network message, could be because TCP frame isn't aligned: $exc")
Expand Down

0 comments on commit 09ea1fb

Please sign in to comment.