Skip to content
Permalink
Browse files Browse the repository at this point in the history
Validate container size against total message size
Summary:
For performance reasons we preallocate golang containers with
the size given in the field header.  This allowed an attacker to
trigger very large memory allocations and potentially crash the server
with small messages.  Before creating the golang container confirm
that the message is theoretically large enough to contain a list/map/set
of the given size.

This requires that the binary and compact protocols use transports
that can expose the amount of data waiting to be read.  As a result of
this change you will not be able to do things like use the raw socket
transport or talk to endpoints over the HTTP transport that don't send
a content length header.

Fixes CVE-2019-11939.

Differential Revision: D19595758

fbshipit-source-id: 48bb9dbaf0467cea7a54602f0b07b00a8755c3f9
  • Loading branch information
fiorix authored and facebook-github-bot committed Mar 11, 2020
1 parent c880089 commit 483ed86
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 36 deletions.
16 changes: 14 additions & 2 deletions thrift/lib/go/thrift/binary_protocol.go
Expand Up @@ -329,6 +329,10 @@ func (p *BinaryProtocol) ReadMapBegin() (kType, vType Type, size int, err error)
err = invalidDataLength
return
}
if uint64(size32*2) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
err = invalidDataLength
return
}
size = int(size32)
return kType, vType, size, nil
}
Expand All @@ -353,6 +357,10 @@ func (p *BinaryProtocol) ReadListBegin() (elemType Type, size int, err error) {
err = invalidDataLength
return
}
if uint64(size32) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
err = invalidDataLength
return
}
size = int(size32)

return
Expand All @@ -378,6 +386,10 @@ func (p *BinaryProtocol) ReadSetBegin() (elemType Type, size int, err error) {
err = invalidDataLength
return
}
if uint64(size32) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
err = invalidDataLength
return
}
size = int(size32)
return elemType, size, nil
}
Expand Down Expand Up @@ -456,7 +468,7 @@ func (p *BinaryProtocol) ReadBinary() ([]byte, error) {
if size < 0 {
return nil, invalidDataLength
}
if uint64(size) > p.trans.RemainingBytes() {
if uint64(size) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
return nil, invalidDataLength
}

Expand Down Expand Up @@ -487,7 +499,7 @@ func (p *BinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
if uint64(size) > p.trans.RemainingBytes() {
if uint64(size) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
return "", invalidDataLength
}
var buf []byte
Expand Down
57 changes: 54 additions & 3 deletions thrift/lib/go/thrift/binary_protocol_test.go
Expand Up @@ -17,6 +17,8 @@
package thrift

import (
"bytes"
"fmt"
"strings"
"testing"
"time"
Expand All @@ -31,10 +33,13 @@ func TestSkipUnknownTypeBinaryProtocol(t *testing.T) {
d := NewDeserializer()
f := NewBinaryProtocolFactoryDefault()
d.Protocol = f.GetProtocol(d.Transport)
// skip over a map with invalid key/value type and 1.7B entries
data := []byte("\n\x10\rO\t6\x03\n\n\n\x10\r\n\tslice\x00")
// skip over a map with invalid key/value type and ~550M entries
data := make([]byte, 1100000000)
copy(data[:], []byte("\n\x10\rO\t6\x03\n\n\n\x10\r\n\tsl ce\x00"))
transport, _ := d.Transport.(*MemoryBuffer)
transport.Buffer = bytes.NewBuffer(data)
start := time.Now()
err := d.Read(&m, data)
err := m.Read(d.Protocol)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "unknown type") {
Expand All @@ -45,3 +50,49 @@ func TestSkipUnknownTypeBinaryProtocol(t *testing.T) {
t.Fatalf("It should not take seconds to parse a small message")
}
}

func TestInitialAllocationMapBinaryProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewBinaryProtocolFactoryDefault()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a map with 1.8B elements for a 20 byte message
data := []byte("\n\x10\rO\t6\x03\n\n\n\x10\r\n\tslice\x00")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}

func TestInitialAllocationListBinaryProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewBinaryProtocolFactoryDefault()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a list with 1.8B elements for a 20 byte message
data := []byte("\n\x10\rO\t6\x03\n\n\n\x10\x0f\n\tslice\x00")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}

func TestInitialAllocationSetBinaryProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewBinaryProtocolFactoryDefault()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a set with 1.8B elements for a 20 byte message
data := []byte("\n\x12\rO\t6\x03\n\n\n\x10\x0e\n\tslice\x00")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
fmt.Printf("Got %+v", err)
t.Fatalf("Failed for reason besides Invalid data length")
}
}
14 changes: 11 additions & 3 deletions thrift/lib/go/thrift/compact_protocol.go
Expand Up @@ -412,7 +412,6 @@ func (p *CompactProtocol) ReadFieldBegin() (name string, typeId Type, id int16,
if (t & 0x0f) == STOP {
return "", STOP, 0, nil
}

// mask off the 4 MSB of the type header. it could contain a field id delta.
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
Expand Down Expand Up @@ -458,6 +457,10 @@ func (p *CompactProtocol) ReadMapBegin() (keyType Type, valueType Type, size int
err = invalidDataLength
return
}
if uint64(size32*2) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
err = invalidDataLength
return
}
size = int(size32)

