Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: add error return params to tree interface #157

Merged
merged 5 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions datasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"math"
"sync"

"golang.org/x/sync/errgroup"
)

// ErrUnevenChunks is thrown when non-nil chunks are not all of equal size.
Expand Down Expand Up @@ -187,29 +189,41 @@ func (ds *dataSquare) resetRoots() {
}
}

func (ds *dataSquare) computeRoots() {
var wg sync.WaitGroup
func (ds *dataSquare) computeRoots() error {
var g errgroup.Group

rowRoots := make([][]byte, ds.width)
colRoots := make([][]byte, ds.width)

for i := uint(0); i < ds.width; i++ {
wg.Add(2)

go func(i uint) {
defer wg.Done()
rowRoots[i] = ds.getRowRoot(i)
}(i)
i := i // https://go.dev/doc/faq#closures_and_goroutines
g.Go(func() error {
rowRoot, err := ds.getRowRoot(i)
if err != nil {
return err
}
rowRoots[i] = rowRoot
return nil
})

g.Go(func() error {
colRoot, err := ds.getColRoot(i)
if err != nil {
return err
}
colRoots[i] = colRoot
return nil
})
}

go func(i uint) {
defer wg.Done()
colRoots[i] = ds.getColRoot(i)
}(i)
err := g.Wait()
if err != nil {
return err
}

wg.Wait()
ds.rowRoots = rowRoots
ds.colRoots = colRoots
return nil
}

// getRowRoots returns the Merkle roots of all the rows in the square.
Expand All @@ -223,9 +237,9 @@ func (ds *dataSquare) getRowRoots() [][]byte {

// getRowRoot calculates and returns the root of the selected row. Note: unlike the
// getRowRoots method, getRowRoot does not write to the built-in cache.
func (ds *dataSquare) getRowRoot(x uint) []byte {
func (ds *dataSquare) getRowRoot(x uint) ([]byte, error) {
if ds.rowRoots != nil {
return ds.rowRoots[x]
return ds.rowRoots[x], nil
}

tree := ds.createTreeFn(Row, x)
Expand All @@ -247,9 +261,9 @@ func (ds *dataSquare) getColRoots() [][]byte {

// getColRoot calculates and returns the root of the selected row. Note: unlike the
// getColRoots method, getColRoot does not write to the built-in cache.
func (ds *dataSquare) getColRoot(y uint) []byte {
func (ds *dataSquare) getColRoot(y uint) ([]byte, error) {
if ds.colRoots != nil {
return ds.colRoots[y]
return ds.colRoots[y], nil
}

tree := ds.createTreeFn(Col, y)
Expand Down
77 changes: 57 additions & 20 deletions datasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"reflect"
"testing"

"github.com/celestiaorg/merkletree"
"github.com/minio/sha256-simd"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -138,36 +140,60 @@ func TestLazyRootGeneration(t *testing.T) {
var colRoots [][]byte

for i := uint(0); i < square.width; i++ {
rowRoots = append(rowRoots, square.getRowRoot(i))
colRoots = append(rowRoots, square.getColRoot(i))
rowRoot, err := square.getRowRoot(i)
assert.NoError(t, err)
colRoot, err := square.getColRoot(i)
assert.NoError(t, err)
rowRoots = append(rowRoots, rowRoot)
colRoots = append(colRoots, colRoot)
}

square.computeRoots()
err = square.computeRoots()
assert.NoError(t, err)

if !reflect.DeepEqual(square.rowRoots, rowRoots) && !reflect.DeepEqual(square.colRoots, colRoots) {
t.Error("getRowRoot or getColRoot did not produce identical roots to computeRoots")
}
}

func TestComputeRoots(t *testing.T) {
t.Run("default tree computeRoots() returns no error", func(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
assert.NoError(t, err)
err = square.computeRoots()
assert.NoError(t, err)
})
t.Run("error tree computeRoots() returns an error", func(t *testing.T) {
square, err := newDataSquare([][]byte{{1}}, newErrorTree)
assert.NoError(t, err)
err = square.computeRoots()
assert.Error(t, err)
})
}

func TestRootAPI(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

for i := uint(0); i < square.width; i++ {
if !reflect.DeepEqual(square.getRowRoots()[i], square.getRowRoot(i)) {
rowRoot, err := square.getRowRoot(i)
assert.NoError(t, err)
if !reflect.DeepEqual(square.getRowRoots()[i], rowRoot) {
t.Errorf(
"Row root API results in different roots, expected %v got %v",
square.getRowRoots()[i],
square.getRowRoot(i),
rowRoot,
)
}
if !reflect.DeepEqual(square.getColRoots()[i], square.getColRoot(i)) {
colRoot, err := square.getColRoot(i)
assert.NoError(t, err)
if !reflect.DeepEqual(square.getColRoots()[i], colRoot) {
t.Errorf(
"Column root API results in different roots, expected %v got %v",
square.getColRoots()[i],
square.getColRoot(i),
colRoot,
)
}
}
Expand Down Expand Up @@ -205,7 +231,8 @@ func BenchmarkEDSRoots(b *testing.B) {
func(b *testing.B) {
for n := 0; n < b.N; n++ {
square.resetRoots()
square.computeRoots()
err := square.computeRoots()
assert.NoError(b, err)
}
},
)
Expand All @@ -224,18 +251,6 @@ func computeRowProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, ui
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
}

func computeColProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed an unused function

tree := ds.createTreeFn(Col, y)
data := ds.col(y)

for i := uint(0); i < ds.width; i++ {
tree.Push(data[i])
}
// TODO(ismail): check for overflow when casting from uint -> int
merkleRoot, proof, proofIndex, numLeaves := treeProve(tree.(*DefaultTree), int(x))
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
}

func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, proofIndex uint64, numLeaves uint64) {
if err := d.Tree.SetIndex(uint64(idx)); err != nil {
panic(fmt.Sprintf("don't call prove on a already used tree: %v", err))
Expand All @@ -245,3 +260,25 @@ func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, p
}
return d.Tree.Prove()
}

type errorTree struct {
*merkletree.Tree
leaves [][]byte
}

func newErrorTree(axis Axis, index uint) Tree {
return &errorTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
}
}

func (d *errorTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *errorTree) Root() ([]byte, error) {
return nil, fmt.Errorf("error")
}
40 changes: 28 additions & 12 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func (e *ErrByzantineData) Error() string {
// square (EDS), comparing repaired rows and columns against expected Merkle
// roots.
//
// Input
// # Input
//
// Missing shares must be nil.
//
// Output
// # Output
//
// The EDS is modified in-place. If repairing is successful, the EDS will be
// complete. If repairing is unsuccessful, the EDS will be the most-repaired
Expand Down Expand Up @@ -282,10 +282,14 @@ func (eds *ExtendedDataSquare) verifyAgainstRowRoots(
rebuiltShare []byte,
) error {
var root []byte
var err error
if rebuiltIndex < 0 || rebuiltShare == nil {
root = eds.computeSharesRoot(oldShares, Row, r)
root, err = eds.computeSharesRoot(oldShares, Row, r)
} else {
root = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
}
if err != nil {
return err
}

if !bytes.Equal(root, rowRoots[r]) {
Expand All @@ -303,10 +307,14 @@ func (eds *ExtendedDataSquare) verifyAgainstColRoots(
rebuiltShare []byte,
) error {
var root []byte
var err error
if rebuiltIndex < 0 || rebuiltShare == nil {
root = eds.computeSharesRoot(oldShares, Col, c)
root, err = eds.computeSharesRoot(oldShares, Col, c)
} else {
root = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
}
if err != nil {
return err
}

if !bytes.Equal(root, colRoots[c]) {
Expand All @@ -331,8 +339,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
if rowIsComplete {
errs.Go(func() error {
// ensure that the roots are equal
if !bytes.Equal(rowRoots[i], eds.getRowRoot(i)) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], eds.getRowRoot(i))
rowRoot, err := eds.getRowRoot(i)
if err != nil {
return err
}
if !bytes.Equal(rowRoots[i], rowRoot) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], rowRoot)
}
return nil
})
Expand All @@ -342,8 +354,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
if colIsComplete {
errs.Go(func() error {
// ensure that the roots are equal
if !bytes.Equal(colRoots[i], eds.getColRoot(i)) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], eds.getColRoot(i))
colRoot, err := eds.getColRoot(i)
if err != nil {
return err
}
if !bytes.Equal(colRoots[i], colRoot) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], colRoot)
}
return nil
})
Expand Down Expand Up @@ -391,15 +407,15 @@ func noMissingData(input [][]byte, rebuiltIndex int) bool {
return true
}

func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) []byte {
func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) ([]byte, error) {
tree := eds.createTreeFn(axis, i)
for _, d := range shares {
tree.Push(d)
}
return tree.Root()
}

func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) []byte {
func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) ([]byte, error) {
tree := eds.createTreeFn(axis, i)
for _, d := range shares[:rebuiltIndex] {
tree.Push(d)
Expand Down
7 changes: 5 additions & 2 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,11 @@ func TestValidFraudProof(t *testing.T) {
if err != nil {
t.Errorf("could not decode fraud proof shares; got: %v", err)
}
root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) {
root, err := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
assert.NoError(t, err)
rowRoot, err := corrupted.getRowRoot(fraudProof.Index)
assert.NoError(t, err)
if bytes.Equal(root, rowRoot) {
// If the roots match, then the fraud proof should be for invalid erasure coding.
parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth])
if err != nil {
Expand Down
14 changes: 8 additions & 6 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"github.com/celestiaorg/merkletree"
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle inside of rsmt2d.
// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
// inside of rsmt2d.
type TreeConstructorFn = func(axis Axis, index uint) Tree

// SquareIndex contains all information needed to identify the cell that is being
Expand All @@ -17,8 +18,8 @@ type SquareIndex struct {

// Tree wraps Merkle tree implementations to work with rsmt2d
type Tree interface {
Push(data []byte)
Root() []byte
Push(data []byte) error
Root() ([]byte, error)
}

var _ Tree = &DefaultTree{}
Expand All @@ -36,17 +37,18 @@ func NewDefaultTree(axis Axis, index uint) Tree {
}
}

func (d *DefaultTree) Push(data []byte) {
func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *DefaultTree) Root() []byte {
func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
}
d.root = d.Tree.Root()
}
return d.root
return d.root, nil
}