diff --git a/storage/modules/block_storage.go b/storage/modules/block_storage.go index 4d41bafd..d34565f0 100644 --- a/storage/modules/block_storage.go +++ b/storage/modules/block_storage.go @@ -21,6 +21,7 @@ import ( "log" "runtime" "strconv" + "strings" "github.com/neilotoole/errgroup" @@ -77,9 +78,27 @@ func getBlockIndexKey(index int64) []byte { return []byte(fmt.Sprintf("%s/%d", blockIndexNamespace, index)) } -func getTransactionHashKey(transactionIdentifier *types.TransactionIdentifier) (string, []byte) { +func getTransactionHashKey( + transactionIdentifier *types.TransactionIdentifier, + blockIdentifier *types.BlockIdentifier, +) (string, []byte) { return transactionNamespace, []byte( - fmt.Sprintf("%s/%s", transactionNamespace, transactionIdentifier.Hash), + fmt.Sprintf( + "%s/%s/%s", + transactionNamespace, + transactionIdentifier.Hash, + blockIdentifier.Hash, + ), + ) +} + +func getTransactionHashPrefix(transactionIdentifier *types.TransactionIdentifier) []byte { + return []byte( + fmt.Sprintf( + "%s/%s/", + transactionNamespace, + transactionIdentifier.Hash, + ), ) } @@ -230,10 +249,24 @@ func (b *BlockStorage) pruneBlock( if err == nil { blockIdentifier := blockResponse.Block.BlockIdentifier - for _, tx := range blockResponse.OtherTransactions { - if err := b.pruneTransaction(ctx, dbTx, blockIdentifier, tx); err != nil { - return -1, fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err) - } + // Remove all transaction hashes + g, gctx := errgroup.WithContextN(ctx, b.numCPU, b.numCPU) + for i := range blockResponse.OtherTransactions { + // We need to set variable before calling goroutine + // to avoid getting an updated pointer as loop iteration + // continues. + tx := blockResponse.OtherTransactions[i] + g.Go(func() error { + if err := b.pruneTransaction(gctx, dbTx, blockIdentifier, tx); err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err) + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return -1, err } _, blockKey := getBlockHashKey(blockIdentifier.Hash) @@ -851,9 +884,9 @@ func (b *BlockStorage) updateTransaction( dbTx database.Transaction, hashKey []byte, namespace string, - blocks map[string]*blockTransaction, + blockTransaction *blockTransaction, ) error { - encodedResult, err := b.db.Encoder().Encode(namespace, blocks) + encodedResult, err := b.db.Encoder().Encode(namespace, blockTransaction) if err != nil { return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err) } @@ -871,29 +904,16 @@ func (b *BlockStorage) storeTransaction( blockIdentifier *types.BlockIdentifier, tx *types.Transaction, ) error { - namespace, hashKey := getTransactionHashKey(tx.TransactionIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } + namespace, hashKey := getTransactionHashKey(tx.TransactionIdentifier, blockIdentifier) - var blocks map[string]*blockTransaction - if !exists { - blocks = make(map[string]*blockTransaction) - } else { - err := b.db.Encoder().Decode(namespace, val, &blocks, true) - if err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) - } - } // We check for duplicates before storing transaction, // so this must be a new key. - blocks[blockIdentifier.Hash] = &blockTransaction{ + bt := &blockTransaction{ Transaction: tx, BlockIndex: blockIdentifier.Index, } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + return b.updateTransaction(ctx, transaction, hashKey, namespace, bt) } func (b *BlockStorage) pruneTransaction( @@ -902,25 +922,12 @@ func (b *BlockStorage) pruneTransaction( blockIdentifier *types.BlockIdentifier, txIdentifier *types.TransactionIdentifier, ) error { - namespace, hashKey := getTransactionHashKey(txIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } - if !exists { - return storageErrs.ErrTransactionNotFound - } - - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) - } - - blocks[blockIdentifier.Hash] = &blockTransaction{ + namespace, hashKey := getTransactionHashKey(txIdentifier, blockIdentifier) + bt := &blockTransaction{ BlockIndex: blockIdentifier.Index, } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + return b.updateTransaction(ctx, transaction, hashKey, namespace, bt) } func (b *BlockStorage) removeTransaction( @@ -929,36 +936,48 @@ func (b *BlockStorage) removeTransaction( blockIdentifier *types.BlockIdentifier, transactionIdentifier *types.TransactionIdentifier, ) error { - namespace, hashKey := getTransactionHashKey(transactionIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } - - if !exists { - return fmt.Errorf( - "%w %s", - storageErrs.ErrTransactionDeleteFailed, - transactionIdentifier.Hash, - ) - } - - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) - } - - if _, exists := blocks[blockIdentifier.Hash]; !exists { - return fmt.Errorf("%w %s", storageErrs.ErrTransactionHashNotFound, blockIdentifier.Hash) - } + _, hashKey := getTransactionHashKey(transactionIdentifier, blockIdentifier) + return transaction.Delete(ctx, hashKey) +} - delete(blocks, blockIdentifier.Hash) +func (b *BlockStorage) getAllTransactionsByIdentifier( + ctx context.Context, + transactionIdentifier *types.TransactionIdentifier, + txn database.Transaction, +) ([]*types.BlockTransaction, error) { + blockTransactions := []*types.BlockTransaction{} + _, err := txn.Scan( + ctx, + getTransactionHashPrefix(transactionIdentifier), + getTransactionHashPrefix(transactionIdentifier), + func(k []byte, v []byte) error { + // Decode blockTransaction + var bt blockTransaction + if err := b.db.Encoder().Decode(transactionNamespace, v, &bt, false); err != nil { + return fmt.Errorf("%w: unable to decode block data for transaction", err) + } - if len(blocks) == 0 { - return transaction.Delete(ctx, hashKey) + // Extract hash from key + splitKey := strings.Split(string(k), "/") + blockHash := splitKey[len(splitKey)-1] + + blockTransactions = append(blockTransactions, &types.BlockTransaction{ + BlockIdentifier: &types.BlockIdentifier{ + Index: bt.BlockIndex, + Hash: blockHash, + }, + Transaction: bt.Transaction, + }) + return nil + }, + false, + true, + ) + if err != nil { + return nil, err } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + return blockTransactions, nil } // FindTransaction returns the most recent *types.BlockIdentifier containing the @@ -968,27 +987,20 @@ func (b *BlockStorage) FindTransaction( transactionIdentifier *types.TransactionIdentifier, txn database.Transaction, ) (*types.BlockIdentifier, *types.Transaction, error) { - namespace, key := getTransactionHashKey(transactionIdentifier) - txExists, tx, err := txn.Get(ctx, key) + blockTransactions, err := b.getAllTransactionsByIdentifier(ctx, transactionIdentifier, txn) if err != nil { return nil, nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err) } - if !txExists { + if len(blockTransactions) == 0 { return nil, nil, nil } - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil { - return nil, nil, fmt.Errorf("%w: unable to decode block data for transaction", err) - } - var newestBlock *types.BlockIdentifier var newestTransaction *types.Transaction - for hash, blockTransaction := range blocks { - b := &types.BlockIdentifier{Hash: hash, Index: blockTransaction.BlockIndex} - if newestBlock == nil || blockTransaction.BlockIndex > newestBlock.Index { - newestBlock = b + for _, blockTransaction := range blockTransactions { + if newestBlock == nil || blockTransaction.BlockIdentifier.Index > newestBlock.Index { + newestBlock = blockTransaction.BlockIdentifier newestTransaction = blockTransaction.Transaction } } @@ -1016,7 +1028,7 @@ func (b *BlockStorage) findBlockTransaction( return nil, storageErrs.ErrCannotAccessPrunedData } - namespace, key := getTransactionHashKey(transactionIdentifier) + namespace, key := getTransactionHashKey(transactionIdentifier, blockIdentifier) txExists, tx, err := txn.Get(ctx, key) if err != nil { return nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err) @@ -1030,22 +1042,12 @@ func (b *BlockStorage) findBlockTransaction( ) } - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil { + var bt blockTransaction + if err := b.db.Encoder().Decode(namespace, tx, &bt, true); err != nil { return nil, fmt.Errorf("%w: unable to decode block data for transaction", err) } - val, ok := blocks[blockIdentifier.Hash] - if !ok { - return nil, fmt.Errorf( - "%w: did not find transaction %s in block %s", - storageErrs.ErrTransactionDoesNotExistInBlock, - transactionIdentifier.Hash, - blockIdentifier.Hash, - ) - } - - return val.Transaction, nil + return bt.Transaction, nil } // GetBlockTransaction retrieves a transaction belonging to a certain