Skip to content

Commit

Permalink
Merge branch 'security'
Browse files Browse the repository at this point in the history
This adds size and depth security checks while reading.  If this breaks
your application, set the DepthLimit and TraversalLimit as appropriate
after a Decode or Unmarshal.

Fixes #14
  • Loading branch information
zombiezen committed Mar 31, 2016
2 parents c15a22e + ed82af8 commit 8ed242d
Show file tree
Hide file tree
Showing 10 changed files with 581 additions and 149 deletions.
2 changes: 1 addition & 1 deletion capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestToInterface(t *testing.T) {
}{
{nil, Interface{}},
{Struct{}, Interface{}},
{Struct{seg: seg, off: 0}, Interface{}},
{Struct{seg: seg, off: 0, depthLimit: maxDepth}, Interface{}},
{Interface{}, Interface{}},
{Interface{seg, 42}, Interface{seg, 42}},
}
Expand Down
183 changes: 109 additions & 74 deletions capn.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ func (s *Segment) root() PointerList {
return PointerList{}
}
return PointerList{List{
seg: s,
length: 1,
size: sz,
seg: s,
length: 1,
size: sz,
depthLimit: s.msg.depthLimit(),
}}
}

Expand All @@ -114,8 +115,7 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) {
return s.msg.Segment(id)
}

func (s *Segment) readPtr(off Address) (Ptr, error) {
var err error
func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) {
val := s.readRawPointer(off)
s, off, val, err = s.resolveFarPointer(off, val)
if err != nil {
Expand All @@ -124,77 +124,33 @@ func (s *Segment) readPtr(off Address) (Ptr, error) {
if val == 0 {
return Ptr{}, nil
}
if depthLimit == 0 {
return Ptr{}, errDepthLimit
}
// Be wary of overflow. Offset is 30 bits signed. List size is 29 bits
// unsigned. For both of these we need to check in terms of words if
// using 32 bit maths as bits or bytes will overflow.
switch val.pointerType() {
case structPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return Ptr{}, errPointerAddress
sp, err := s.readStructPtr(off, val)
if err != nil {
return Ptr{}, err
}
sz := val.structSize()
if !s.regionInBounds(addr, sz.totalSize()) {
return Ptr{}, errPointerAddress
if !s.msg.ReadLimiter().canRead(sp.readSize()) {
return Ptr{}, errReadLimit
}
return Struct{
seg: s,
off: addr,
size: sz,
}.ToPtr(), nil
sp.depthLimit = depthLimit - 1
return sp.ToPtr(), nil
case listPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return Ptr{}, errPointerAddress
}
lt := val.listType()
lsize, ok := val.totalListSize()
if !ok {
return Ptr{}, errOverflow
}
if !s.regionInBounds(addr, lsize) {
return Ptr{}, errPointerAddress
}
if lt == compositeList {
hdr := s.readRawPointer(addr)
var ok bool
addr, ok = addr.addSize(wordSize)
if !ok {
return Ptr{}, errOverflow
}
if hdr.pointerType() != structPointer {
return Ptr{}, errBadTag
}
sz := hdr.structSize()
n := int32(hdr.offset())
// TODO(light): check that this has the same end address
if tsize, ok := sz.totalSize().times(n); !ok {
return Ptr{}, errOverflow
} else if !s.regionInBounds(addr, tsize) {
return Ptr{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
}.ToPtr(), nil
lp, err := s.readListPtr(off, val)
if err != nil {
return Ptr{}, err
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
}.ToPtr(), nil
if !s.msg.ReadLimiter().canRead(lp.readSize()) {
return Ptr{}, errReadLimit
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
}.ToPtr(), nil
lp.depthLimit = depthLimit - 1
return lp.ToPtr(), nil
case otherPointer:
if val.otherPointerType() != 0 {
return Ptr{}, errOtherPointer
Expand All @@ -209,6 +165,81 @@ func (s *Segment) readPtr(off Address) (Ptr, error) {
}
}

func (s *Segment) readStructPtr(off Address, val rawPointer) (Struct, error) {
addr, ok := val.offset().resolve(off)
if !ok {
return Struct{}, errPointerAddress
}
sz := val.structSize()
if !s.regionInBounds(addr, sz.totalSize()) {
return Struct{}, errPointerAddress
}
return Struct{
seg: s,
off: addr,
size: sz,
}, nil
}

func (s *Segment) readListPtr(off Address, val rawPointer) (List, error) {
addr, ok := val.offset().resolve(off)
if !ok {
return List{}, errPointerAddress
}
lt := val.listType()
lsize, ok := val.totalListSize()
if !ok {
return List{}, errOverflow
}
if !s.regionInBounds(addr, lsize) {
return List{}, errPointerAddress
}
limitSize := lsize
if limitSize == 0 {

}
if lt == compositeList {
hdr := s.readRawPointer(addr)
var ok bool
addr, ok = addr.addSize(wordSize)
if !ok {
return List{}, errOverflow
}
if hdr.pointerType() != structPointer {
return List{}, errBadTag
}
sz := hdr.structSize()
n := int32(hdr.offset())
// TODO(light): check that this has the same end address
if tsize, ok := sz.totalSize().times(n); !ok {
return List{}, errOverflow
} else if !s.regionInBounds(addr, tsize) {
return List{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
}, nil
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
}, nil
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
}, nil
}

func (s *Segment) resolveFarPointer(off Address, val rawPointer) (*Segment, Address, rawPointer, error) {
switch val.pointerType() {
case doubleFarPointer:
Expand Down Expand Up @@ -404,9 +435,10 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro
case structPtrType:
s := src.Struct()
dst := Struct{
seg: newSeg,
off: newAddr,
size: s.size,
seg: newSeg,
off: newAddr,
size: s.size,
depthLimit: maxDepth,
// clear flags
}
key.newval = dst.ToPtr()
Expand All @@ -417,11 +449,12 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) erro
case listPtrType:
l := src.List()
dst := List{
seg: newSeg,
off: newAddr,
length: l.length,
size: l.size,
flags: l.flags,
seg: newSeg,
off: newAddr,
length: l.length,
size: l.size,
flags: l.flags,
depthLimit: maxDepth,
}
if dst.flags&isCompositeList != 0 {
// Copy tag word
Expand Down Expand Up @@ -478,6 +511,8 @@ var (
errBadTag = errors.New("capnp: invalid tag word")
errOtherPointer = errors.New("capnp: unknown pointer type")
errObjectSize = errors.New("capnp: invalid object size")
errReadLimit = errors.New("capnp: read traversal limit reached")
errDepthLimit = errors.New("capnp: depth limit reached")
)

var (
Expand Down
125 changes: 124 additions & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1867,7 +1867,7 @@ func BenchmarkUnmarshal_Reuse(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
*ta = testArena(data[r.Intn(len(data))][8:])
*msg = capnp.Message{Arena: arena}
msg.Reset(arena)
a, _ := air.ReadRootBenchmarkA(msg)
unmarshalA(a)
}
Expand All @@ -1889,3 +1889,126 @@ func (ta testArena) Data(id capnp.SegmentID) ([]byte, error) {
func (ta testArena) Allocate(capnp.Size, map[capnp.SegmentID]*capnp.Segment) (capnp.SegmentID, []byte, error) {
return 0, nil, errors.New("test arena: can't allocate")
}

func TestPointerTraverseDefense(t *testing.T) {
t.Parallel()
const limit = 128
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word
0, 0, 0, 0, 0, 0, 0, 0, // struct's data
}),
TraverseLimit: limit * 8,
}

for i := 0; i < limit; i++ {
_, err := msg.RootPtr()
if err != nil {
t.Fatalf("iteration %d RootPtr: %v", i, err)
}
}

if _, err := msg.RootPtr(); err == nil {
t.Fatalf("deref %d did not fail as expected", limit+1)
}
}

func TestPointerDepthDefense(t *testing.T) {
t.Parallel()
const limit = 64
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself
}),
DepthLimit: limit,
}
root, err := msg.Root()
if err != nil {
t.Fatal("Root:", err)
}

curr := capnp.ToStruct(root)
if !capnp.IsValid(curr) {
t.Fatal("Root is not a struct")
}
for i := 0; i < limit-1; i++ {
p, err := curr.Pointer(0)
if err != nil {
t.Fatalf("deref %d fail: %v", i+1, err)
}
if !capnp.IsValid(p) {
t.Fatalf("deref %d is invalid", i+1)
}
curr = capnp.ToStruct(p)
if !capnp.IsValid(curr) {
t.Fatalf("deref %d is not a struct", i+1)
}
}

_, err = curr.Pointer(0)
if err == nil {
t.Fatalf("deref %d did not fail as expected", limit)
}
}

func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) {
t.Parallel()
const limit = 63
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word)
0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word
}),
DepthLimit: limit,
}

