Skip to content

Commit

Permalink
capnp: introduce Ptr type
Browse files Browse the repository at this point in the history
See https://github.com/zombiezen/go-capnproto2/wiki/New-Ptr-Type for
background and migration steps.
  • Loading branch information
zombiezen committed Mar 25, 2016
2 parents b5a265f + 22732bb commit c78274f
Show file tree
Hide file tree
Showing 22 changed files with 1,503 additions and 1,282 deletions.
51 changes: 35 additions & 16 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ func NewInterface(s *Segment, cap CapabilityID) Interface {
}
}

// ToInterface attempts to convert p into an interface. If p is not a
// valid interface, then ToInterface returns an invalid Interface.
// ToInterface is deprecated in favor of Ptr.Interface.
func ToInterface(p Pointer) Interface {
if !IsValid(p) {
return Interface{}
Expand All @@ -35,11 +34,25 @@ func ToInterface(p Pointer) Interface {
return i
}

// ToPtr converts the interface to a generic pointer.
func (p Interface) ToPtr() Ptr {
return Ptr{
seg: p.seg,
lenOrCap: uint32(p.cap),
flags: interfacePtrFlag,
}
}

// Segment returns the segment this pointer came from.
func (i Interface) Segment() *Segment {
return i.seg
}

// IsValid returns whether the interface is valid.
func (i Interface) IsValid() bool {
return i.seg != nil
}

// HasData is always true.
func (i Interface) HasData() bool {
return true
Expand Down Expand Up @@ -271,11 +284,11 @@ func (p *Pipeline) Struct() (Struct, error) {
if err != nil {
return Struct{}, err
}
ptr, err := Transform(s, p.Transform())
ptr, err := TransformPtr(s.ToPtr(), p.Transform())
if err != nil {
return Struct{}, err
}
return ToStruct(ptr), nil
return ptr.Struct(), nil
}

// Client returns the client version of p.
Expand Down Expand Up @@ -370,31 +383,37 @@ func (m *Method) String() string {
return string(buf)
}

// Transform applies a sequence of pipeline operations to a pointer
// and returns the result.
// Transform is deprecated in favor of TransformPtr.
func Transform(p Pointer, transform []PipelineOp) (Pointer, error) {
pp, err := TransformPtr(toPtr(p), transform)
return pp.toPointer(), err
}

// TransformPtr applies a sequence of pipeline operations to a pointer
// and returns the result.
func TransformPtr(p Ptr, transform []PipelineOp) (Ptr, error) {
n := len(transform)
if n == 0 {
return p, nil
}
s := ToStruct(p)
s := p.Struct()
for _, op := range transform[:n-1] {
field, err := s.Pointer(op.Field)
field, err := s.Ptr(op.Field)
if err != nil {
return nil, err
return Ptr{}, err
}
s, err = ToStructDefault(field, op.DefaultValue)
s, err = field.StructDefault(op.DefaultValue)
if err != nil {
return nil, err
return Ptr{}, err
}
}
op := transform[n-1]
p, err := s.Pointer(op.Field)
p, err := s.Ptr(op.Field)
if err != nil {
return nil, err
return Ptr{}, err
}
if op.DefaultValue != nil {
p, err = PointerDefault(p, op.DefaultValue)
p, err = p.Default(op.DefaultValue)
}
return p, err
}
Expand All @@ -413,11 +432,11 @@ func (ans immediateAnswer) Struct() (Struct, error) {
}

func (ans immediateAnswer) findClient(transform []PipelineOp) Client {
p, err := Transform(ans.s, transform)
p, err := TransformPtr(ans.s.ToPtr(), transform)
if err != nil {
return ErrorClient(err)
}
return ToInterface(p).Client()
return p.Interface().Client()
}

func (ans immediateAnswer) PipelineCall(transform []PipelineOp, call *Call) Answer {
Expand Down
125 changes: 65 additions & 60 deletions capn.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) {
return s.msg.Segment(id)
}

func (s *Segment) readPtr(off Address) (Pointer, error) {
func (s *Segment) readPtr(off Address) (Ptr, error) {
var err error
val := s.readRawPointer(off)
s, off, val, err = s.resolveFarPointer(off, val)
if err != nil {
return nil, err
return Ptr{}, err
}
if val == 0 {
return nil, nil
return Ptr{}, nil
}
// Be wary of overflow. Offset is 30 bits signed. List size is 29 bits
// unsigned. For both of these we need to check in terms of words if
Expand All @@ -131,81 +131,81 @@ func (s *Segment) readPtr(off Address) (Pointer, error) {
case structPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return nil, errPointerAddress
return Ptr{}, errPointerAddress
}
sz := val.structSize()
if !s.regionInBounds(addr, sz.totalSize()) {
return nil, errPointerAddress
return Ptr{}, errPointerAddress
}
return Struct{
seg: s,
off: addr,
size: sz,
}, nil
}.ToPtr(), nil
case listPointer:
addr, ok := val.offset().resolve(off)
if !ok {
return nil, errPointerAddress
return Ptr{}, errPointerAddress
}
lt := val.listType()
lsize, ok := val.totalListSize()
if !ok {
return nil, errOverflow
return Ptr{}, errOverflow
}
if !s.regionInBounds(addr, lsize) {
return nil, errPointerAddress
return Ptr{}, errPointerAddress
}
if lt == compositeList {
hdr := s.readRawPointer(addr)
var ok bool
addr, ok = addr.addSize(wordSize)
if !ok {
return nil, errOverflow
return Ptr{}, errOverflow
}
if hdr.pointerType() != structPointer {
return nil, errBadTag
return Ptr{}, errBadTag
}
sz := hdr.structSize()
n := int32(hdr.offset())
// TODO(light): check that this has the same end address
if tsize, ok := sz.totalSize().times(n); !ok {
return nil, errOverflow
return Ptr{}, errOverflow
} else if !s.regionInBounds(addr, tsize) {
return nil, errPointerAddress
return Ptr{}, errPointerAddress
}
return List{
seg: s,
size: sz,
off: addr,
length: n,
flags: isCompositeList,
}, nil
}.ToPtr(), nil
}
if lt == bit1List {
return List{
seg: s,
off: addr,
length: val.numListElements(),
flags: isBitList,
}, nil
}.ToPtr(), nil
}
return List{
seg: s,
size: val.elementSize(),
off: addr,
length: val.numListElements(),
}, nil
}.ToPtr(), nil
case otherPointer:
if val.otherPointerType() != 0 {
return nil, errOtherPointer
return Ptr{}, errOtherPointer
}
return Interface{
seg: s,
cap: val.capabilityIndex(),
}, nil
}.ToPtr(), nil
default:
// Only other types are far pointers.
return nil, errBadLandingPad
return Ptr{}, errBadLandingPad
}
}

