diff --git a/test/dummy.go b/test/dummy.go index 74532ac..b9b5699 100644 --- a/test/dummy.go +++ b/test/dummy.go @@ -2,11 +2,11 @@ package test import ( "bytes" - "context" "crypto/sha512" "fmt" "slices" + "sync" "time" "github.com/rollkit/go-execution/types" @@ -14,6 +14,7 @@ import ( // DummyExecutor is a dummy implementation of the DummyExecutor interface for testing type DummyExecutor struct { + mu sync.RWMutex // Add mutex for thread safety stateRoot types.Hash pendingRoots map[uint64]types.Hash maxBytes uint64 @@ -32,6 +33,9 @@ func NewDummyExecutor() *DummyExecutor { // InitChain initializes the chain state with the given genesis time, initial height, and chain ID. // It returns the state root hash, the maximum byte size, and an error if the initialization fails. func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { + e.mu.Lock() + defer e.mu.Unlock() + hash := sha512.New() hash.Write(e.stateRoot) e.stateRoot = hash.Sum(nil) @@ -40,17 +44,27 @@ func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, in // GetTxs returns the list of transactions (types.Tx) within the DummyExecutor instance and an error if any. func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) { - txs := e.injectedTxs + e.mu.RLock() + defer e.mu.RUnlock() + + txs := make([]types.Tx, len(e.injectedTxs)) + copy(txs, e.injectedTxs) // Create a copy to avoid external modifications return txs, nil } // InjectTx adds a transaction to the internal list of injected transactions in the DummyExecutor instance. func (e *DummyExecutor) InjectTx(tx types.Tx) { + e.mu.Lock() + defer e.mu.Unlock() + e.injectedTxs = append(e.injectedTxs, tx) } // ExecuteTxs simulate execution of transactions. func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { + e.mu.Lock() + defer e.mu.Unlock() + hash := sha512.New() hash.Write(prevStateRoot) for _, tx := range txs { @@ -64,6 +78,9 @@ func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHei // SetFinal marks block at given height as finalized. func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { + e.mu.Lock() + defer e.mu.Unlock() + if pending, ok := e.pendingRoots[blockHeight]; ok { e.stateRoot = pending delete(e.pendingRoots, blockHeight) @@ -77,3 +94,11 @@ func (e *DummyExecutor) removeExecutedTxs(txs []types.Tx) { return slices.ContainsFunc(txs, func(t types.Tx) bool { return bytes.Equal(tx, t) }) }) } + +// GetStateRoot returns the current state root in a thread-safe manner +func (e *DummyExecutor) GetStateRoot() types.Hash { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.stateRoot +}