keyAndValueType := byte(STOP)
Expand Down Expand Up @@ -496,6 +499,11 @@ func (p *CompactProtocol) ReadListBegin() (elemType Type, size int, err error) {
}
size = int(size2)
}
if uint64(size) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
err = invalidDataLength
return
}

elemType, e := p.getType(compactType(size_and_type))
if e != nil {
err = NewProtocolException(e)
Expand Down Expand Up @@ -596,7 +604,7 @@ func (p *CompactProtocol) ReadString() (value string, err error) {
if length < 0 {
return "", invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
if uint64(length) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
return "", invalidDataLength
}

Expand Down Expand Up @@ -625,7 +633,7 @@ func (p *CompactProtocol) ReadBinary() (value []byte, err error) {
if length < 0 {
return nil, invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
if uint64(length) > p.trans.RemainingBytes() || p.trans.RemainingBytes() == UnknownRemaining {
return nil, invalidDataLength
}

Expand Down
64 changes: 63 additions & 1 deletion thrift/lib/go/thrift/compact_protocol_test.go
Expand Up @@ -18,6 +18,7 @@ package thrift

import (
"bytes"
"strings"
"testing"
)

Expand All @@ -27,7 +28,6 @@ func TestReadWriteCompactProtocol(t *testing.T) {
ReadWriteProtocolParallelTest(t, NewCompactProtocolFactory())
transports := []Transport{
NewMemoryBuffer(),
NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))),
NewFramedTransport(NewMemoryBuffer()),
}
for _, trans := range transports {
Expand All @@ -54,3 +54,65 @@ func TestReadWriteCompactProtocol(t *testing.T) {
trans.Close()
}
}

func TestInitialAllocationMapCompactProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewCompactProtocolFactory()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a map of 930M elements for a 9 byte message
data := []byte("%0\x88\x8a\x97\xb7\xc4\x030")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}

func TestInitialAllocationListCompactProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewCompactProtocolFactory()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a list of 950M elements for an 11 byte message
data := []byte("%0\x98\xfa\xb7\xb7\xc4\xc4\x03\x01a")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}

func TestInitialAllocationSetCompactProtocol(t *testing.T) {
var m MyTestStruct
d := NewDeserializer()
f := NewCompactProtocolFactory()
d.Protocol = f.GetProtocol(d.Transport)
// attempts to allocate a list of 950M elements for an 11 byte message
data := []byte("%0\xa8\xfa\x97\xb7\xc4\xc4\x03\x01a")
err := d.Read(&m, data)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}

