Skip to content

Commit

Permalink
capnp: track pointer depth
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Mar 31, 2016
1 parent 877c47e commit 51ff4c5
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 103 deletions.
2 changes: 1 addition & 1 deletion capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestToInterface(t *testing.T) {
}{
{nil, Interface{}},
{Struct{}, Interface{}},
{Struct{seg: seg, off: 0}, Interface{}},
{Struct{seg: seg, off: 0, depthLimit: maxDepth}, Interface{}},
{Interface{}, Interface{}},
{Interface{seg, 42}, Interface{seg, 42}},
}
Expand Down
67 changes: 39 additions & 28 deletions capn.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ func (s *Segment) root() PointerList {
return PointerList{}
}
return PointerList{List{
seg: s,
length: 1,
size: sz,
seg: s,
length: 1,
size: sz,
depthLimit: s.msg.depthLimit(),
}}
}

Expand All @@ -114,7 +115,7 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) {
return s.msg.Segment(id)
}

func (s *Segment) readPtr(off Address) (ptr Ptr, err error) {
func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) {
defer func(orig *Message) {
if !ptr.IsValid() || s.msg != orig {
return
Expand All @@ -133,6 +134,9 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) {
if val == 0 {
return Ptr{}, nil
}
if depthLimit == 0 {
return Ptr{}, errDepthLimit
}
// Be wary of overflow. Offset is 30 bits signed. List size is 29 bits
// unsigned. For both of these we need to check in terms of words if
// using 32 bit maths as bits or bytes will overflow.
Expand All @@ -147,9 +151,10 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) {
return Ptr{}, errPointerAddress
}
return Struct{
seg: s,
off: addr,
size: sz,
seg: s,
off: addr,
size: sz,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
case listPointer:
addr, ok := val.offset().resolve(off)
Expand Down Expand Up @@ -183,26 +188,29 @@ func (s *Segment) readPtr(off Address) (ptr Ptr, err error) {
return Ptr{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
depthLimit: depthLimit - 1,
}.ToPtr(), nil
case otherPointer:
if val.otherPointerType() != 0 {
Expand Down Expand Up @@ -413,9 +421,10 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro
case structPtrType:
s := src.Struct()
dst := Struct{
seg: newSeg,
off: newAddr,
size: s.size,
seg: newSeg,
off: newAddr,
size: s.size,
depthLimit: maxDepth,
// clear flags
}
key.newval = dst.ToPtr()
Expand All @@ -426,11 +435,12 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro
case listPtrType:
l := src.List()
dst := List{
seg: newSeg,
off: newAddr,
length: l.length,
size: l.size,
flags: l.flags,
seg: newSeg,
off: newAddr,
length: l.length,
size: l.size,
flags: l.flags,
depthLimit: maxDepth,
}
if dst.flags&isCompositeList != 0 {
// Copy tag word
Expand Down Expand Up @@ -488,6 +498,7 @@ var (
errOtherPointer = errors.New("capnp: unknown pointer type")
errObjectSize = errors.New("capnp: invalid object size")
errReadLimit = errors.New("capnp: read traversal limit reached")
errDepthLimit = errors.New("capnp: depth limit reached")
)

var (
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1954,7 +1954,7 @@ func TestPointerDepthDefense(t *testing.T) {

func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) {
t.Parallel()
const limit = 64
const limit = 63
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
Expand Down
82 changes: 45 additions & 37 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (

// A List is a reference to an array of values.
type List struct {
seg *Segment
off Address
length int32
size ObjectSize
flags listFlags
seg *Segment
off Address
length int32
size ObjectSize
depthLimit uint
flags listFlags
}

// newPrimitiveList allocates a new list of primitive values, preferring placement in s.
Expand All @@ -25,10 +26,11 @@ func newPrimitiveList(s *Segment, sz Size, n int32) (List, error) {
return List{}, err
}
return List{
seg: s,
off: addr,
length: n,
size: ObjectSize{DataSize: sz},
seg: s,
off: addr,
length: n,
size: ObjectSize{DataSize: sz},
depthLimit: maxDepth,
}, nil
}

Expand All @@ -50,11 +52,12 @@ func NewCompositeList(s *Segment, sz ObjectSize, n int32) (List, error) {
// Add tag word
s.writeRawPointer(addr, rawStructPointer(pointerOffset(n), sz))
return List{
seg: s,
off: addr + Address(wordSize),
length: n,
size: sz,
flags: isCompositeList,
seg: s,
off: addr + Address(wordSize),
length: n,
size: sz,
flags: isCompositeList,
depthLimit: maxDepth,
}, nil
}

Expand All @@ -71,11 +74,12 @@ func ToListDefault(p Pointer, def []byte) (List, error) {
// ToPtr converts the list to a generic pointer.
func (p List) ToPtr() Ptr {
return Ptr{
seg: p.seg,
off: p.off,
lenOrCap: uint32(p.length),
size: p.size,
flags: listPtrFlag(p.flags),
seg: p.seg,
off: p.off,
lenOrCap: uint32(p.length),
size: p.size,
depthLimit: p.depthLimit,
flags: listPtrFlag(p.flags),
}
}

Expand Down Expand Up @@ -176,10 +180,11 @@ func (p List) Struct(i int) Struct {
}
addr, _ := p.elem(i)
return Struct{
seg: p.seg,
off: addr,
size: p.size,
flags: isListMember,
seg: p.seg,
off: addr,
size: p.size,
flags: isListMember,
depthLimit: p.depthLimit - 1,
}
}

Expand All @@ -201,10 +206,11 @@ func NewBitList(s *Segment, n int32) (BitList, error) {
return BitList{}, err
}
return BitList{List{
seg: s,
off: addr,
length: n,
flags: isBitList,
seg: s,
off: addr,
length: n,
flags: isBitList,
depthLimit: maxDepth,
}}, nil
}

Expand Down Expand Up @@ -246,10 +252,11 @@ func NewPointerList(s *Segment, n int32) (PointerList, error) {
return PointerList{}, err
}
return PointerList{List{
seg: s,
off: addr,
length: n,
size: ObjectSize{PointerCount: 1},
seg: s,
off: addr,
length: n,
size: ObjectSize{PointerCount: 1},
depthLimit: maxDepth,
}}, nil
}

Expand All @@ -262,7 +269,7 @@ func (p PointerList) At(i int) (Pointer, error) {
// PtrAt returns the i'th pointer in the list.
func (p PointerList) PtrAt(i int) (Ptr, error) {
addr, _ := p.elem(i)
return p.seg.readPtr(addr)
return p.seg.readPtr(addr, p.depthLimit)
}

// Set is deprecated in favor of SetPtr.
Expand Down Expand Up @@ -291,7 +298,7 @@ func NewTextList(s *Segment, n int32) (TextList, error) {
// At returns the i'th string in the list.
func (l TextList) At(i int) (string, error) {
addr, _ := l.elem(i)
p, err := l.seg.readPtr(addr)
p, err := l.seg.readPtr(addr, l.depthLimit)
if err != nil {
return "", err
}
Expand All @@ -302,7 +309,7 @@ func (l TextList) At(i int) (string, error) {
// The underlying array of the slice is the segment data.
func (l TextList) BytesAt(i int) ([]byte, error) {
addr, _ := l.elem(i)
p, err := l.seg.readPtr(addr)
p, err := l.seg.readPtr(addr, l.depthLimit)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -334,7 +341,7 @@ func NewDataList(s *Segment, n int32) (DataList, error) {
// At returns the i'th data in the list.
func (l DataList) At(i int) ([]byte, error) {
addr, _ := l.elem(i)
p, err := l.seg.readPtr(addr)
p, err := l.seg.readPtr(addr, l.depthLimit)
if err != nil {
return nil, err
}
Expand All @@ -358,8 +365,9 @@ type VoidList struct{ List }
// s is only used for Segment()'s return value.
func NewVoidList(s *Segment, n int32) VoidList {
return VoidList{List{
seg: s,
length: n,
seg: s,
length: n,
depthLimit: maxDepth,
}}
}

Expand Down
20 changes: 11 additions & 9 deletions list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ func TestToListDefault(t *testing.T) {
}{
{nil, nil, List{}},
{Struct{}, nil, List{}},
{Struct{seg: seg, off: 0}, nil, List{}},
{Struct{seg: seg, off: 0, depthLimit: maxDepth}, nil, List{}},
{List{}, nil, List{}},
{
ptr: List{
seg: seg,
off: 8,
length: 1,
size: ObjectSize{DataSize: 8},
seg: seg,
off: 8,
length: 1,
size: ObjectSize{DataSize: 8},
depthLimit: maxDepth,
},
list: List{
seg: seg,
off: 8,
length: 1,
size: ObjectSize{DataSize: 8},
seg: seg,
off: 8,
length: 1,
size: ObjectSize{DataSize: 8},
depthLimit: maxDepth,
},
},
}
Expand Down
9 changes: 9 additions & 0 deletions mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ const (
defaultDepthLimit = 64
)

const maxDepth = ^uint(0)

// A Message is a tree of Cap'n Proto objects, split into one or more
// segments of contiguous memory. The only required field is Arena.
// A Message is safe to read from multiple goroutines.
Expand Down Expand Up @@ -147,6 +149,13 @@ func (m *Message) ReadLimiter() *ReadLimiter {
return &m.rlimit
}

func (m *Message) depthLimit() uint {
if m.DepthLimit != 0 {
return m.DepthLimit
}
return defaultDepthLimit
}

// NumSegments returns the number of segments in the message.
func (m *Message) NumSegments() int64 {
return int64(m.Arena.NumSegments())
Expand Down
Loading

0 comments on commit 51ff4c5

Please sign in to comment.