Skip to content

Commit

Permalink
refactor(tree): safety check for all node types
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Dec 5, 2023
1 parent 5a8a683 commit 4020135
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
1 change: 1 addition & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
errSerializeHashedNode = errors.New("trying to serialize a hashed internal node")
errInsertIntoOtherStem = errors.New("insert splits a stem where it should not happen")
errUnknownNodeType = errors.New("unknown node type detected")
errNilNodeType = errors.New("nil node type detected")
errMissingNodeInStateless = errors.New("trying to access a node that is missing from the stateless view")
errIsPOAStub = errors.New("trying to read/write a proof of absence leaf node")
)
Expand Down
98 changes: 67 additions & 31 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,14 @@ func (n *InternalNode) Children() []VerkleNode {

// SetChild *replaces* the child at the given index with the given node.
func (n *InternalNode) SetChild(i int, c VerkleNode) error {
if c == nil {
return errNilNodeType
}

if i >= NodeWidth {
return errors.New("child index higher than node width")
}

n.children[i] = c
return nil
}
Expand Down Expand Up @@ -608,8 +613,14 @@ func (n *InternalNode) Delete(key []byte, resolver NodeResolverFn) (bool, error)
// signal that this node should be deleted
// as well.
for _, c := range n.children {
if _, ok := c.(Empty); !ok {
switch c.(type) {
case *InternalNode, *LeafNode:
break
case Empty, HashedNode:
case UnknownNode:
panic(errUnknownNodeType)
default:
panic(errNilNodeType)
}
}

Expand Down Expand Up @@ -639,14 +650,20 @@ func (n *InternalNode) Flush(flush NodeFlushFn) {

n.Commit()
for i, child := range n.children {
if c, ok := child.(*InternalNode); ok {
switch c := child.(type) {
case *InternalNode:
c.Commit()
c.Flush(flushAndCapturePath)
n.children[i] = HashedNode{}
} else if c, ok := child.(*LeafNode); ok {
case *LeafNode:
c.Commit()
flushAndCapturePath(c.stem[:n.depth+1], n.children[i])
flushAndCapturePath(c.stem[:n.depth+1], c)
n.children[i] = HashedNode{}
case Empty, HashedNode:
case UnknownNode:
panic(errUnknownNodeType)
default:
panic(errNilNodeType)
}
}
flush(path, n)
Expand Down Expand Up @@ -875,33 +892,32 @@ func (n *InternalNode) GetProofItems(keys keylist, resolver NodeResolverFn) (*Pr
var fiPtrs [NodeWidth]*Fr
var points [NodeWidth]*Point
for i, child := range n.children {
var c VerkleNode

fiPtrs[i] = &fi[i]
if child != nil {
var c VerkleNode
if _, ok := child.(HashedNode); ok {
childpath := make([]byte, n.depth+1)
copy(childpath[:n.depth+1], keys[0][:n.depth])
childpath[n.depth] = byte(i)
if resolver == nil {
return nil, nil, nil, fmt.Errorf("no resolver for path %x", childpath)
}
serialized, err := resolver(childpath)
if err != nil {
return nil, nil, nil, fmt.Errorf("error resolving for path %x: %w", childpath, err)
}
c, err = ParseNode(serialized, n.depth+1)
if err != nil {
return nil, nil, nil, err
}
n.children[i] = c
} else {
c = child
switch child := child.(type) {
case HashedNode:
childpath := make([]byte, n.depth+1)
copy(childpath[:n.depth+1], keys[0][:n.depth])
childpath[n.depth] = byte(i)
if resolver == nil {
return nil, nil, nil, fmt.Errorf("no resolver for path %x", childpath)
}
points[i] = c.Commitment()
} else {
// TODO: add a test case to cover this scenario.
points[i] = new(Point)
serialized, err := resolver(childpath)
if err != nil {
return nil, nil, nil, fmt.Errorf("error resolving for path %x: %w", childpath, err)
}
c, err = ParseNode(serialized, n.depth+1)
if err != nil {
return nil, nil, nil, err
}
n.children[i] = c
case *InternalNode, *LeafNode, Empty, UnknownNode:
c = child
default:
panic(errNilNodeType)
}
points[i] = c.Commitment()
}
if err := banderwagon.BatchMapToScalarField(fiPtrs[:], points[:]); err != nil {
return nil, nil, nil, fmt.Errorf("batch mapping to scalar fields: %s", err)
Expand Down Expand Up @@ -972,8 +988,14 @@ func (n *InternalNode) Serialize() ([]byte, error) {
// Write the <bitlist>.
bitlist := ret[internalBitlistOffset:internalCommitmentOffset]
for i, c := range n.children {
if _, ok := c.(Empty); !ok {
switch c.(type) {
case *InternalNode, *LeafNode:
setBit(bitlist, i)
case Empty, HashedNode:
case UnknownNode:
panic(errUnknownNodeType)
default:
panic(errNilNodeType)
}
}

Expand All @@ -995,6 +1017,9 @@ func (n *InternalNode) Copy() VerkleNode {
}

for i, child := range n.children {
if child == nil {
panic(errNilNodeType)
}
ret.children[i] = child.Copy()
}

Expand Down Expand Up @@ -1024,7 +1049,7 @@ func (n *InternalNode) toDot(parent, path string) string {

for i, child := range n.children {
if child == nil {
continue
panic(errNilNodeType)
}
ret = fmt.Sprintf("%s%s", ret, child.toDot(me, fmt.Sprintf("%s%02x", path, i)))
}
Expand Down Expand Up @@ -1719,6 +1744,11 @@ func (n *InternalNode) collectNonHashedNodes(list []VerkleNode, paths [][]byte,
copy(childpath, path)
childpath[len(path)] = byte(i)
list, paths = childNode.collectNonHashedNodes(list, paths, childpath)
case Empty, HashedNode:
case UnknownNode:
panic(errUnknownNodeType)
default:
panic(errNilNodeType)
}
}
return list, paths
Expand All @@ -1729,8 +1759,14 @@ func (n *InternalNode) serializeInternalWithUncompressedCommitment(pointsIdx map
serialized := make([]byte, nodeTypeSize+bitlistSize+banderwagon.UncompressedSize)
bitlist := serialized[internalBitlistOffset:internalCommitmentOffset]
for i, c := range n.children {
if _, ok := c.(Empty); !ok {
switch c.(type) {
case *InternalNode, *LeafNode:
setBit(bitlist, i)
case Empty, HashedNode:
case UnknownNode:
panic(errUnknownNodeType)
default:
panic(errNilNodeType)
}
}
serialized[nodeTypeOffset] = internalRLPType
Expand Down

0 comments on commit 4020135

Please sign in to comment.