diff --git a/extendeddatacrossword.go b/extendeddatacrossword.go index 7114321..b44442e 100644 --- a/extendeddatacrossword.go +++ b/extendeddatacrossword.go @@ -378,11 +378,8 @@ func (eds *ExtendedDataSquare) preRepairSanityCheck( return nil }) errs.Go(func() error { - parityShares, err := eds.codec.Encode(eds.rowSlice(i, 0, eds.originalDataWidth)) + err := eds.verifyEncoding(eds.row(i), noShareInsertion, nil) if err != nil { - return err - } - if !bytes.Equal(flattenShares(parityShares), flattenShares(eds.rowSlice(i, eds.originalDataWidth, eds.originalDataWidth))) { return &ErrByzantineData{Row, i, eds.row(i)} } return nil @@ -407,12 +404,8 @@ func (eds *ExtendedDataSquare) preRepairSanityCheck( return nil }) errs.Go(func() error { - // check if we take the first half of the col and encode it, we get the second half - parityShares, err := eds.codec.Encode(eds.colSlice(0, i, eds.originalDataWidth)) + err := eds.verifyEncoding(eds.col(i), noShareInsertion, nil) if err != nil { - return err - } - if !bytes.Equal(flattenShares(parityShares), flattenShares(eds.colSlice(eds.originalDataWidth, i, eds.originalDataWidth))) { return &ErrByzantineData{Col, i, eds.col(i)} } return nil @@ -473,7 +466,14 @@ func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte // verifyEncoding checks the Reed-Solomon encoding of the provided data. func (eds *ExtendedDataSquare) verifyEncoding(data [][]byte, rebuiltIndex int, rebuiltShare []byte) error { - data[rebuiltIndex] = rebuiltShare + if rebuiltShare != nil && rebuiltIndex >= 0 { + data[rebuiltIndex] = rebuiltShare + defer func() { + // revert the change to the data slice after the verification + data[rebuiltIndex] = nil + }() + } + half := len(data) / 2 original := data[:half] parity, err := eds.codec.Encode(original) @@ -483,10 +483,8 @@ func (eds *ExtendedDataSquare) verifyEncoding(data [][]byte, rebuiltIndex int, r for i := half; i < len(data); i++ { if !bytes.Equal(data[i], parity[i-half]) { - data[rebuiltIndex] = nil return errors.New("parity data does not match encoded data") } } - return nil } diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 4ad7be3..4fb8566 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -78,6 +78,39 @@ func TestRepairExtendedDataSquare(t *testing.T) { t.Errorf("did not return an error on trying to repair an unrepairable square") } }) + + t.Run("repair in random order", func(t *testing.T) { + for i := 0; i < 100; i++ { + newEds, err := NewExtendedDataSquare(codec, NewDefaultTree, original.Width(), shareSize) + require.NoError(t, err) + // Randomly set shares in the newEds from the original and repair. + for { + x := rand.Intn(int(original.Width())) + y := rand.Intn(int(original.Width())) + if newEds.GetCell(uint(x), uint(y)) != nil { + continue + } + err = newEds.SetCell(uint(x), uint(y), original.GetCell(uint(x), uint(y))) + require.NoError(t, err) + + // Repair square. + err = newEds.Repair(rowRoots, colRoots) + if errors.Is(err, ErrUnrepairableDataSquare) { + continue + } + require.NoError(t, err) + break + } + + require.True(t, newEds.Equals(original)) + newRowRoots, err := newEds.RowRoots() + require.NoError(t, err) + require.Equal(t, rowRoots, newRowRoots) + newColRoots, err := newEds.ColRoots() + require.NoError(t, err) + require.Equal(t, colRoots, newColRoots) + } + }) } func TestValidFraudProof(t *testing.T) {