diff --git a/Makefile b/Makefile index fcb68a2b..8ab13a51 100644 --- a/Makefile +++ b/Makefile @@ -38,8 +38,7 @@ lint-examples: golangci-lint run -v -E ${LINT_SETTINGS} lint: | lint-examples - golangci-lint run --timeout 2m0s -v -E ${LINT_SETTINGS},gomnd && \ - make check-comments; + golangci-lint run --timeout 2m0s -v -E ${LINT_SETTINGS},gomnd format: gofmt -s -w -l . diff --git a/mocks/syncer/handler.go b/mocks/syncer/handler.go index f2bf144c..d435c6df 100644 --- a/mocks/syncer/handler.go +++ b/mocks/syncer/handler.go @@ -42,3 +42,17 @@ func (_m *Handler) BlockRemoved(ctx context.Context, block *types.BlockIdentifie return r0 } + +// BlockSeen provides a mock function with given fields: ctx, block +func (_m *Handler) BlockSeen(ctx context.Context, block *types.Block) error { + ret := _m.Called(ctx, block) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *types.Block) error); ok { + r0 = rf(ctx, block) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/statefulsyncer/configuration.go b/statefulsyncer/configuration.go index 3ef3f1c3..de1e37a0 100644 --- a/statefulsyncer/configuration.go +++ b/statefulsyncer/configuration.go @@ -57,3 +57,12 @@ func WithPruneSleepTime(sleepTime int) Option { s.pruneSleepTime = time.Duration(sleepTime) * time.Second } } + +// WithSeenConcurrency overrides the number of concurrent +// invocations of BlockSeen we will handle. We default +// to the value of runtime.NumCPU(). +func WithSeenConcurrency(concurrency int64) Option { + return func(s *StatefulSyncer) { + s.seenSemaphoreSize = concurrency + } +} diff --git a/statefulsyncer/stateful_syncer.go b/statefulsyncer/stateful_syncer.go index 1218789e..9056645c 100644 --- a/statefulsyncer/stateful_syncer.go +++ b/statefulsyncer/stateful_syncer.go @@ -19,8 +19,11 @@ import ( "errors" "fmt" "log" + "runtime" "time" + "golang.org/x/sync/semaphore" + "github.com/coinbase/rosetta-sdk-go/fetcher" storageErrs "github.com/coinbase/rosetta-sdk-go/storage/errors" "github.com/coinbase/rosetta-sdk-go/storage/modules" @@ -40,6 +43,9 @@ const ( // pruneBuffer is the cushion we apply to pastBlockLimit // when pruning. pruneBuffer = 2 + + // semaphoreWeight is the weight of each semaphore request. + semaphoreWeight = int64(1) ) // StatefulSyncer is an abstraction layer over @@ -61,6 +67,11 @@ type StatefulSyncer struct { pastBlockLimit int adjustmentWindow int64 pruneSleepTime time.Duration + + // SeenSemaphore limits how many executions of + // BlockSeen occur concurrently. + seenSemaphore *semaphore.Weighted + seenSemaphoreSize int64 } // Logger is used by the statefulsyncer to @@ -103,11 +114,12 @@ func New( logger: logger, // Optional args - cacheSize: syncer.DefaultCacheSize, - maxConcurrency: syncer.DefaultMaxConcurrency, - pastBlockLimit: syncer.DefaultPastBlockLimit, - adjustmentWindow: syncer.DefaultAdjustmentWindow, - pruneSleepTime: DefaultPruneSleepTime, + cacheSize: syncer.DefaultCacheSize, + maxConcurrency: syncer.DefaultMaxConcurrency, + pastBlockLimit: syncer.DefaultPastBlockLimit, + adjustmentWindow: syncer.DefaultAdjustmentWindow, + pruneSleepTime: DefaultPruneSleepTime, + seenSemaphoreSize: int64(runtime.NumCPU()), } // Override defaults with any provided options @@ -115,6 +127,11 @@ func New( opt(s) } + // We set this after options because the caller + // has the ability to set the max concurrency + // of seen invocations. + s.seenSemaphore = semaphore.NewWeighted(s.seenSemaphoreSize) + return s } @@ -219,6 +236,25 @@ func (s *StatefulSyncer) Prune(ctx context.Context, helper PruneHelper) error { return ctx.Err() } +// BlockSeen is called by the syncer when a block is seen. +func (s *StatefulSyncer) BlockSeen(ctx context.Context, block *types.Block) error { + if err := s.seenSemaphore.Acquire(ctx, semaphoreWeight); err != nil { + return err + } + defer s.seenSemaphore.Release(semaphoreWeight) + + if err := s.blockStorage.SeeBlock(ctx, block); err != nil { + return fmt.Errorf( + "%w: unable to pre-store block %s:%d", + err, + block.BlockIdentifier.Hash, + block.BlockIdentifier.Index, + ) + } + + return nil +} + // BlockAdded is called by the syncer when a block is added. func (s *StatefulSyncer) BlockAdded(ctx context.Context, block *types.Block) error { err := s.blockStorage.AddBlock(ctx, block) diff --git a/storage/encoder/encoder.go b/storage/encoder/encoder.go index 1cefd9e7..d8470c62 100644 --- a/storage/encoder/encoder.go +++ b/storage/encoder/encoder.go @@ -221,22 +221,6 @@ const ( unicodeRecordSeparator = '\u001E' ) -// Indexes of encoded AccountCoin struct -const ( - accountAddress = iota - coinIdentifier - amountValue - amountCurrencySymbol - amountCurrencyDecimals - - // If none exist below, we stop after amount. - accountMetadata - subAccountAddress - subAccountMetadata - amountMetadata - currencyMetadata -) - func (e *Encoder) encodeAndWrite(output *bytes.Buffer, object interface{}) error { buf := e.pool.Get() err := getEncoder(buf).Encode(object) @@ -368,6 +352,22 @@ func (e *Encoder) DecodeAccountCoin( // nolint:gocognit accountCoin *types.AccountCoin, reclaimInput bool, ) error { + // Indices of encoded AccountCoin struct + const ( + accountAddress = iota + coinIdentifier + amountValue + amountCurrencySymbol + amountCurrencyDecimals + + // If none exist below, we stop after amount. + accountMetadata + subAccountAddress + subAccountMetadata + amountMetadata + currencyMetadata + ) + count := 0 currentBytes := b for { @@ -475,3 +475,192 @@ func (e *Encoder) DecodeAccountCoin( // nolint:gocognit return nil } + +// EncodeAccountCurrency is used to encode an AccountCurrency using the scheme (on the happy path): +// accountAddress|currencySymbol|currencyDecimals +// +// And the following scheme on the unhappy path: +// accountAddress|currencySymbol|currencyDecimals|accountMetadata| +// subAccountAddress|subAccountMetadata|currencyMetadata +// +// In both cases, the | character is represented by the unicodeRecordSeparator rune. +func (e *Encoder) EncodeAccountCurrency( // nolint:gocognit + accountCurrency *types.AccountCurrency, +) ([]byte, error) { + output := e.pool.Get() + if _, err := output.WriteString(accountCurrency.Account.Address); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + if _, err := output.WriteString(accountCurrency.Currency.Symbol); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + if _, err := output.WriteString( + strconv.FormatInt(int64(accountCurrency.Currency.Decimals), 10), + ); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + + // Exit early if we don't have any complex data to record (this helps + // us save a lot of space on the happy path). + if accountCurrency.Account.Metadata == nil && + accountCurrency.Account.SubAccount == nil && + accountCurrency.Currency.Metadata == nil { + return output.Bytes(), nil + } + + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + if accountCurrency.Account.Metadata != nil { + if err := e.encodeAndWrite(output, accountCurrency.Account.Metadata); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + } + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + + if accountCurrency.Account.SubAccount != nil { + if _, err := output.WriteString(accountCurrency.Account.SubAccount.Address); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + } + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + + if accountCurrency.Account.SubAccount != nil && + accountCurrency.Account.SubAccount.Metadata != nil { + if err := e.encodeAndWrite(output, accountCurrency.Account.SubAccount.Metadata); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + } + if _, err := output.WriteRune(unicodeRecordSeparator); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + + if accountCurrency.Currency.Metadata != nil { + if err := e.encodeAndWrite(output, accountCurrency.Currency.Metadata); err != nil { + return nil, fmt.Errorf("%w: %s", errors.ErrObjectEncodeFailed, err.Error()) + } + } + + return output.Bytes(), nil +} + +// DecodeAccountCurrency decodes an AccountCurrency and optionally +// reclaims the memory associated with the input. +func (e *Encoder) DecodeAccountCurrency( // nolint:gocognit + b []byte, + accountCurrency *types.AccountCurrency, + reclaimInput bool, +) error { + // Indices of encoded AccountCurrency struct + const ( + accountAddress = iota + currencySymbol + currencyDecimals + + // If none exist below, we stop after amount. + accountMetadata + subAccountAddress + subAccountMetadata + currencyMetadata + ) + + count := 0 + currentBytes := b + for { + nextRune := bytes.IndexRune(currentBytes, unicodeRecordSeparator) + if nextRune == -1 { + if count != currencyDecimals && count != currencyMetadata { + return fmt.Errorf("%w: next rune is -1 at %d", errors.ErrRawDecodeFailed, count) + } + + nextRune = len(currentBytes) + } + + val := currentBytes[:nextRune] + if len(val) == 0 { + goto handleNext + } + + switch count { + case accountAddress: + accountCurrency.Account = &types.AccountIdentifier{ + Address: string(val), + } + case currencySymbol: + accountCurrency.Currency = &types.Currency{ + Symbol: string(val), + } + case currencyDecimals: + i, err := strconv.ParseInt(string(val), 10, 32) + if err != nil { + return fmt.Errorf("%w: %s", errors.ErrRawDecodeFailed, err.Error()) + } + + accountCurrency.Currency.Decimals = int32(i) + case accountMetadata: + m, err := e.decodeMap(val) + if err != nil { + return fmt.Errorf("%w: account metadata %s", errors.ErrRawDecodeFailed, err.Error()) + } + + accountCurrency.Account.Metadata = m + case subAccountAddress: + accountCurrency.Account.SubAccount = &types.SubAccountIdentifier{ + Address: string(val), + } + case subAccountMetadata: + if accountCurrency.Account.SubAccount == nil { + return errors.ErrRawDecodeFailed // must have address + } + + m, err := e.decodeMap(val) + if err != nil { + return fmt.Errorf( + "%w: subaccount metadata %s", + errors.ErrRawDecodeFailed, + err.Error(), + ) + } + + accountCurrency.Account.SubAccount.Metadata = m + case currencyMetadata: + m, err := e.decodeMap(val) + if err != nil { + return fmt.Errorf( + "%w: currency metadata %s", + errors.ErrRawDecodeFailed, + err.Error(), + ) + } + + accountCurrency.Currency.Metadata = m + default: + return fmt.Errorf("%w: count %d > end", errors.ErrRawDecodeFailed, count) + } + + handleNext: + if nextRune == len(currentBytes) && + (count == currencyDecimals || count == currencyMetadata) { + break + } + + currentBytes = currentBytes[nextRune+1:] + count++ + } + + if reclaimInput { + e.pool.PutByteSlice(b) + } + + return nil +} diff --git a/storage/encoder/encoder_test.go b/storage/encoder/encoder_test.go index 99dd444c..3bfd97a2 100644 --- a/storage/encoder/encoder_test.go +++ b/storage/encoder/encoder_test.go @@ -257,3 +257,75 @@ func TestEncodeDecodeAccountCoin(t *testing.T) { }) } } + +func TestEncodeDecodeAccountCurrency(t *testing.T) { + tests := map[string]struct { + accountCurrency *types.AccountCurrency + }{ + "simple": { + accountCurrency: &types.AccountCurrency{ + Account: &types.AccountIdentifier{ + Address: "hello", + }, + Currency: &types.Currency{ + Symbol: "BTC", + Decimals: 8, + }, + }, + }, + "sub account info": { + accountCurrency: &types.AccountCurrency{ + Account: &types.AccountIdentifier{ + Address: "hello", + SubAccount: &types.SubAccountIdentifier{ + Address: "sub", + Metadata: map[string]interface{}{ + "test": "stuff", + }, + }, + }, + Currency: &types.Currency{ + Symbol: "BTC", + Decimals: 8, + }, + }, + }, + "currency metadata": { + accountCurrency: &types.AccountCurrency{ + Account: &types.AccountIdentifier{ + Address: "hello", + }, + Currency: &types.Currency{ + Symbol: "BTC", + Decimals: 8, + Metadata: map[string]interface{}{ + "issuer": "satoshi", + }, + }, + }, + }, + } + + for name, test := range tests { + e, err := NewEncoder(nil, NewBufferPool(), true) + assert.NoError(t, err) + + t.Run(name, func(t *testing.T) { + standardResult, err := e.Encode("", test.accountCurrency) + assert.NoError(t, err) + optimizedResult, err := e.EncodeAccountCurrency(test.accountCurrency) + assert.NoError(t, err) + fmt.Printf( + "Uncompressed: %d, Standard Compressed: %d, Optimized: %d\n", + len(types.PrintStruct(test.accountCurrency)), + len(standardResult), + len(optimizedResult), + ) + + var decoded types.AccountCurrency + assert.NoError(t, e.DecodeAccountCurrency(optimizedResult, &decoded, true)) + + assert.Equal(t, test.accountCurrency, &decoded) + }) + } +} diff --git a/storage/modules/balance_storage.go b/storage/modules/balance_storage.go index 49b9a349..3108147c 100644 --- a/storage/modules/balance_storage.go +++ b/storage/modules/balance_storage.go @@ -329,11 +329,6 @@ func (b *BalanceStorage) RemovingBlock( }, nil } -type accountEntry struct { - Account *types.AccountIdentifier `json:"account"` - Currency *types.Currency `json:"currency"` -} - // SetBalance allows a client to set the balance of an account in a database // transaction (removing all historical states). This is particularly useful // for bootstrapping balances. @@ -360,7 +355,7 @@ func (b *BalanceStorage) SetBalance( } // Serialize account entry - serialAcc, err := b.db.Encoder().Encode(accountNamespace, accountEntry{ + serialAcc, err := b.db.Encoder().EncodeAccountCurrency(&types.AccountCurrency{ Account: account, Currency: amount.Currency, }) @@ -480,26 +475,29 @@ func (b *BalanceStorage) ReconciliationCoverage( ) (float64, error) { seen := 0 validCoverage := 0 - err := b.getAllAccountEntries(ctx, func(txn database.Transaction, entry accountEntry) error { - seen++ + err := b.getAllAccountEntries( + ctx, + func(txn database.Transaction, entry *types.AccountCurrency) error { + seen++ - // Fetch last reconciliation index in same database.Transaction - key := GetAccountKey(reconciliationNamepace, entry.Account, entry.Currency) - exists, lastReconciled, err := BigIntGet(ctx, key, txn) - if err != nil { - return err - } + // Fetch last reconciliation index in same database.Transaction + key := GetAccountKey(reconciliationNamepace, entry.Account, entry.Currency) + exists, lastReconciled, err := BigIntGet(ctx, key, txn) + if err != nil { + return err + } - if !exists { - return nil - } + if !exists { + return nil + } - if lastReconciled.Int64() >= minimumIndex { - validCoverage++ - } + if lastReconciled.Int64() >= minimumIndex { + validCoverage++ + } - return nil - }) + return nil + }, + ) if err != nil { return -1, fmt.Errorf("%w: unable to get all account entries", err) } @@ -859,7 +857,7 @@ func (b *BalanceStorage) UpdateBalance( if !exists { newAccount = true key := GetAccountKey(accountNamespace, change.Account, change.Currency) - serialAcc, err := b.db.Encoder().Encode(accountNamespace, accountEntry{ + serialAcc, err := b.db.Encoder().EncodeAccountCurrency(&types.AccountCurrency{ Account: change.Account, Currency: change.Currency, }) @@ -1137,7 +1135,7 @@ func (b *BalanceStorage) BootstrapBalances( func (b *BalanceStorage) getAllAccountEntries( ctx context.Context, - handler func(database.Transaction, accountEntry) error, + handler func(database.Transaction, *types.AccountCurrency) error, ) error { txn := b.db.ReadTransaction(ctx) defer txn.Discard(ctx) @@ -1146,9 +1144,9 @@ func (b *BalanceStorage) getAllAccountEntries( []byte(accountNamespace), []byte(accountNamespace), func(k []byte, v []byte) error { - var accEntry accountEntry + var accCurrency types.AccountCurrency // We should not reclaim memory during a scan!! - err := b.db.Encoder().Decode(accountNamespace, v, &accEntry, false) + err := b.db.Encoder().DecodeAccountCurrency(v, &accCurrency, false) if err != nil { return fmt.Errorf( "%w: unable to parse balance entry for %s", @@ -1157,7 +1155,7 @@ func (b *BalanceStorage) getAllAccountEntries( ) } - return handler(txn, accEntry) + return handler(txn, &accCurrency) }, false, false, @@ -1177,22 +1175,14 @@ func (b *BalanceStorage) GetAllAccountCurrency( ) ([]*types.AccountCurrency, error) { log.Println("Loading previously seen accounts (this could take a while)...") - accountEntries := []*accountEntry{} - if err := b.getAllAccountEntries(ctx, func(_ database.Transaction, entry accountEntry) error { - accountEntries = append(accountEntries, &entry) + accounts := []*types.AccountCurrency{} + if err := b.getAllAccountEntries(ctx, func(_ database.Transaction, account *types.AccountCurrency) error { + accounts = append(accounts, account) return nil }); err != nil { return nil, fmt.Errorf("%w: unable to get all balance entries", err) } - accounts := make([]*types.AccountCurrency, len(accountEntries)) - for i, account := range accountEntries { - accounts[i] = &types.AccountCurrency{ - Account: account.Account, - Currency: account.Currency, - } - } - return accounts, nil } diff --git a/storage/modules/block_storage.go b/storage/modules/block_storage.go index 078c613e..42e519e0 100644 --- a/storage/modules/block_storage.go +++ b/storage/modules/block_storage.go @@ -561,22 +561,53 @@ func (b *BlockStorage) GetBlock( return b.GetBlockTransactional(ctx, transaction, blockIdentifier) } -func (b *BlockStorage) storeBlock( +func (b *BlockStorage) seeBlock( ctx context.Context, transaction database.Transaction, blockResponse *types.BlockResponse, -) error { +) (bool, error) { blockIdentifier := blockResponse.Block.BlockIdentifier namespace, key := getBlockHashKey(blockIdentifier.Hash) buf, err := b.db.Encoder().Encode(namespace, blockResponse) if err != nil { - return fmt.Errorf("%w: %v", storageErrs.ErrBlockEncodeFailed, err) + return false, fmt.Errorf("%w: %v", storageErrs.ErrBlockEncodeFailed, err) } - if err := storeUniqueKey(ctx, transaction, key, buf, true); err != nil { - return fmt.Errorf("%w: %v", storageErrs.ErrBlockStoreFailed, err) + exists, val, err := transaction.Get(ctx, key) + if err != nil { + return false, err + } + + if !exists { + return false, transaction.Set(ctx, key, buf, true) + } + + var rosettaBlockResponse types.BlockResponse + err = b.db.Encoder().Decode(namespace, val, &rosettaBlockResponse, true) + if err != nil { + return false, err } + // Exit early if block already exists! + if blockResponse.Block.BlockIdentifier.Hash == rosettaBlockResponse.Block.BlockIdentifier.Hash && + blockResponse.Block.BlockIdentifier.Index == rosettaBlockResponse.Block.BlockIdentifier.Index { + return true, nil + } + + return false, fmt.Errorf( + "%w: duplicate key %s found", + storageErrs.ErrDuplicateKey, + string(key), + ) +} + +func (b *BlockStorage) storeBlock( + ctx context.Context, + transaction database.Transaction, + blockIdentifier *types.BlockIdentifier, +) error { + _, key := getBlockHashKey(blockIdentifier.Hash) + if err := storeUniqueKey( ctx, transaction, @@ -598,12 +629,13 @@ func (b *BlockStorage) storeBlock( return nil } -// AddBlock stores a block or returns an error. -func (b *BlockStorage) AddBlock( +// SeeBlock pre-stores a block or returns an error. +func (b *BlockStorage) SeeBlock( ctx context.Context, block *types.Block, ) error { - transaction := b.db.WriteTransaction(ctx, blockSyncIdentifier, true) + _, key := getBlockHashKey(block.BlockIdentifier.Hash) + transaction := b.db.WriteTransaction(ctx, string(key), true) defer transaction.Discard(ctx) // Store all transactions in order and check for duplicates @@ -639,11 +671,15 @@ func (b *BlockStorage) AddBlock( } // Store block - err := b.storeBlock(ctx, transaction, blockWithoutTransactions) + exists, err := b.seeBlock(ctx, transaction, blockWithoutTransactions) if err != nil { return fmt.Errorf("%w: %v", storageErrs.ErrBlockStoreFailed, err) } + if exists { + return nil + } + g, gctx := errgroup.WithContextN(ctx, b.numCPU, b.numCPU) for i := range block.Transactions { // We need to set variable before calling goroutine @@ -668,6 +704,23 @@ func (b *BlockStorage) AddBlock( return err } + return transaction.Commit(ctx) +} + +// AddBlock stores a block or returns an error. +func (b *BlockStorage) AddBlock( + ctx context.Context, + block *types.Block, +) error { + transaction := b.db.WriteTransaction(ctx, blockSyncIdentifier, true) + defer transaction.Discard(ctx) + + // Store block + err := b.storeBlock(ctx, transaction, block.BlockIdentifier) + if err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrBlockStoreFailed, err) + } + return b.callWorkersAndCommit(ctx, block, transaction, true) } @@ -987,10 +1040,21 @@ func (b *BlockStorage) FindTransaction( return nil, nil, nil } + head, err := b.GetHeadBlockIdentifier(ctx) + if err != nil { + return nil, nil, err + } + var newestBlock *types.BlockIdentifier var newestTransaction *types.Transaction for _, blockTransaction := range blockTransactions { if newestBlock == nil || blockTransaction.BlockIdentifier.Index > newestBlock.Index { + // Now that we are optimistically storing data, there is a chance + // we may fetch a transaction from a seen but unsequenced block. + if head != nil && blockTransaction.BlockIdentifier.Index > head.Index { + continue + } + newestBlock = blockTransaction.BlockIdentifier newestTransaction = blockTransaction.Transaction } diff --git a/storage/modules/block_storage_test.go b/storage/modules/block_storage_test.go index ca60f8cb..4f866ae3 100644 --- a/storage/modules/block_storage_test.go +++ b/storage/modules/block_storage_test.go @@ -335,6 +335,9 @@ func TestBlock(t *testing.T) { assert.Equal(t, int64(-1), oldestIndex) assert.Error(t, storageErrs.ErrOldestIndexMissing, err) + err = storage.SeeBlock(ctx, genesisBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, genesisBlock) assert.NoError(t, err) @@ -349,6 +352,9 @@ func TestBlock(t *testing.T) { }) t.Run("Set and get block", func(t *testing.T) { + err = storage.SeeBlock(ctx, newBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, newBlock) assert.NoError(t, err) @@ -419,6 +425,9 @@ func TestBlock(t *testing.T) { }) t.Run("Set duplicate transaction hash (from prior block)", func(t *testing.T) { + err = storage.SeeBlock(ctx, newBlock2) + assert.NoError(t, err) + err = storage.AddBlock(ctx, newBlock2) assert.NoError(t, err) @@ -486,6 +495,9 @@ func TestBlock(t *testing.T) { assert.NoError(t, err) assert.Equal(t, newBlock2.ParentBlockIdentifier, head) + err = storage.SeeBlock(ctx, newBlock2) + assert.NoError(t, err) + err = storage.AddBlock(ctx, newBlock2) assert.NoError(t, err) @@ -504,7 +516,10 @@ func TestBlock(t *testing.T) { }) t.Run("Add block with complex metadata", func(t *testing.T) { - err := storage.AddBlock(ctx, complexBlock) + err := storage.SeeBlock(ctx, complexBlock) + assert.NoError(t, err) + + err = storage.AddBlock(ctx, complexBlock) assert.NoError(t, err) oldestIndex, err := storage.GetOldestBlockIndex(ctx) @@ -533,7 +548,7 @@ func TestBlock(t *testing.T) { }) t.Run("Set duplicate transaction hash (same block)", func(t *testing.T) { - err = storage.AddBlock(ctx, duplicateTxBlock) + err = storage.SeeBlock(ctx, duplicateTxBlock) assert.Contains(t, err.Error(), storageErrs.ErrDuplicateTransactionHash.Error()) head, err := storage.GetHeadBlockIdentifier(ctx) @@ -542,7 +557,10 @@ func TestBlock(t *testing.T) { }) t.Run("Add block after omitted", func(t *testing.T) { - err := storage.AddBlock(ctx, gapBlock) + err := storage.SeeBlock(ctx, gapBlock) + assert.NoError(t, err) + + err = storage.AddBlock(ctx, gapBlock) assert.NoError(t, err) block, err := storage.GetBlock( @@ -565,6 +583,9 @@ func TestBlock(t *testing.T) { assert.NoError(t, err) assert.Equal(t, gapBlock.ParentBlockIdentifier, head) + err = storage.SeeBlock(ctx, gapBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, gapBlock) assert.NoError(t, err) @@ -611,6 +632,7 @@ func TestBlock(t *testing.T) { ParentBlockIdentifier: parentBlockIdentifier, } + assert.NoError(t, storage.SeeBlock(ctx, block)) assert.NoError(t, storage.AddBlock(ctx, block)) head, err := storage.GetHeadBlockIdentifier(ctx) assert.NoError(t, err) @@ -736,6 +758,9 @@ func TestCreateBlockCache(t *testing.T) { }) t.Run("1 block processed", func(t *testing.T) { + err = storage.SeeBlock(ctx, genesisBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, genesisBlock) assert.NoError(t, err) assert.Equal( @@ -746,6 +771,9 @@ func TestCreateBlockCache(t *testing.T) { }) t.Run("2 blocks processed", func(t *testing.T) { + err = storage.SeeBlock(ctx, newBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, newBlock) assert.NoError(t, err) assert.Equal( @@ -764,6 +792,9 @@ func TestCreateBlockCache(t *testing.T) { ParentBlockIdentifier: newBlock.BlockIdentifier, } + err = storage.SeeBlock(ctx, simpleGap) + assert.NoError(t, err) + err = storage.AddBlock(ctx, simpleGap) assert.NoError(t, err) assert.Equal( @@ -804,7 +835,7 @@ func TestAtTip(t *testing.T) { }) t.Run("Add old block", func(t *testing.T) { - err := storage.AddBlock(ctx, &types.Block{ + b := &types.Block{ BlockIdentifier: &types.BlockIdentifier{ Hash: "block 0", Index: 0, @@ -814,7 +845,11 @@ func TestAtTip(t *testing.T) { Index: 0, }, Timestamp: utils.Milliseconds() - (3 * tipDelay * utils.MillisecondsInSecond), - }) + } + err := storage.SeeBlock(ctx, b) + assert.NoError(t, err) + + err = storage.AddBlock(ctx, b) assert.NoError(t, err) atTip, blockIdentifier, err := storage.AtTip(ctx, tipDelay) @@ -828,7 +863,7 @@ func TestAtTip(t *testing.T) { }) t.Run("Add new block", func(t *testing.T) { - err := storage.AddBlock(ctx, &types.Block{ + b := &types.Block{ BlockIdentifier: &types.BlockIdentifier{ Hash: "block 1", Index: 1, @@ -838,7 +873,11 @@ func TestAtTip(t *testing.T) { Index: 0, }, Timestamp: utils.Milliseconds(), - }) + } + err := storage.SeeBlock(ctx, b) + assert.NoError(t, err) + + err = storage.AddBlock(ctx, b) assert.NoError(t, err) atTip, blockIdentifier, err := storage.AtTip(ctx, tipDelay) diff --git a/syncer/syncer.go b/syncer/syncer.go index 9617c908..c4067233 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -211,6 +211,7 @@ func (s *Syncer) addBlockIndices( endIndex int64, ) error { defer close(blockIndices) + i := startIndex for i <= endIndex { s.concurrencyLock.Lock() @@ -266,6 +267,10 @@ func (s *Syncer) fetchBlockResult( br.block = block } + if err := s.handleSeenBlock(ctx, br); err != nil { + return nil, err + } + return br, nil } @@ -280,11 +285,11 @@ func (s *Syncer) safeExit(err error) error { return err } -// fetchChannelBlocks fetches blocks from a +// fetchBlocks fetches blocks from a // channel with retries until there are no // more blocks in the channel or there is an // error. -func (s *Syncer) fetchChannelBlocks( +func (s *Syncer) fetchBlocks( ctx context.Context, network *types.NetworkIdentifier, blockIndices chan int64, @@ -451,6 +456,75 @@ func (s *Syncer) adjustWorkers() bool { return shouldCreate } +func (s *Syncer) handleSeenBlock( + ctx context.Context, + result *blockResult, +) error { + // If the helper returns ErrOrphanHead + // for a block fetch, result.block will + // be nil. + if result.block == nil { + return nil + } + + return s.handler.BlockSeen(ctx, result.block) +} + +func (s *Syncer) sequenceBlocks( // nolint:golint + ctx context.Context, + pipelineCtx context.Context, + g *errgroup.Group, + blockIndices chan int64, + fetchedBlocks chan *blockResult, + endIndex int64, +) error { + cache := make(map[int64]*blockResult) + for result := range fetchedBlocks { + cache[result.index] = result + + if err := s.processBlocks(ctx, cache, endIndex); err != nil { + return fmt.Errorf("%w: %v", ErrBlocksProcessMultipleFailed, err) + } + + // Determine if concurrency should be adjusted. + s.recentBlockSizes = append(s.recentBlockSizes, utils.SizeOf(result)) + s.lastAdjustment++ + + s.concurrencyLock.Lock() + shouldCreate := s.adjustWorkers() + if !shouldCreate { + s.concurrencyLock.Unlock() + continue + } + + // If we have finished loading blocks or the pipelineCtx + // has an error (like context.Canceled), we should avoid + // creating more goroutines (as there is a chance that + // Wait has returned). Attempting to create more goroutines + // after Wait has returned will cause a panic. + s.doneLoadingLock.Lock() + if !s.doneLoading && pipelineCtx.Err() == nil { + g.Go(func() error { + return s.fetchBlocks( + pipelineCtx, + s.network, + blockIndices, + fetchedBlocks, + ) + }) + } else { + s.concurrency-- + } + s.doneLoadingLock.Unlock() + + // Hold concurrencyLock until after we attempt to create another + // new goroutine in the case we accidentally go to 0 during shutdown. + s.concurrencyLock.Unlock() + } + + return nil +} + // syncRange fetches and processes a range of blocks // (from syncer.nextIndex to endIndex, inclusive) // with syncer.concurrency. @@ -459,7 +533,7 @@ func (s *Syncer) syncRange( endIndex int64, ) error { blockIndices := make(chan int64) - results := make(chan *blockResult) + fetchedBlocks := make(chan *blockResult) // Ensure default concurrency is less than max concurrency. startingConcurrency := DefaultConcurrency @@ -496,58 +570,29 @@ func (s *Syncer) syncRange( for j := int64(0); j < s.concurrency; j++ { g.Go(func() error { - return s.fetchChannelBlocks(pipelineCtx, s.network, blockIndices, results) + return s.fetchBlocks(pipelineCtx, s.network, blockIndices, fetchedBlocks) }) } // Wait for all block fetching goroutines to exit - // before closing the results channel. + // before closing the fetchedBlocks channel. go func() { _ = g.Wait() - close(results) + close(fetchedBlocks) }() - cache := make(map[int64]*blockResult) - for b := range results { - cache[b.index] = b - - if err := s.processBlocks(ctx, cache, endIndex); err != nil { - return fmt.Errorf("%w: %v", ErrBlocksProcessMultipleFailed, err) - } - - // Determine if concurrency should be adjusted. - s.recentBlockSizes = append(s.recentBlockSizes, utils.SizeOf(b)) - s.lastAdjustment++ - - s.concurrencyLock.Lock() - shouldCreate := s.adjustWorkers() - if !shouldCreate { - s.concurrencyLock.Unlock() - continue - } - - // If we have finished loading blocks or the pipelineCtx - // has an error (like context.Canceled), we should avoid - // creating more goroutines (as there is a chance that - // Wait has returned). Attempting to create more goroutines - // after Wait has returned will cause a panic. - s.doneLoadingLock.Lock() - if !s.doneLoading && pipelineCtx.Err() == nil { - g.Go(func() error { - return s.fetchChannelBlocks(pipelineCtx, s.network, blockIndices, results) - }) - } else { - s.concurrency-- - } - s.doneLoadingLock.Unlock() - - // Hold concurrencyLock until after we attempt to create another - // new goroutine in the case we accidentally go to 0 during shutdown. - s.concurrencyLock.Unlock() + if err := s.sequenceBlocks( + ctx, + pipelineCtx, + g, + blockIndices, + fetchedBlocks, + endIndex, + ); err != nil { + return err } - err := g.Wait() - if err != nil { + if err := g.Wait(); err != nil { return fmt.Errorf("%w: unable to sync to %d", err, endIndex) } diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index e31c2717..7e730546 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -468,6 +468,15 @@ func TestSync_NoReorg(t *testing.T) { continue } + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -476,9 +485,6 @@ func TestSync_NoReorg(t *testing.T) { nil, ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) - if index == 1100 { - assert.Equal(t, int64(3), syncer.concurrency) - } // Test tip method if index > 200 { @@ -533,6 +539,13 @@ func TestSync_SpecificStart(t *testing.T) { ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) }).Once() + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -596,6 +609,13 @@ func TestSync_Cancel(t *testing.T) { b, nil, ).Once() + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -649,6 +669,15 @@ func TestSync_Reorg(t *testing.T) { ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) }).Once() + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Return( + nil, + ).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -703,6 +732,13 @@ func TestSync_Reorg(t *testing.T) { }).Once() } + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + newBlocks[0], + ).Return( + nil, + ).Once() // only fetch this block once mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -727,6 +763,20 @@ func TestSync_Reorg(t *testing.T) { ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) }).Once() + + seenTimes := 2 + if b.BlockIdentifier.Index > 801 { + seenTimes = 1 + } + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Times(seenTimes) mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -790,6 +840,15 @@ func TestSync_ManualReorg(t *testing.T) { ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) }).Once() + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Return( + nil, + ).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -838,6 +897,13 @@ func TestSync_ManualReorg(t *testing.T) { ).Run(func(args mock.Arguments) { assertNotCanceled(t, args) }).Once() + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -941,6 +1007,15 @@ func TestSync_Dynamic(t *testing.T) { continue } + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), @@ -1022,6 +1097,15 @@ func TestSync_DynamicOverhead(t *testing.T) { continue } + mockHandler.On( + "BlockSeen", + mock.AnythingOfType("*context.cancelCtx"), + b, + ).Return( + nil, + ).Run(func(args mock.Arguments) { + assertNotCanceled(t, args) + }).Once() mockHandler.On( "BlockAdded", mock.AnythingOfType("*context.cancelCtx"), diff --git a/syncer/types.go b/syncer/types.go index f52e7417..fc06421a 100644 --- a/syncer/types.go +++ b/syncer/types.go @@ -81,6 +81,15 @@ const ( // to handle different events. It is common to write logs or // perform reconciliation in the sync processor. type Handler interface { + // BlockSeen is invoked AT LEAST ONCE + // by the syncer prior to calling BlockAdded + // with the same arguments. This allows for + // storing block data before it is sequenced. + BlockSeen( + ctx context.Context, + block *types.Block, + ) error + BlockAdded( ctx context.Context, block *types.Block,