Skip to content

Commit

Permalink
statedb: Check that modifications are made with a copy
Browse files Browse the repository at this point in the history
Add a simple form of mutation detection that checks (by pointer comparison)
that on modification to an existing value a DeepCopy()'d object is inserted.

Extend the tests to do modifications. Validated manually that the mutation
detection resulted in panic.

Signed-off-by: Jussi Maki <jussi@isovalent.com>
  • Loading branch information
joamaki authored and squeed committed Apr 3, 2023
1 parent b1523a4 commit 03bc4dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
6 changes: 6 additions & 0 deletions pkg/statedb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ func (t *transaction) Commit() error {
changedTables := map[string]struct{}{}
for _, change := range t.txn.Changes() {
changedTables[change.Table] = struct{}{}

// Verify that a copy of the original object is being
// inserted rather than mutated in-place.
if change.Before == change.After {
panic("statedb: The original object is being modified without being copied first!")
}
}
t.db.revision = t.revision
t.txn.Commit()
Expand Down
36 changes: 32 additions & 4 deletions pkg/statedb/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func runTest(t *testing.T, p testParams) {
db := p.DB
fooId1, fooId2 := NewUUID(), NewUUID()

// Helper function to assert that the two "foo" objects exist.
assertGet := func(tx ReadTransaction) {
foos := p.Foos.Reader(tx)

Expand Down Expand Up @@ -105,6 +106,7 @@ func runTest(t *testing.T, p testParams) {
assertGet(tx)
tx.Commit()
}

// Check that it's been committed.
assertGet(db.ReadTxn())

Expand Down Expand Up @@ -139,15 +141,41 @@ func runTest(t *testing.T, p testParams) {
t.Errorf("expected Invalidated() channel to be closed!")
}

// Check that modifications to existing objects also result in notification.
it, err = p.Foos.Reader(db.ReadTxn()).Get(All)
assert.NoError(t, err)
ch = it.Invalidated()
select {
case <-ch:
t.Errorf("expected Invalidated() channel to block!")
default:
}

tx3 := db.WriteTxn()
foo2, err := p.Foos.Reader(tx3).First(ByUUID(fooId2))
assert.NoError(t, err)
assert.NotNil(t, foo2)
foo2 = foo2.DeepCopy()
foo2.Num = 222
err = p.Foos.Writer(tx3).Insert(foo2)
assert.NoError(t, err)
tx3.Commit()

select {
case <-ch:
case <-time.After(time.Second):
t.Errorf("expected Invalidated() channel to be closed!")
}

it, err = p.Foos.Reader(db.ReadTxn()).Get(All)
assert.NoError(t, err)
assert.Equal(t, Length[*Foo](it), 3)

// Aborting doesn't change anything.
tx3 := db.WriteTxn()
err = p.Foos.Writer(tx3).Insert(&Foo{UUID: NewUUID(), Num: 3})
tx4 := db.WriteTxn()
err = p.Foos.Writer(tx4).Insert(&Foo{UUID: NewUUID(), Num: 3})
assert.NoError(t, err)
tx3.Abort()
tx4.Abort()

it, err = p.Foos.Reader(db.ReadTxn()).Get(All)
assert.NoError(t, err)
Expand All @@ -165,7 +193,7 @@ func runTest(t *testing.T, p testParams) {
foos, ok := result[fooTableSchema.Name]
assert.True(t, ok, "There should be a 'foos' table")
assert.Len(t, foos, 3)
assert.True(t, foos[0].Num > 0 && foos[0].Num <= 3)
assert.True(t, foos[0].Num > 0)
assert.True(t, len(foos[0].UUID) > 0)
}

Expand Down

0 comments on commit 03bc4dd

Please sign in to comment.