func TestInitialAllocationMapCompactProtocolLimitedR(t *testing.T) {
var m MyTestStruct

// attempts to allocate a map of 930M elements for a 9 byte message
data := []byte("%0\x88\x8a\x97\xb7\xc4\x030")
p := NewCompactProtocol(
NewStreamTransportLimitedR(bytes.NewBuffer(data), len(data)),
)

err := m.Read(p)
if err == nil {
t.Fatalf("Parsed invalid message correctly")
} else if !strings.Contains(err.Error(), "Invalid data length") {
t.Fatalf("Failed for reason besides Invalid data length")
}
}
37 changes: 15 additions & 22 deletions thrift/lib/go/thrift/http_client.go
Expand Up @@ -19,7 +19,6 @@ package thrift
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
Expand All @@ -35,6 +34,7 @@ type HTTPClient struct {
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
responseBuffer bytes.Buffer
header http.Header
nsecConnectTimeout int64
nsecReadTimeout int64
Expand Down Expand Up @@ -164,20 +164,9 @@ func (p *HTTPClient) IsOpen() bool {
}

func (p *HTTPClient) closeResponse() error {
var err error
if p.response != nil && p.response.Body != nil {
// The docs specify that if keepalive is enabled and the response body is not
// read to completion the connection will never be returned to the pool and
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
io.Copy(ioutil.Discard, p.response.Body)

err = p.response.Body.Close()
}

p.response = nil
return err
p.responseBuffer.Reset()
return nil
}

func (p *HTTPClient) Close() error {
Expand All @@ -192,15 +181,15 @@ func (p *HTTPClient) Read(buf []byte) (int, error) {
if p.response == nil {
return 0, NewTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
n, err := p.response.Body.Read(buf)
n, err := p.responseBuffer.Read(buf)
if n > 0 && (err == nil || err == io.EOF) {
return n, nil
}
return n, NewTransportExceptionFromError(err)
}

func (p *HTTPClient) ReadByte() (c byte, err error) {
return readByte(p.response.Body)
return readByte(&p.responseBuffer)
}

func (p *HTTPClient) Write(buf []byte) (int, error) {
Expand Down Expand Up @@ -230,6 +219,9 @@ func (p *HTTPClient) Flush() error {
if err != nil {
return NewTransportExceptionFromError(err)
}

defer response.Body.Close()

if response.StatusCode != http.StatusOK {
// Close the response to avoid leaking file descriptors. closeResponse does
// more than just call Close(), so temporarily assign it and reuse the logic.
Expand All @@ -239,15 +231,16 @@ func (p *HTTPClient) Flush() error {
// TODO(pomack) log bad response
return NewTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode))
}

_, err = io.Copy(&p.responseBuffer, response.Body)
if err != nil {
return NewTransportExceptionFromError(err)
}

p.response = response
return nil
}

func (p *HTTPClient) RemainingBytes() (num_bytes uint64) {
len := p.response.ContentLength
if len >= 0 {
return uint64(len)
}

return UnknownRemaining // the truth is, we just don't know unless framed is used
return uint64(p.responseBuffer.Len())
}
5 changes: 0 additions & 5 deletions thrift/lib/go/thrift/protocol_test.go
Expand Up @@ -17,7 +17,6 @@
package thrift

import (
"bytes"
"io"
"io/ioutil"
"math"
Expand Down Expand Up @@ -172,8 +171,6 @@ func ReadWriteProtocolParallelTest(t *testing.T, protocolFactory ProtocolFactory
rConn, wConn := tcpStreamSetupForTest(t)
rdr, writer := io.Pipe()
transports := []TransportFactory{
NewStreamTransportFactory(rdr, writer, false), // use a pipe
NewStreamTransportFactory(rConn, wConn, false), // use tcp over network
NewFramedTransportFactory(NewStreamTransportFactory(rdr, writer, false)), // framed over pipe
NewFramedTransportFactory(NewStreamTransportFactory(rConn, wConn, false)), // framed over tcp
}
Expand Down Expand Up @@ -237,12 +234,10 @@ func ReadWriteProtocolParallelTest(t *testing.T, protocolFactory ProtocolFactory
}

func ReadWriteProtocolTest(t *testing.T, protocolFactory ProtocolFactory) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
l := HTTPClientSetupForTest(t)
defer l.Close()
transports := []TransportFactory{
NewMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewFramedTransportFactory(NewMemoryBufferTransportFactory(1024)),
NewHTTPPostClientTransportFactory("http://" + l.Addr().String()),
}
Expand Down

0 comments on commit 483ed86

Please sign in to comment.