Skip to content

Commit

Permalink
Merge pull request #2924 from cfromknecht/write-and-flush
Browse files Browse the repository at this point in the history
peer: resume partial writes due to timeouts
  • Loading branch information
Roasbeef committed Apr 27, 2019
2 parents 0393793 + e3a6de1 commit 42b081b
Show file tree
Hide file tree
Showing 5 changed files with 405 additions and 95 deletions.
37 changes: 34 additions & 3 deletions brontide/conn.go
Expand Up @@ -166,7 +166,11 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// If the message doesn't require any chunking, then we can go ahead
// with a single write.
if len(b) <= math.MaxUint16 {
return len(b), c.noise.WriteMessage(c.conn, b)
err = c.noise.WriteMessage(b)
if err != nil {
return 0, err
}
return c.noise.Flush(c.conn)
}

// If we need to split the message into fragments, then we'll write
Expand All @@ -185,16 +189,43 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// Slice off the next chunk to be written based on our running
// counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(c.conn, chunk); err != nil {
if err := c.noise.WriteMessage(chunk); err != nil {
return bytesWritten, err
}

bytesWritten += len(chunk)
n, err := c.noise.Flush(c.conn)
bytesWritten += n
if err != nil {
return bytesWritten, err
}
}

return bytesWritten, nil
}

// WriteMessage encrypts and buffers the next message p for the connection. The
// ciphertext of the message is prepended with an encrypt+auth'd length which
// must be used as the AD to the AEAD construction when being decrypted by the
// other side.
//
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (c *Conn) WriteMessage(b []byte) error {
return c.noise.WriteMessage(b)
}

// Flush attempts to write a message buffered using WriteMessage to the
// underlying connection. If no buffered message exists, this will result in a
// NOP. Otherwise, it will continue to write the remaining bytes, picking up
// where the byte stream left off in the event of a partial write. The number of
// bytes returned reflects the number of plaintext bytes in the payload, and
// does not account for the overhead of the header or MACs.
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (c *Conn) Flush() (int, error) {
return c.noise.Flush(c.conn)
}

// Close closes the connection. Any blocked Read or Write operations will be
// unblocked and return errors.
//
Expand Down
129 changes: 113 additions & 16 deletions brontide/noise.go
Expand Up @@ -31,6 +31,10 @@ const (
// length of a message payload.
lengthHeaderSize = 2

// encHeaderSize is the number of bytes required to hold an encrypted
// header and it's MAC.
encHeaderSize = lengthHeaderSize + macSize

// keyRotationInterval is the number of messages sent on a single
// cipher stream before the keys are rotated forwards.
keyRotationInterval = 1000
Expand All @@ -48,6 +52,10 @@ var (
ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " +
"the max allowed message length of (2^16)-1")

// ErrMessageNotFlushed signals that the connection cannot accept a new
// message because the prior message has not been fully flushed.
ErrMessageNotFlushed = errors.New("prior message not flushed")

// lightningPrologue is the noise prologue that is used to initialize
// the brontide noise handshake.
lightningPrologue = []byte("lightning")
Expand Down Expand Up @@ -366,7 +374,17 @@ type Machine struct {
// nextCipherHeader is a static buffer that we'll use to read in the
// next ciphertext header from the wire. The header is a 2 byte length
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [lengthHeaderSize + macSize]byte
nextCipherHeader [encHeaderSize]byte

// nextHeaderSend holds a reference to the remaining header bytes to
// write out for a pending message. This allows us to tolerate timeout
// errors that cause partial writes.
nextHeaderSend []byte

// nextHeaderBody holds a reference to the remaining body bytes to write
// out for a pending message. This allows us to tolerate timeout errors
// that cause partial writes.
nextBodySend []byte
}

// NewBrontideMachine creates a new instance of the brontide state-machine. If
Expand Down Expand Up @@ -682,37 +700,116 @@ func (b *Machine) split() {
}
}