Expand Down Expand Up @@ -258,29 +258,31 @@ func (s *Segment) resolveFarPointer(off Address, val rawPointer) (*Segment, Addr
type offset struct {
id SegmentID
boff, bend int64 // in bits
newval Pointer
newval Ptr
}

func makeOffsetKey(p Pointer) offset {
func makeOffsetKey(p Ptr) offset {
// Since this is used for copying, the address boundaries should already be clamped.
switch p := p.underlying().(type) {
case Struct:
switch p.flags.ptrType() {
case structPtrType:
s := p.Struct()
return offset{
id: p.seg.id,
boff: int64(p.off) * 8,
bend: (int64(p.off) + int64(p.size.totalSize())) * 8,
id: s.seg.id,
boff: int64(s.off) * 8,
bend: (int64(s.off) + int64(s.size.totalSize())) * 8,
}
case List:
case listPtrType:
l := p.List()
key := offset{
id: p.seg.id,
boff: int64(p.off) * 8,
id: l.seg.id,
boff: int64(l.off) * 8,
}
if p.flags&isBitList != 0 {
key.bend = int64(p.off)*8 + int64(p.length)
if l.flags&isBitList != 0 {
key.bend = int64(l.off)*8 + int64(l.length)
} else {
key.bend = (int64(p.off) + int64(p.size.totalSize())*int64(p.length)) * 8
key.bend = (int64(l.off) + int64(l.size.totalSize())*int64(l.length)) * 8
}
if p.flags&isCompositeList != 0 {
if l.flags&isCompositeList != 0 {
// Composite lists' offsets are after the tag word.
key.boff -= int64(wordSize) * 8
}
Expand All @@ -305,31 +307,32 @@ func compare(a, b rbtree.Item) int {
}
}

func needsCopy(dest *Segment, src Pointer) bool {
if src.Segment().msg != dest.msg {
func needsCopy(dest *Segment, src Ptr) bool {
if src.seg.msg != dest.msg {
return true
}
if s := ToStruct(src); IsValid(s) {
// Structs can only be referenced if they're not list members.
return s.flags&isListMember != 0
s := src.Struct()
if s.seg == nil {
return false
}
return false
// Structs can only be referenced if they're not list members.
return s.flags&isListMember != 0
}

func (s *Segment) writePtr(cc copyContext, off Address, src Pointer) error {
func (s *Segment) writePtr(cc copyContext, off Address, src Ptr) error {
// handle nulls
if !IsValid(src) {
if !src.IsValid() {
s.writeRawPointer(off, 0)
return nil
}
srcSeg := src.Segment()

if i := ToInterface(src); IsValid(i) {
if i := src.Interface(); i.Segment() != nil {
if s.msg != srcSeg.msg {
c := s.msg.AddCap(i.Client())
src = Pointer(NewInterface(s, c))
i = NewInterface(s, c)
}
s.writeRawPointer(off, src.value(off))
s.writeRawPointer(off, i.value(off))
return nil
}
if s != srcSeg {
Expand All @@ -345,7 +348,7 @@ func (s *Segment) writePtr(cc copyContext, off Address, src Pointer) error {
return err
}

srcAddr := pointerAddress(src)
srcAddr := src.address()
t.writeRawPointer(dstAddr, rawFarPointer(srcSeg.id, srcAddr))
// alloc guarantees that two words are available.
t.writeRawPointer(dstAddr+Address(wordSize), src.value(srcAddr-Address(wordSize)))
Expand All @@ -362,7 +365,7 @@ func (s *Segment) writePtr(cc copyContext, off Address, src Pointer) error {
return nil
}

func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Pointer) error {
func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Ptr) error {
if cc.depth >= 32 {
return errCopyDepth
}
Expand Down Expand Up @@ -397,44 +400,46 @@ func copyPointer(cc copyContext, dstSeg *Segment, dstAddr Address, src Pointer)
if err != nil {
return err
}
switch src := src.underlying().(type) {
case Struct:
switch src.flags.ptrType() {
case structPtrType:
s := src.Struct()
dst := Struct{
seg: newSeg,
off: newAddr,
size: src.size,
size: s.size,
// clear flags
}
key.newval = dst
key.newval = dst.ToPtr()
cc.copies.Insert(key)
if err := copyStruct(cc, dst, src); err != nil {
if err := copyStruct(cc, dst, s); err != nil {
return err
}
case List:
case listPtrType:
l := src.List()
dst := List{
seg: newSeg,
off: newAddr,
length: src.length,
size: src.size,
flags: src.flags,
length: l.length,
size: l.size,
flags: l.flags,
}
if dst.flags&isCompositeList != 0 {
// Copy tag word
newSeg.writeRawPointer(newAddr, src.seg.readRawPointer(src.off-Address(wordSize)))
newSeg.writeRawPointer(newAddr, l.seg.readRawPointer(l.off-Address(wordSize)))
var ok bool
dst.off, ok = dst.off.addSize(wordSize)
if !ok {
return errOverflow
}
}
key.newval = dst
key.newval = dst.ToPtr()
cc.copies.Insert(key)
// TODO(light): fast path for copying text/data
if dst.flags&isBitList != 0 {
copy(newSeg.data[newAddr:], src.seg.data[src.off:src.length+7/8])
copy(newSeg.data[newAddr:], l.seg.data[l.off:l.length+7/8])
} else {
for i := 0; i < src.Len(); i++ {
err := copyStruct(cc, dst.Struct(i), src.Struct(i))
for i := 0; i < l.Len(); i++ {
err := copyStruct(cc, dst.Struct(i), l.Struct(i))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit c78274f

Please sign in to comment.