toStruct := func(p capnp.Pointer, err error) (capnp.Struct, error) {
if err != nil {
return capnp.Struct{}, err
}
if !capnp.IsValid(p) {
return capnp.Struct{}, errors.New("invalid pointer")
}
s := capnp.ToStruct(p)
if !capnp.IsValid(s) {
return capnp.Struct{}, errors.New("not a struct")
}
return s, nil
}
toList := func(p capnp.Pointer, err error) (capnp.List, error) {
if err != nil {
return capnp.List{}, err
}
if !capnp.IsValid(p) {
return capnp.List{}, errors.New("invalid pointer")
}
l := capnp.ToList(p)
if !capnp.IsValid(l) {
return capnp.List{}, errors.New("not a list")
}
return l, nil
}
curr, err := toStruct(msg.Root())
if err != nil {
t.Fatal("Root:", err)
}
for i := limit; i > 2; {
l, err := toList(curr.Pointer(0))
if err != nil {
t.Fatalf("deref %d (for list): %v", limit-i+1, err)
}
i--
curr, err = toStruct(capnp.PointerList{List: l}.At(0))
if err != nil {
t.Fatalf("deref %d (for struct): %v", limit-i+1, err)
}
i--
}

_, err = curr.Pointer(0)
if err == nil {
t.Fatalf("deref %d did not fail as expected", limit)
}
}

0 comments on commit 8ed242d

Please sign in to comment.