From 9784faa5c204e317e8625924b6394f3055f8b1e8 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Tue, 22 Dec 2015 09:36:02 -0800 Subject: [PATCH 1/6] capnp: add integration tests for pointer depth --- integration_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/integration_test.go b/integration_test.go index 27afec49..a0ae56f9 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1706,3 +1706,97 @@ func TestVoidUnionSetters(t *testing.T) { t.Errorf("msg.Marshal() =\n%s\n; want:\n%s", hex.Dump(act), hex.Dump(want)) } } + +func TestPointerDepthDefense(t *testing.T) { + t.Parallel() + const limit = 64 + msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ + 0, 0, 0, 0, 0, 0, 1, 0, + 0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, + })} + root, err := msg.Root() + if err != nil { + t.Fatal("Root:", err) + } + + curr := capnp.ToStruct(root) + if !capnp.IsValid(curr) { + t.Fatal("Root is not a struct") + } + for i := 0; i < limit-1; i++ { + p, err := curr.Pointer(0) + if err != nil { + t.Fatalf("deref %d fail: %v", i+1, err) + } + if !capnp.IsValid(p) { + t.Fatalf("deref %d is invalid", i+1) + } + curr = capnp.ToStruct(p) + if !capnp.IsValid(curr) { + t.Fatalf("deref %d is not a struct", i+1) + } + } + + _, err = curr.Pointer(0) + if err == nil { + t.Fatalf("deref %d did not fail as expected", limit) + } +} + +func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { + t.Parallel() + const limit = 64 + msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ + 0, 0, 0, 0, 0, 0, 2, 0, + 5, 0, 0, 0, 14, 0, 0, 0, + 0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, + })} + + toStruct := func(p capnp.Pointer, err error) (capnp.Struct, error) { + if err != nil { + return capnp.Struct{}, err + } + if !capnp.IsValid(p) { + return capnp.Struct{}, errors.New("invalid pointer") + } + s := capnp.ToStruct(p) + if !capnp.IsValid(s) { + return capnp.Struct{}, errors.New("not a struct") + } + return s, nil + } + toList := func(p capnp.Pointer, err error) (capnp.List, error) { + if err != nil { + return capnp.Struct{}, err + } + if !capnp.IsValid(p) { + return capnp.Struct{}, errors.New("invalid pointer") + } + l := capnp.ToList(p) + if !capnp.IsValid(l) { + return capnp.Struct{}, errors.New("not a list") + } + return l, nil + } + curr, err := toStruct(msg.Root()) + if err != nil { + t.Fatal("Root:", err) + } + for i := limit; i > 2; { + l, err := toList(curr.Pointer(0)) + if err != nil { + t.Fatalf("deref %d (for list): %v", limit-i+1, err) + } + i-- + curr, err = toStruct(capnp.PointerList{List: l}.At(0)) + if err != nil { + t.Fatalf("deref %d (for list): %v", limit-i+1, err) + } + i-- + } + + _, err = curr.Pointer(0) + if err == nil { + t.Fatalf("deref %d did not fail as expected", limit) + } +} From 07bc58010fc7a18f3dff460fe4a42da8437bd97d Mon Sep 17 00:00:00 2001 From: Ross Light Date: Tue, 22 Dec 2015 15:56:04 -0800 Subject: [PATCH 2/6] capnp: make pointer depth struct/list test fail correctly --- integration_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/integration_test.go b/integration_test.go index a0ae56f9..7a9c6b32 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1711,8 +1711,8 @@ func TestPointerDepthDefense(t *testing.T) { t.Parallel() const limit = 64 msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ - 0, 0, 0, 0, 0, 0, 1, 0, - 0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word + 0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself })} root, err := msg.Root() if err != nil { @@ -1747,9 +1747,9 @@ func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { t.Parallel() const limit = 64 msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ - 0, 0, 0, 0, 0, 0, 2, 0, - 5, 0, 0, 0, 14, 0, 0, 0, - 0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word + 0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word) + 0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word })} toStruct := func(p capnp.Pointer, err error) (capnp.Struct, error) { @@ -1767,14 +1767,14 @@ func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { } toList := func(p capnp.Pointer, err error) (capnp.List, error) { if err != nil { - return capnp.Struct{}, err + return capnp.List{}, err } if !capnp.IsValid(p) { - return capnp.Struct{}, errors.New("invalid pointer") + return capnp.List{}, errors.New("invalid pointer") } l := capnp.ToList(p) if !capnp.IsValid(l) { - return capnp.Struct{}, errors.New("not a list") + return capnp.List{}, errors.New("not a list") } return l, nil } @@ -1790,7 +1790,7 @@ func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { i-- curr, err = toStruct(capnp.PointerList{List: l}.At(0)) if err != nil { - t.Fatalf("deref %d (for list): %v", limit-i+1, err) + t.Fatalf("deref %d (for struct): %v", limit-i+1, err) } i-- } From 2d30d334382b0cd6232ff94b4e304775fd82dd37 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Thu, 31 Mar 2016 09:03:09 -0700 Subject: [PATCH 3/6] capnp: add test for traverse limit --- integration_test.go | 47 ++++++++++++++++++++++++++++++++++++--------- mem.go | 16 +++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/integration_test.go b/integration_test.go index 918d3d0c..c8799ace 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1890,13 +1890,39 @@ func (ta testArena) Allocate(capnp.Size, map[capnp.SegmentID]*capnp.Segment) (ca return 0, nil, errors.New("test arena: can't allocate") } +func TestPointerTraverseDefense(t *testing.T) { + t.Parallel() + const limit = 128 + msg := &capnp.Message{ + Arena: capnp.SingleSegment([]byte{ + 0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word + 0, 0, 0, 0, 0, 0, 0, 0, // struct's data + }), + TraverseLimit: limit * 8, + } + + for i := 0; i < limit-1; i++ { + _, err := msg.RootPtr() + if err != nil { + t.Fatal("RootPtr:", err) + } + } + + if _, err := msg.RootPtr(); err == nil { + t.Fatalf("deref %d did not fail as expected", limit) + } +} + func TestPointerDepthDefense(t *testing.T) { t.Parallel() const limit = 64 - msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ - 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word - 0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself - })} + msg := &capnp.Message{ + Arena: capnp.SingleSegment([]byte{ + 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word + 0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself + }), + DepthLimit: limit, + } root, err := msg.Root() if err != nil { t.Fatal("Root:", err) @@ -1929,11 +1955,14 @@ func TestPointerDepthDefense(t *testing.T) { func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { t.Parallel() const limit = 64 - msg := &capnp.Message{Arena: capnp.SingleSegment([]byte{ - 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word - 0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word) - 0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word - })} + msg := &capnp.Message{ + Arena: capnp.SingleSegment([]byte{ + 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word + 0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word) + 0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word + }), + DepthLimit: limit, + } toStruct := func(p capnp.Pointer, err error) (capnp.Struct, error) { if err != nil { diff --git a/mem.go b/mem.go index 1da2bf7a..b4c65636 100644 --- a/mem.go +++ b/mem.go @@ -10,6 +10,7 @@ import ( // A Message is a tree of Cap'n Proto objects, split into one or more // segments of contiguous memory. The only required field is Arena. +// A Message is safe to read from multiple goroutines. type Message struct { Arena Arena @@ -22,6 +23,21 @@ type Message struct { // more details on the capability table. CapTable []Client + // TraverseLimit limits how many total bytes of data are allowed to be + // traversed while reading. Traversal is counted when a Struct or + // List is obtained. This means that calling a getter for the same + // sub-struct multiple times will cause it to be double-counted. Once + // the traversal limit is reached, pointer accessors will report + // errors. See https://capnproto.org/encoding.html#amplification-attack + // for more details on this security measure. + // + // If not set, this defaults to 64 MiB. + TraverseLimit uint64 + + // DepthLimit limits how deeply-nested a message structure can be. + // If not set, this defaults to 64. + DepthLimit uint + segs map[SegmentID]*Segment // Preallocated first segment. msg is non-nil once initialized. From 877c47e29764422d4c8f1729635ae153baae742e Mon Sep 17 00:00:00 2001 From: Ross Light Date: Thu, 31 Mar 2016 11:00:47 -0700 Subject: [PATCH 4/6] capnp: add ReadLimiter for traversal limit --- capn.go | 14 ++++- integration_test.go | 6 +- mem.go | 37 ++++++++++++ pointer.go | 21 +++++++ readlimit.go | 38 +++++++++++++ readlimit_test.go | 133 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 readlimit.go create mode 100644 readlimit_test.go diff --git a/capn.go b/capn.go index 0b49d673..97eb007d 100644 --- a/capn.go +++ b/capn.go @@ -114,8 +114,17 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) { return s.msg.Segment(id) } -func (s *Segment) readPtr(off Address) (Ptr, error) { - var err error +func (s *Segment) readPtr(off Address) (ptr Ptr, err error) { + defer func(orig *Message) { + if !ptr.IsValid() || s.msg != orig { + return + } + if !s.msg.ReadLimiter().canRead(ptr.limitSize()) { + if err == nil { + ptr, err = Ptr{}, errReadLimit + } + } + }(s.msg) val := s.readRawPointer(off) s, off, val, err = s.resolveFarPointer(off, val) if err != nil { @@ -478,6 +487,7 @@ var ( errBadTag = errors.New("capnp: invalid tag word") errOtherPointer = errors.New("capnp: unknown pointer type") errObjectSize = errors.New("capnp: invalid object size") + errReadLimit = errors.New("capnp: read traversal limit reached") ) var ( diff --git a/integration_test.go b/integration_test.go index c8799ace..77e6ed21 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1867,7 +1867,7 @@ func BenchmarkUnmarshal_Reuse(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { *ta = testArena(data[r.Intn(len(data))][8:]) - *msg = capnp.Message{Arena: arena} + msg.Reset(arena) a, _ := air.ReadRootBenchmarkA(msg) unmarshalA(a) } @@ -1901,7 +1901,7 @@ func TestPointerTraverseDefense(t *testing.T) { TraverseLimit: limit * 8, } - for i := 0; i < limit-1; i++ { + for i := 0; i < limit; i++ { _, err := msg.RootPtr() if err != nil { t.Fatal("RootPtr:", err) @@ -1909,7 +1909,7 @@ func TestPointerTraverseDefense(t *testing.T) { } if _, err := msg.RootPtr(); err == nil { - t.Fatalf("deref %d did not fail as expected", limit) + t.Fatalf("deref %d did not fail as expected", limit+1) } } diff --git a/mem.go b/mem.go index bbc5c27e..b68362db 100644 --- a/mem.go +++ b/mem.go @@ -9,10 +9,21 @@ import ( "zombiezen.com/go/capnproto2/internal/packed" ) +// Default security limits. +const ( + defaultTraverseLimit = 64 << 20 + defaultDepthLimit = 64 +) + // A Message is a tree of Cap'n Proto objects, split into one or more // segments of contiguous memory. The only required field is Arena. // A Message is safe to read from multiple goroutines. type Message struct { + // rlimit must be first so that it is 64-bit aligned. + // See sync/atomic docs. + rlimit ReadLimiter + rlimitInit sync.Once + Arena Arena // CapTable is the indexed list of the clients referenced in the @@ -73,6 +84,19 @@ func NewMessage(arena Arena) (msg *Message, first *Segment, err error) { return msg, first, nil } +// Reset resets a message to use a different arena, allowing a single +// Message to be reused for reading multiple messages. This invalidates +// any existing pointers in the Message, so use with caution. +func (m *Message) Reset(arena Arena) { + m.mu.Lock() + m.Arena = arena + m.CapTable = nil + m.segs = nil + m.firstSeg = Segment{} + m.mu.Unlock() + m.ReadLimiter().Reset(m.TraverseLimit) +} + // Root is deprecated in favor of RootPtr. func (m *Message) Root() (Pointer, error) { p, err := m.RootPtr() @@ -110,6 +134,19 @@ func (m *Message) AddCap(c Client) CapabilityID { return n } +// ReadLimiter returns the message's read limiter. Useful if you want +// to reset the traversal limit while reading. +func (m *Message) ReadLimiter() *ReadLimiter { + m.rlimitInit.Do(func() { + if m.TraverseLimit == 0 { + m.rlimit.limit = defaultTraverseLimit + } else { + m.rlimit.limit = m.TraverseLimit + } + }) + return &m.rlimit +} + // NumSegments returns the number of segments in the message. func (m *Message) NumSegments() int64 { return int64(m.Arena.NumSegments()) diff --git a/pointer.go b/pointer.go index ecf88bd0..ba63018f 100644 --- a/pointer.go +++ b/pointer.go @@ -197,6 +197,27 @@ func (p Ptr) address() Address { panic("ptr not a valid struct or list") } +// limitSize returns the pointer's size for the purposes of read limit +// accounting. +func (p Ptr) limitSize() Size { + switch p.flags.ptrType() { + case structPtrType: + return p.size.totalSize() + case listPtrType: + elem := p.size.totalSize() + if elem == 0 { + elem = wordSize + } + sz, ok := elem.times(int32(p.lenOrCap)) + if !ok { + return maxSize + } + return sz + default: + return 0 + } +} + // Pointer is deprecated in favor of Ptr. type Pointer interface { // Segment returns the segment this pointer points into. diff --git a/readlimit.go b/readlimit.go new file mode 100644 index 00000000..1e1f9808 --- /dev/null +++ b/readlimit.go @@ -0,0 +1,38 @@ +package capnp + +import "sync/atomic" + +// A ReadLimiter tracks the number of bytes read from a message in order +// to avoid amplification attacks as detailed in +// https://capnproto.org/encoding.html#amplification-attack. +// It is safe to use from multiple goroutines. +type ReadLimiter struct { + limit uint64 +} + +// canRead reports whether the amount of bytes can be stored safely. +func (rl *ReadLimiter) canRead(sz Size) bool { + for { + curr := atomic.LoadUint64(&rl.limit) + ok := curr >= uint64(sz) + var new uint64 + if ok { + new = curr - uint64(sz) + } else { + new = 0 + } + if atomic.CompareAndSwapUint64(&rl.limit, curr, new) { + return ok + } + } +} + +// Reset sets the number of bytes allowed to be read. +func (rl *ReadLimiter) Reset(limit uint64) { + atomic.StoreUint64(&rl.limit, limit) +} + +// Unread increases the limit by sz. +func (rl *ReadLimiter) Unread(sz Size) { + atomic.AddUint64(&rl.limit, uint64(sz)) +} diff --git a/readlimit_test.go b/readlimit_test.go new file mode 100644 index 00000000..21b665da --- /dev/null +++ b/readlimit_test.go @@ -0,0 +1,133 @@ +package capnp + +import "testing" + +func TestReadLimiter_canRead(t *testing.T) { + t.Parallel() + type canReadCall struct { + sz Size + ok bool + } + tests := []struct { + name string + init uint64 + calls []canReadCall + }{ + { + name: "can always read zero", + init: 0, + calls: []canReadCall{ + {0, true}, + }, + }, + { + name: "can't read a byte when limit is zero", + init: 0, + calls: []canReadCall{ + {1, false}, + }, + }, + { + name: "reading a word from a high limit is okay", + init: 128, + calls: []canReadCall{ + {8, true}, + }, + }, + { + name: "reading a byte after depleting the limit fails", + init: 8, + calls: []canReadCall{ + {8, true}, + {1, false}, + }, + }, + { + name: "reading a byte after hitting the limit fails", + init: 8, + calls: []canReadCall{ + {8, true}, + {1, false}, + }, + }, + { + name: "reading a byte after hitting the limit in multiple calls fails", + init: 8, + calls: []canReadCall{ + {6, true}, + {2, true}, + {1, false}, + }, + }, + } + for _, test := range tests { + rl := &ReadLimiter{limit: test.init} + for i, c := range test.calls { + ok := rl.canRead(c.sz) + if ok != c.ok { + // TODO(light): show previous calls + t.Errorf("in %s, calls[%d] ok = %t; want %t", test.name, i, ok, c.ok) + } + } + } +} + +func TestReadLimiter_Reset(t *testing.T) { + { + rl := &ReadLimiter{limit: 42} + t.Log(" rl := &ReadLimiter{limit: 42}") + ok := rl.canRead(42) + t.Logf(" rl.canRead(42) -> %t", ok) + rl.Reset(8) + t.Log(" rl.Reset(8)") + if rl.canRead(8) { + t.Log(" rl.canRead(8) -> true") + } else { + t.Error("!! rl.canRead(8) -> false; want true") + } + } + t.Log() + { + rl := &ReadLimiter{limit: 42} + t.Log(" rl := &ReadLimiter{limit: 42}") + ok := rl.canRead(40) + t.Logf(" rl.canRead(40) -> %t", ok) + rl.Reset(8) + t.Log(" rl.Reset(8)") + if rl.canRead(9) { + t.Error("!! rl.canRead(9) -> true; want false") + } else { + t.Log(" rl.canRead(9) -> false") + } + } +} + +func TestReadLimiter_Unread(t *testing.T) { + { + rl := &ReadLimiter{limit: 42} + t.Log(" rl := &ReadLimiter{limit: 42}") + ok := rl.canRead(42) + t.Logf(" rl.canRead(42) -> %t", ok) + rl.Unread(8) + t.Log(" rl.Unread(8)") + if rl.canRead(8) { + t.Log(" rl.canRead(8) -> true") + } else { + t.Error("!! rl.canRead(8) -> false; want true") + } + } + t.Log() + { + rl := &ReadLimiter{limit: 42} + t.Log(" rl := &ReadLimiter{limit: 42}") + ok := rl.canRead(40) + t.Logf(" rl.canRead(40) -> %t", ok) + rl.Unread(8) + t.Log(" rl.Unread(8)") + if rl.canRead(9) { + t.Log(" rl.canRead(9) -> true") + } else { + t.Error("!! rl.canRead(9) -> false; want true") + } + } +} From 51ff4c596af022b0c27262de633f14e9c29c16ea Mon Sep 17 00:00:00 2001 From: Ross Light Date: Thu, 31 Mar 2016 12:15:43 -0700 Subject: [PATCH 5/6] capnp: track pointer depth --- capability_test.go | 2 +- capn.go | 67 ++++++++++++++++++++---------------- integration_test.go | 2 +- list.go | 82 +++++++++++++++++++++++++-------------------- list_test.go | 20 ++++++----- mem.go | 9 +++++ pointer.go | 31 +++++++++-------- struct.go | 29 +++++++++------- 8 files changed, 139 insertions(+), 103 deletions(-) diff --git a/capability_test.go b/capability_test.go index 8b73b47e..2c6ac3a8 100644 --- a/capability_test.go +++ b/capability_test.go @@ -17,7 +17,7 @@ func TestToInterface(t *testing.T) { }{ {nil, Interface{}}, {Struct{}, Interface{}}, - {Struct{seg: seg, off: 0}, Interface{}}, + {Struct{seg: seg, off: 0, depthLimit: maxDepth}, Interface{}}, {Interface{}, Interface{}}, {Interface{seg, 42}, Interface{seg, 42}}, } diff --git a/capn.go b/capn.go index 97eb007d..db351b52 100644 --- a/capn.go +++ b/capn.go @@ -101,9 +101,10 @@ func (s *Segment) root() PointerList { return PointerList{} } return PointerList{List{ - seg: s, - length: 1, - size: sz, + seg: s, + length: 1, + size: sz, + depthLimit: s.msg.depthLimit(), }} } @@ -114,7 +115,7 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) { return s.msg.Segment(id) } -func (s *Segment) readPtr(off Address) (ptr Ptr, err error) { +func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) { defer func(orig *Message) { if !ptr.IsValid() || s.msg != orig { return @@ -133,6 +134,9 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) { if val == 0 { return Ptr{}, nil } + if depthLimit == 0 { + return Ptr{}, errDepthLimit + } // Be wary of overflow. Offset is 30 bits signed. List size is 29 bits // unsigned. For both of these we need to check in terms of words if // using 32 bit maths as bits or bytes will overflow. @@ -147,9 +151,10 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) { return Ptr{}, errPointerAddress } return Struct{ - seg: s, - off: addr, - size: sz, + seg: s, + off: addr, + size: sz, + depthLimit: depthLimit - 1, }.ToPtr(), nil case listPointer: addr, ok := val.offset().resolve(off) @@ -183,26 +188,29 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) { return Ptr{}, errPointerAddress } return List{ - seg: s, - size: sz, - off: addr, - length: n, - flags: isCompositeList, + seg: s, + size: sz, + off: addr, + length: n, + flags: isCompositeList, + depthLimit: depthLimit - 1, }.ToPtr(), nil } if lt == bit1List { return List{ - seg: s, - off: addr, - length: val.numListElements(), - flags: isBitList, + seg: s, + off: addr, + length: val.numListElements(), + flags: isBitList, + depthLimit: depthLimit - 1, }.ToPtr(), nil } return List{ - seg: s, - size: val.elementSize(), - off: addr, - length: val.numListElements(), + seg: s, + size: val.elementSize(), + off: addr, + length: val.numListElements(), + depthLimit: depthLimit - 1, }.ToPtr(), nil case otherPointer: if val.otherPointerType() != 0 { @@ -413,9 +421,10 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro case structPtrType: s := src.Struct() dst := Struct{ - seg: newSeg, - off: newAddr, - size: s.size, + seg: newSeg, + off: newAddr, + size: s.size, + depthLimit: maxDepth, // clear flags } key.newval = dst.ToPtr() @@ -426,11 +435,12 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro case listPtrType: l := src.List() dst := List{ - seg: newSeg, - off: newAddr, - length: l.length, - size: l.size, - flags: l.flags, + seg: newSeg, + off: newAddr, + length: l.length, + size: l.size, + flags: l.flags, + depthLimit: maxDepth, } if dst.flags&isCompositeList != 0 { // Copy tag word @@ -488,6 +498,7 @@ var ( errOtherPointer = errors.New("capnp: unknown pointer type") errObjectSize = errors.New("capnp: invalid object size") errReadLimit = errors.New("capnp: read traversal limit reached") + errDepthLimit = errors.New("capnp: depth limit reached") ) var ( diff --git a/integration_test.go b/integration_test.go index 77e6ed21..641e6ae1 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1954,7 +1954,7 @@ func TestPointerDepthDefense(t *testing.T) { func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) { t.Parallel() - const limit = 64 + const limit = 63 msg := &capnp.Message{ Arena: capnp.SingleSegment([]byte{ 0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word diff --git a/list.go b/list.go index b247df11..0f6612ca 100644 --- a/list.go +++ b/list.go @@ -7,11 +7,12 @@ import ( // A List is a reference to an array of values. type List struct { - seg *Segment - off Address - length int32 - size ObjectSize - flags listFlags + seg *Segment + off Address + length int32 + size ObjectSize + depthLimit uint + flags listFlags } // newPrimitiveList allocates a new list of primitive values, preferring placement in s. @@ -25,10 +26,11 @@ func newPrimitiveList(s *Segment, sz Size, n int32) (List, error) { return List{}, err } return List{ - seg: s, - off: addr, - length: n, - size: ObjectSize{DataSize: sz}, + seg: s, + off: addr, + length: n, + size: ObjectSize{DataSize: sz}, + depthLimit: maxDepth, }, nil } @@ -50,11 +52,12 @@ func NewCompositeList(s *Segment, sz ObjectSize, n int32) (List, error) { // Add tag word s.writeRawPointer(addr, rawStructPointer(pointerOffset(n), sz)) return List{ - seg: s, - off: addr + Address(wordSize), - length: n, - size: sz, - flags: isCompositeList, + seg: s, + off: addr + Address(wordSize), + length: n, + size: sz, + flags: isCompositeList, + depthLimit: maxDepth, }, nil } @@ -71,11 +74,12 @@ func ToListDefault(p Pointer, def []byte) (List, error) { // ToPtr converts the list to a generic pointer. func (p List) ToPtr() Ptr { return Ptr{ - seg: p.seg, - off: p.off, - lenOrCap: uint32(p.length), - size: p.size, - flags: listPtrFlag(p.flags), + seg: p.seg, + off: p.off, + lenOrCap: uint32(p.length), + size: p.size, + depthLimit: p.depthLimit, + flags: listPtrFlag(p.flags), } } @@ -176,10 +180,11 @@ func (p List) Struct(i int) Struct { } addr, _ := p.elem(i) return Struct{ - seg: p.seg, - off: addr, - size: p.size, - flags: isListMember, + seg: p.seg, + off: addr, + size: p.size, + flags: isListMember, + depthLimit: p.depthLimit - 1, } } @@ -201,10 +206,11 @@ func NewBitList(s *Segment, n int32) (BitList, error) { return BitList{}, err } return BitList{List{ - seg: s, - off: addr, - length: n, - flags: isBitList, + seg: s, + off: addr, + length: n, + flags: isBitList, + depthLimit: maxDepth, }}, nil } @@ -246,10 +252,11 @@ func NewPointerList(s *Segment, n int32) (PointerList, error) { return PointerList{}, err } return PointerList{List{ - seg: s, - off: addr, - length: n, - size: ObjectSize{PointerCount: 1}, + seg: s, + off: addr, + length: n, + size: ObjectSize{PointerCount: 1}, + depthLimit: maxDepth, }}, nil } @@ -262,7 +269,7 @@ func (p PointerList) At(i int) (Pointer, error) { // PtrAt returns the i'th pointer in the list. func (p PointerList) PtrAt(i int) (Ptr, error) { addr, _ := p.elem(i) - return p.seg.readPtr(addr) + return p.seg.readPtr(addr, p.depthLimit) } // Set is deprecated in favor of SetPtr. @@ -291,7 +298,7 @@ func NewTextList(s *Segment, n int32) (TextList, error) { // At returns the i'th string in the list. func (l TextList) At(i int) (string, error) { addr, _ := l.elem(i) - p, err := l.seg.readPtr(addr) + p, err := l.seg.readPtr(addr, l.depthLimit) if err != nil { return "", err } @@ -302,7 +309,7 @@ func (l TextList) At(i int) (string, error) { // The underlying array of the slice is the segment data. func (l TextList) BytesAt(i int) ([]byte, error) { addr, _ := l.elem(i) - p, err := l.seg.readPtr(addr) + p, err := l.seg.readPtr(addr, l.depthLimit) if err != nil { return nil, err } @@ -334,7 +341,7 @@ func NewDataList(s *Segment, n int32) (DataList, error) { // At returns the i'th data in the list. func (l DataList) At(i int) ([]byte, error) { addr, _ := l.elem(i) - p, err := l.seg.readPtr(addr) + p, err := l.seg.readPtr(addr, l.depthLimit) if err != nil { return nil, err } @@ -358,8 +365,9 @@ type VoidList struct{ List } // s is only used for Segment()'s return value. func NewVoidList(s *Segment, n int32) VoidList { return VoidList{List{ - seg: s, - length: n, + seg: s, + length: n, + depthLimit: maxDepth, }} } diff --git a/list_test.go b/list_test.go index afe7a1fa..5ad8c846 100644 --- a/list_test.go +++ b/list_test.go @@ -20,20 +20,22 @@ func TestToListDefault(t *testing.T) { }{ {nil, nil, List{}}, {Struct{}, nil, List{}}, - {Struct{seg: seg, off: 0}, nil, List{}}, + {Struct{seg: seg, off: 0, depthLimit: maxDepth}, nil, List{}}, {List{}, nil, List{}}, { ptr: List{ - seg: seg, - off: 8, - length: 1, - size: ObjectSize{DataSize: 8}, + seg: seg, + off: 8, + length: 1, + size: ObjectSize{DataSize: 8}, + depthLimit: maxDepth, }, list: List{ - seg: seg, - off: 8, - length: 1, - size: ObjectSize{DataSize: 8}, + seg: seg, + off: 8, + length: 1, + size: ObjectSize{DataSize: 8}, + depthLimit: maxDepth, }, }, } diff --git a/mem.go b/mem.go index b68362db..ad959ac6 100644 --- a/mem.go +++ b/mem.go @@ -15,6 +15,8 @@ const ( defaultDepthLimit = 64 ) +const maxDepth = ^uint(0) + // A Message is a tree of Cap'n Proto objects, split into one or more // segments of contiguous memory. The only required field is Arena. // A Message is safe to read from multiple goroutines. @@ -147,6 +149,13 @@ func (m *Message) ReadLimiter() *ReadLimiter { return &m.rlimit } +func (m *Message) depthLimit() uint { + if m.DepthLimit != 0 { + return m.DepthLimit + } + return defaultDepthLimit +} + // NumSegments returns the number of segments in the message. func (m *Message) NumSegments() int64 { return int64(m.Arena.NumSegments()) diff --git a/pointer.go b/pointer.go index ba63018f..45cf8cb4 100644 --- a/pointer.go +++ b/pointer.go @@ -2,11 +2,12 @@ package capnp // A Ptr is a reference to a Cap'n Proto struct, list, or interface. type Ptr struct { - seg *Segment - off Address - lenOrCap uint32 - size ObjectSize - flags ptrFlags + seg *Segment + off Address + lenOrCap uint32 + size ObjectSize + depthLimit uint + flags ptrFlags } func toPtr(p Pointer) Ptr { @@ -31,10 +32,11 @@ func (p Ptr) Struct() Struct { return Struct{} } return Struct{ - seg: p.seg, - off: p.off, - size: p.size, - flags: p.flags.structFlags(), + seg: p.seg, + off: p.off, + size: p.size, + flags: p.flags.structFlags(), + depthLimit: p.depthLimit, } } @@ -62,11 +64,12 @@ func (p Ptr) List() List { return List{} } return List{ - seg: p.seg, - off: p.off, - length: int32(p.lenOrCap), - size: p.size, - flags: p.flags.listFlags(), + seg: p.seg, + off: p.off, + length: int32(p.lenOrCap), + size: p.size, + flags: p.flags.listFlags(), + depthLimit: p.depthLimit, } } diff --git a/struct.go b/struct.go index 43e9014a..4bd7a66e 100644 --- a/struct.go +++ b/struct.go @@ -2,10 +2,11 @@ package capnp // Struct is a pointer to a struct. type Struct struct { - seg *Segment - off Address - size ObjectSize - flags structFlags + seg *Segment + off Address + size ObjectSize + depthLimit uint + flags structFlags } // NewStruct creates a new struct, preferring placement in s. @@ -19,9 +20,10 @@ func NewStruct(s *Segment, sz ObjectSize) (Struct, error) { return Struct{}, err } return Struct{ - seg: seg, - off: addr, - size: sz, + seg: seg, + off: addr, + size: sz, + depthLimit: maxDepth, }, nil } @@ -58,10 +60,11 @@ func ToStructDefault(p Pointer, def []byte) (Struct, error) { // ToPtr converts the struct to a generic pointer. func (p Struct) ToPtr() Ptr { return Ptr{ - seg: p.seg, - off: p.off, - size: p.size, - flags: structPtrFlag(p.flags), + seg: p.seg, + off: p.off, + size: p.size, + depthLimit: p.depthLimit, + flags: structPtrFlag(p.flags), } } @@ -106,7 +109,7 @@ func (p Struct) Ptr(i uint16) (Ptr, error) { if p.seg == nil || i >= p.size.PointerCount { return Ptr{}, nil } - return p.seg.readPtr(p.pointerAddress(i)) + return p.seg.readPtr(p.pointerAddress(i), p.depthLimit) } // SetPointer is deprecated in favor of SetPtr. @@ -283,7 +286,7 @@ func copyStruct(cc copyContext, dst, src Struct) error { for j := uint16(0); j < numSrcPtrs && j < numDstPtrs; j++ { srcAddr, _ := srcPtrSect.element(int32(j), wordSize) dstAddr, _ := dstPtrSect.element(int32(j), wordSize) - m, err := src.seg.readPtr(srcAddr) + m, err := src.seg.readPtr(srcAddr, maxDepth) // copy already handles depth-limiting if err != nil { return err } From ed82af8991044aeb1245a9d0d82427867454f285 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Thu, 31 Mar 2016 12:49:50 -0700 Subject: [PATCH 6/6] capnp: remove defer in traversal check --- capn.go | 164 ++++++++++++++++++++++++-------------------- integration_test.go | 2 +- list.go | 17 +++++ pointer.go | 21 ------ struct.go | 9 +++ 5 files changed, 116 insertions(+), 97 deletions(-) diff --git a/capn.go b/capn.go index db351b52..7e498497 100644 --- a/capn.go +++ b/capn.go @@ -116,16 +116,6 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) { } func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) { - defer func(orig *Message) { - if !ptr.IsValid() || s.msg != orig { - return - } - if !s.msg.ReadLimiter().canRead(ptr.limitSize()) { - if err == nil { - ptr, err = Ptr{}, errReadLimit - } - } - }(s.msg) val := s.readRawPointer(off) s, off, val, err = s.resolveFarPointer(off, val) if err != nil { @@ -142,76 +132,25 @@ func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) { // using 32 bit maths as bits or bytes will overflow. switch val.pointerType() { case structPointer: - addr, ok := val.offset().resolve(off) - if !ok { - return Ptr{}, errPointerAddress + sp, err := s.readStructPtr(off, val) + if err != nil { + return Ptr{}, err } - sz := val.structSize() - if !s.regionInBounds(addr, sz.totalSize()) { - return Ptr{}, errPointerAddress + if !s.msg.ReadLimiter().canRead(sp.readSize()) { + return Ptr{}, errReadLimit } - return Struct{ - seg: s, - off: addr, - size: sz, - depthLimit: depthLimit - 1, - }.ToPtr(), nil + sp.depthLimit = depthLimit - 1 + return sp.ToPtr(), nil case listPointer: - addr, ok := val.offset().resolve(off) - if !ok { - return Ptr{}, errPointerAddress - } - lt := val.listType() - lsize, ok := val.totalListSize() - if !ok { - return Ptr{}, errOverflow - } - if !s.regionInBounds(addr, lsize) { - return Ptr{}, errPointerAddress - } - if lt == compositeList { - hdr := s.readRawPointer(addr) - var ok bool - addr, ok = addr.addSize(wordSize) - if !ok { - return Ptr{}, errOverflow - } - if hdr.pointerType() != structPointer { - return Ptr{}, errBadTag - } - sz := hdr.structSize() - n := int32(hdr.offset()) - // TODO(light): check that this has the same end address - if tsize, ok := sz.totalSize().times(n); !ok { - return Ptr{}, errOverflow - } else if !s.regionInBounds(addr, tsize) { - return Ptr{}, errPointerAddress - } - return List{ - seg: s, - size: sz, - off: addr, - length: n, - flags: isCompositeList, - depthLimit: depthLimit - 1, - }.ToPtr(), nil + lp, err := s.readListPtr(off, val) + if err != nil { + return Ptr{}, err } - if lt == bit1List { - return List{ - seg: s, - off: addr, - length: val.numListElements(), - flags: isBitList, - depthLimit: depthLimit - 1, - }.ToPtr(), nil + if !s.msg.ReadLimiter().canRead(lp.readSize()) { + return Ptr{}, errReadLimit } - return List{ - seg: s, - size: val.elementSize(), - off: addr, - length: val.numListElements(), - depthLimit: depthLimit - 1, - }.ToPtr(), nil + lp.depthLimit = depthLimit - 1 + return lp.ToPtr(), nil case otherPointer: if val.otherPointerType() != 0 { return Ptr{}, errOtherPointer @@ -226,6 +165,81 @@ func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) { } } +func (s *Segment) readStructPtr(off Address, val rawPointer) (Struct, error) { + addr, ok := val.offset().resolve(off) + if !ok { + return Struct{}, errPointerAddress + } + sz := val.structSize() + if !s.regionInBounds(addr, sz.totalSize()) { + return Struct{}, errPointerAddress + } + return Struct{ + seg: s, + off: addr, + size: sz, + }, nil +} + +func (s *Segment) readListPtr(off Address, val rawPointer) (List, error) { + addr, ok := val.offset().resolve(off) + if !ok { + return List{}, errPointerAddress + } + lt := val.listType() + lsize, ok := val.totalListSize() + if !ok { + return List{}, errOverflow + } + if !s.regionInBounds(addr, lsize) { + return List{}, errPointerAddress + } + limitSize := lsize + if limitSize == 0 { + + } + if lt == compositeList { + hdr := s.readRawPointer(addr) + var ok bool + addr, ok = addr.addSize(wordSize) + if !ok { + return List{}, errOverflow + } + if hdr.pointerType() != structPointer { + return List{}, errBadTag + } + sz := hdr.structSize() + n := int32(hdr.offset()) + // TODO(light): check that this has the same end address + if tsize, ok := sz.totalSize().times(n); !ok { + return List{}, errOverflow + } else if !s.regionInBounds(addr, tsize) { + return List{}, errPointerAddress + } + return List{ + seg: s, + size: sz, + off: addr, + length: n, + flags: isCompositeList, + }, nil + } + if lt == bit1List { + return List{ + seg: s, + off: addr, + length: val.numListElements(), + flags: isBitList, + }, nil + } + return List{ + seg: s, + size: val.elementSize(), + off: addr, + length: val.numListElements(), + }, nil +} + func (s *Segment) resolveFarPointer(off Address, val rawPointer) (*Segment, Address, rawPointer, error) { switch val.pointerType() { case doubleFarPointer: diff --git a/integration_test.go b/integration_test.go index 641e6ae1..3a2585fb 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1904,7 +1904,7 @@ func TestPointerTraverseDefense(t *testing.T) { for i := 0; i < limit; i++ { _, err := msg.RootPtr() if err != nil { - t.Fatal("RootPtr:", err) + t.Fatalf("iteration %d RootPtr: %v", i, err) } } diff --git a/list.go b/list.go index 0f6612ca..23b0c15a 100644 --- a/list.go +++ b/list.go @@ -102,6 +102,23 @@ func (p List) HasData() bool { return sz > 0 } +// readSize returns the list's size for the purposes of read limit +// accounting. +func (p List) readSize() Size { + if p.seg == nil { + return 0 + } + e := p.size.totalSize() + if e == 0 { + e = wordSize + } + sz, ok := e.times(p.length) + if !ok { + return maxSize + } + return sz +} + // value returns the equivalent raw list pointer. func (p List) value(paddr Address) rawPointer { if p.seg == nil { diff --git a/pointer.go b/pointer.go index 45cf8cb4..e6232d84 100644 --- a/pointer.go +++ b/pointer.go @@ -200,27 +200,6 @@ func (p Ptr) address() Address { panic("ptr not a valid struct or list") } -// limitSize returns the pointer's size for the purposes of read limit -// accounting. -func (p Ptr) limitSize() Size { - switch p.flags.ptrType() { - case structPtrType: - return p.size.totalSize() - case listPtrType: - elem := p.size.totalSize() - if elem == 0 { - elem = wordSize - } - sz, ok := elem.times(int32(p.lenOrCap)) - if !ok { - return maxSize - } - return sz - default: - return 0 - } -} - // Pointer is deprecated in favor of Ptr. type Pointer interface { // Segment returns the segment this pointer points into. diff --git a/struct.go b/struct.go index 4bd7a66e..96d85142 100644 --- a/struct.go +++ b/struct.go @@ -88,6 +88,15 @@ func (p Struct) HasData() bool { return !p.size.isZero() } +// readSize returns the struct's size for the purposes of read limit +// accounting. +func (p Struct) readSize() Size { + if p.seg == nil { + return 0 + } + return p.size.totalSize() +} + // value returns a raw struct pointer. func (p Struct) value(paddr Address) rawPointer { off := makePointerOffset(paddr, p.off)