// WriteMessage writes the next message p to the passed io.Writer. The
// ciphertext of the message is prepended with an encrypt+auth'd length which
// must be used as the AD to the AEAD construction when being decrypted by the
// other side.
func (b *Machine) WriteMessage(w io.Writer, p []byte) error {
// WriteMessage encrypts and buffers the next message p. The ciphertext of the
// message is prepended with an encrypt+auth'd length which must be used as the
// AD to the AEAD construction when being decrypted by the other side.
//
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (b *Machine) WriteMessage(p []byte) error {
// The total length of each message payload including the MAC size
// payload exceed the largest number encodable within a 16-bit unsigned
// integer.
if len(p) > math.MaxUint16 {
return ErrMaxMessageLengthExceeded
}

// If a prior message was written but it hasn't been fully flushed,
// return an error as we only support buffering of one message at a
// time.
if len(b.nextHeaderSend) > 0 || len(b.nextBodySend) > 0 {
return ErrMessageNotFlushed
}

// The full length of the packet is only the packet length, and does
// NOT include the MAC.
fullLength := uint16(len(p))

var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength)

// First, write out the encrypted+MAC'd length prefix for the packet.
cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:])
if _, err := w.Write(cipherLen); err != nil {
return err
// First, generate the encrypted+MAC'd length prefix for the packet.
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])

// Finally, generate the encrypted packet itself.
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)

return nil
}

// Flush attempts to write a message buffered using WriteMessage to the provided
// io.Writer. If no buffered message exists, this will result in a NOP.
// Otherwise, it will continue to write the remaining bytes, picking up where
// the byte stream left off in the event of a partial write. The number of bytes
// returned reflects the number of plaintext bytes in the payload, and does not
// account for the overhead of the header or MACs.
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (b *Machine) Flush(w io.Writer) (int, error) {
// First, write out the pending header bytes, if any exist. Any header
// bytes written will not count towards the total amount flushed.
if len(b.nextHeaderSend) > 0 {
// Write any remaining header bytes and shift the slice to point
// to the next segment of unwritten bytes. If an error is
// encountered, we can continue to write the header from where
// we left off on a subsequent call to Flush.
n, err := w.Write(b.nextHeaderSend)
b.nextHeaderSend = b.nextHeaderSend[n:]
if err != nil {
return 0, err
}
}

// Finally, write out the encrypted packet itself. We only write out a
// single packet, as any fragmentation should have taken place at a
// higher level.
cipherText := b.sendCipher.Encrypt(nil, nil, p)
_, err := w.Write(cipherText)
return err
// Next, write the pending body bytes, if any exist. Only the number of
// bytes written that correspond to the ciphertext will be included in
// the total bytes written, bytes written as part of the MAC will not be
// counted.
var nn int
if len(b.nextBodySend) > 0 {
// Write out all bytes excluding the mac and shift the body
// slice depending on the number of actual bytes written.
n, err := w.Write(b.nextBodySend)
b.nextBodySend = b.nextBodySend[n:]

// If we partially or fully wrote any of the body's MAC, we'll
// subtract that contribution from the total amount flushed to
// preserve the abstraction of returning the number of plaintext
// bytes written by the connection.
//
// There are three possible scenarios we must handle to ensure
// the returned value is correct. In the first case, the write
// straddles both payload and MAC bytes, and we must subtract
// the number of MAC bytes written from n. In the second, only
// payload bytes are written, thus we can return n unmodified.
// The final scenario pertains to the case where only MAC bytes
// are written, none of which count towards the total.
//
// |-----------Payload------------|----MAC----|
// Straddle: S---------------------------------E--------0
// Payload-only: S------------------------E-----------------0
// MAC-only: S-------E-0
start, end := n+len(b.nextBodySend), len(b.nextBodySend)
switch {

// Straddles payload and MAC bytes, subtract number of MAC bytes
// written from the actual number written.
case start > macSize && end <= macSize:
nn = n - (macSize - end)

// Only payload bytes are written, return n directly.
case start > macSize && end > macSize:
nn = n

// Only MAC bytes are written, return 0 bytes written.
default:
}

if err != nil {
return nn, err
}
}

return nn, nil
}

// ReadMessage attempts to read the next message from the passed io.Reader. In
Expand Down

0 comments on commit 42b081b

Please sign in to comment.