Skip to content

Commit

Permalink
capnp: remove defer in traversal check
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Mar 31, 2016
1 parent 51ff4c5 commit ed82af8
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 97 deletions.
164 changes: 89 additions & 75 deletions capn.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,6 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) {
}

func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) {
defer func(orig *Message) {
if !ptr.IsValid() || s.msg != orig {
return
}
if !s.msg.ReadLimiter().canRead(ptr.limitSize()) {
if err == nil {
ptr, err = Ptr{}, errReadLimit
}
}
}(s.msg)
val := s.readRawPointer(off)
s, off, val, err = s.resolveFarPointer(off, val)
if err != nil {
Expand All @@ -142,76 +132,25 @@ func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) {
// using 32 bit maths as bits or bytes will overflow.
switch val.pointerType() {
case structPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return Ptr{}, errPointerAddress
sp, err := s.readStructPtr(off, val)
if err != nil {
return Ptr{}, err
}
sz := val.structSize()
if !s.regionInBounds(addr, sz.totalSize()) {
return Ptr{}, errPointerAddress
if !s.msg.ReadLimiter().canRead(sp.readSize()) {
return Ptr{}, errReadLimit
}
return Struct{
seg: s,
off: addr,
size: sz,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
sp.depthLimit = depthLimit - 1
return sp.ToPtr(), nil
case listPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return Ptr{}, errPointerAddress
}
lt := val.listType()
lsize, ok := val.totalListSize()
if !ok {
return Ptr{}, errOverflow
}
if !s.regionInBounds(addr, lsize) {
return Ptr{}, errPointerAddress
}
if lt == compositeList {
hdr := s.readRawPointer(addr)
var ok bool
addr, ok = addr.addSize(wordSize)
if !ok {
return Ptr{}, errOverflow
}
if hdr.pointerType() != structPointer {
return Ptr{}, errBadTag
}
sz := hdr.structSize()
n := int32(hdr.offset())
// TODO(light): check that this has the same end address
if tsize, ok := sz.totalSize().times(n); !ok {
return Ptr{}, errOverflow
} else if !s.regionInBounds(addr, tsize) {
return Ptr{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
lp, err := s.readListPtr(off, val)
if err != nil {
return Ptr{}, err
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
depthLimit: depthLimit - 1,
}.ToPtr(), nil
if !s.msg.ReadLimiter().canRead(lp.readSize()) {
return Ptr{}, errReadLimit
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
depthLimit: depthLimit - 1,
}.ToPtr(), nil
lp.depthLimit = depthLimit - 1
return lp.ToPtr(), nil
case otherPointer:
if val.otherPointerType() != 0 {
return Ptr{}, errOtherPointer
Expand All @@ -226,6 +165,81 @@ func (s *Segment) readPtr(off Address, depthLimit uint) (ptr Ptr, err error) {
}
}

func (s *Segment) readStructPtr(off Address, val rawPointer) (Struct, error) {
addr, ok := val.offset().resolve(off)
if !ok {
return Struct{}, errPointerAddress
}
sz := val.structSize()
if !s.regionInBounds(addr, sz.totalSize()) {
return Struct{}, errPointerAddress
}
return Struct{
seg: s,
off: addr,
size: sz,
}, nil
}

func (s *Segment) readListPtr(off Address, val rawPointer) (List, error) {
addr, ok := val.offset().resolve(off)
if !ok {
return List{}, errPointerAddress
}
lt := val.listType()
lsize, ok := val.totalListSize()
if !ok {
return List{}, errOverflow
}
if !s.regionInBounds(addr, lsize) {
return List{}, errPointerAddress
}
limitSize := lsize
if limitSize == 0 {

}
if lt == compositeList {
hdr := s.readRawPointer(addr)
var ok bool
addr, ok = addr.addSize(wordSize)
if !ok {
return List{}, errOverflow
}
if hdr.pointerType() != structPointer {
return List{}, errBadTag
}
sz := hdr.structSize()
n := int32(hdr.offset())
// TODO(light): check that this has the same end address
if tsize, ok := sz.totalSize().times(n); !ok {
return List{}, errOverflow
} else if !s.regionInBounds(addr, tsize) {
return List{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
}, nil
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
}, nil
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
}, nil
}

func (s *Segment) resolveFarPointer(off Address, val rawPointer) (*Segment, Address, rawPointer, error) {
switch val.pointerType() {
case doubleFarPointer:
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,7 @@ func TestPointerTraverseDefense(t *testing.T) {
for i := 0; i < limit; i++ {
_, err := msg.RootPtr()
if err != nil {
t.Fatal("RootPtr:", err)
t.Fatalf("iteration %d RootPtr: %v", i, err)
}
}

Expand Down
17 changes: 17 additions & 0 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ func (p List) HasData() bool {
return sz > 0
}

// readSize returns the list's size for the purposes of read limit
// accounting.
func (p List) readSize() Size {
if p.seg == nil {
return 0
}
e := p.size.totalSize()
if e == 0 {
e = wordSize
}
sz, ok := e.times(p.length)
if !ok {
return maxSize
}
return sz
}

// value returns the equivalent raw list pointer.
func (p List) value(paddr Address) rawPointer {
if p.seg == nil {
Expand Down
21 changes: 0 additions & 21 deletions pointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,27 +200,6 @@ func (p Ptr) address() Address {
panic("ptr not a valid struct or list")
}

// limitSize returns the pointer's size for the purposes of read limit
// accounting.
func (p Ptr) limitSize() Size {
switch p.flags.ptrType() {
case structPtrType:
return p.size.totalSize()
case listPtrType:
elem := p.size.totalSize()
if elem == 0 {
elem = wordSize
}
sz, ok := elem.times(int32(p.lenOrCap))
if !ok {
return maxSize
}
return sz
default:
return 0
}
}

// Pointer is deprecated in favor of Ptr.
type Pointer interface {
// Segment returns the segment this pointer points into.
Expand Down
9 changes: 9 additions & 0 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ func (p Struct) HasData() bool {
return !p.size.isZero()
}

// readSize returns the struct's size for the purposes of read limit
// accounting.
func (p Struct) readSize() Size {
if p.seg == nil {
return 0
}
return p.size.totalSize()
}

// value returns a raw struct pointer.
func (p Struct) value(paddr Address) rawPointer {
off := makePointerOffset(paddr, p.off)
Expand Down

0 comments on commit ed82af8

Please sign in to comment.