diff --git a/datasquare.go b/datasquare.go index 5c6890d..5b82e01 100644 --- a/datasquare.go +++ b/datasquare.go @@ -135,6 +135,9 @@ func (ds *dataSquare) setRowSlice(x uint, y uint, newRow [][]byte) error { return errors.New("invalid chunk size") } } + if y+uint(len(newRow)) > ds.width { + return fmt.Errorf("cannot set row slice at (%d, %d) of length %d: because it would exceed the data square width %d", x, y, len(newRow), ds.width) + } ds.dataMutex.Lock() defer ds.dataMutex.Unlock() @@ -165,6 +168,9 @@ func (ds *dataSquare) setColSlice(x uint, y uint, newCol [][]byte) error { return errors.New("invalid chunk size") } } + if x+uint(len(newCol)) > ds.width { + return fmt.Errorf("cannot set col slice at (%d, %d) of length %d: because it would exceed the data square width %d", x, y, len(newCol), ds.width) + } ds.dataMutex.Lock() defer ds.dataMutex.Unlock() diff --git a/datasquare_test.go b/datasquare_test.go index 0db5303..90b28a4 100644 --- a/datasquare_test.go +++ b/datasquare_test.go @@ -219,6 +219,120 @@ func TestDefaultTreeProofs(t *testing.T) { } } +func Test_setRowSlice(t *testing.T) { + type testCase struct { + name string + newRow [][]byte + x uint + y uint + want [][]byte + wantErr bool + } + testCases := []testCase{ + { + name: "overwrite the first row", + newRow: [][]byte{{5}, {6}}, + x: 0, + y: 0, + want: [][]byte{{5}, {6}, {3}, {4}}, + wantErr: false, + }, + { + name: "overwrite the last row", + newRow: [][]byte{{5}, {6}}, + x: 1, + y: 0, + want: [][]byte{{1}, {2}, {5}, {6}}, + wantErr: false, + }, + { + name: "returns an error if the new row has an invalid chunk size", + newRow: [][]byte{{5, 6}}, + x: 0, + y: 0, + wantErr: true, + }, + { + name: "returns an error if the new row would surpass the data square's width", + newRow: [][]byte{{5}, {6}}, + x: 0, + y: 1, + wantErr: true, + }, + } + + for _, tc := range testCases { + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + assert.NoError(t, err) + err = ds.setRowSlice(tc.x, tc.y, tc.newRow) + + if tc.wantErr { + assert.Error(t, err) + return + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, ds.Flattened()) + } + } +} + +func Test_setColSlice(t *testing.T) { + type testCase struct { + name string + newCol [][]byte + x uint + y uint + want [][]byte + wantErr bool + } + testCases := []testCase{ + { + name: "overwrite the first col", + newCol: [][]byte{{5}, {6}}, + x: 0, + y: 0, + want: [][]byte{{5}, {2}, {6}, {4}}, + wantErr: false, + }, + { + name: "overwrite the last col", + newCol: [][]byte{{5}, {6}}, + x: 0, + y: 1, + want: [][]byte{{1}, {5}, {3}, {6}}, + wantErr: false, + }, + { + name: "returns an error if the new col has an invalid chunk size", + newCol: [][]byte{{5, 6}}, + x: 0, + y: 0, + wantErr: true, + }, + { + name: "returns an error if the new col would surpass the data square's width", + newCol: [][]byte{{5}, {6}}, + x: 1, + y: 0, + wantErr: true, + }, + } + + for _, tc := range testCases { + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + assert.NoError(t, err) + err = ds.setColSlice(tc.x, tc.y, tc.newCol) + + if tc.wantErr { + assert.Error(t, err) + return + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, ds.Flattened()) + } + } +} + func BenchmarkEDSRoots(b *testing.B) { for i := 32; i < 513; i *= 2 { square, err := newDataSquare(genRandDS(i*2), NewDefaultTree)