Skip to content

Commit

Permalink
Broadcast test
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-ogrady committed Jul 26, 2020
1 parent 8cc4b5c commit 25730b5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 11 deletions.
23 changes: 12 additions & 11 deletions internal/storage/broadcast_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ type BroadcastStorageHandler interface {
) error
}

// broadcast is persisted to the db to track transaction broadcast.
type broadcast struct {
// Broadcast is persisted to the db to track transaction broadcast.
type Broadcast struct {
Identifier *types.TransactionIdentifier `json:"identifier"`
Sender string `json:"sender"`
Intent []*types.Operation `json:"intent"`
Expand Down Expand Up @@ -149,7 +149,7 @@ func (b *BroadcastStorage) addBlockCommitWorker(
ctx context.Context,
block *types.Block,
staleTransactions []*types.TransactionIdentifier,
confirmedTransactions []*broadcast,
confirmedTransactions []*Broadcast,
foundTransactions []*types.Transaction,
) error {
for _, stale := range staleTransactions {
Expand Down Expand Up @@ -187,13 +187,13 @@ func (b *BroadcastStorage) AddingBlock(
block *types.Block,
transaction DatabaseTransaction,
) (CommitWorker, error) {
broadcasts, err := b.getAllBroadcasts(ctx)
broadcasts, err := b.GetAllBroadcasts(ctx)
if err != nil {
return nil, fmt.Errorf("%w: unable to get all broadcasts", err)
}

staleTransactions := []*types.TransactionIdentifier{}
confirmedTransactions := []*broadcast{}
confirmedTransactions := []*Broadcast{}
foundTransactions := []*types.Transaction{}

for _, broadcast := range broadcasts {
Expand Down Expand Up @@ -281,7 +281,7 @@ func (b *BroadcastStorage) Broadcast(
return fmt.Errorf("already broadcasting transaction %s", transactionIdentifier.Hash)
}

bytes, err := encode(&broadcast{
bytes, err := encode(&Broadcast{
Identifier: transactionIdentifier,
Sender: sender,
Intent: intent,
Expand All @@ -303,15 +303,16 @@ func (b *BroadcastStorage) Broadcast(
return nil
}

func (b *BroadcastStorage) getAllBroadcasts(ctx context.Context) ([]*broadcast, error) {
// GetAllBroadcasts returns all currently in-process broadcasts.
func (b *BroadcastStorage) GetAllBroadcasts(ctx context.Context) ([]*Broadcast, error) {
rawBroadcasts, err := b.db.Scan(ctx, []byte(transactionBroadcastNamespace))
if err != nil {
return nil, fmt.Errorf("%w: unable to scan for all broadcasts", err)
}

broadcasts := make([]*broadcast, len(rawBroadcasts))
broadcasts := make([]*Broadcast, len(rawBroadcasts))
for i, rawBroadcast := range rawBroadcasts {
var b broadcast
var b Broadcast
if err := decode(rawBroadcast, &b); err != nil {
return nil, fmt.Errorf("%w: unable to decode broadcast", err)
}
Expand All @@ -338,7 +339,7 @@ func (b *BroadcastStorage) broadcastPending(ctx context.Context) error {
return nil
}

broadcasts, err := b.getAllBroadcasts(ctx)
broadcasts, err := b.GetAllBroadcasts(ctx)
if err != nil {
return fmt.Errorf("%w: unable to get all broadcasts", err)
}
Expand Down Expand Up @@ -417,7 +418,7 @@ func (b *BroadcastStorage) broadcastPending(ctx context.Context) error {
// The caller SHOULD NOT broadcast a transaction from an account if it is
// considered locked!
func (b *BroadcastStorage) LockedAddresses(ctx context.Context) ([]string, error) {
broadcasts, err := b.getAllBroadcasts(ctx)
broadcasts, err := b.GetAllBroadcasts(ctx)
if err != nil {
return nil, fmt.Errorf("%w: unable to get all broadcasts", err)
}
Expand Down
60 changes: 60 additions & 0 deletions internal/storage/broadcast_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,44 @@ const (
func makeFillerBlocks(start int64, end int64) []*types.Block {
blocks := []*types.Block{}
for i := start; i < end; i++ {
parentIndex := i - 1
if parentIndex < 0 {
parentIndex = 0
}
blocks = append(blocks, &types.Block{
BlockIdentifier: &types.BlockIdentifier{
Index: i,
Hash: fmt.Sprintf("block %d", i),
},
ParentBlockIdentifier: &types.BlockIdentifier{
Index: parentIndex,
Hash: fmt.Sprintf("block %d", parentIndex),
},
})
}

return blocks
}

func opFiller(sender string, opNumber int) []*types.Operation {
ops := make([]*types.Operation, opNumber)
for i := 0; i < opNumber; i++ {
ops[i] = &types.Operation{
OperationIdentifier: &types.OperationIdentifier{
Index: int64(i),
},
Account: &types.AccountIdentifier{
Address: sender,
SubAccount: &types.SubAccountIdentifier{
Address: sender,
},
},
}
}

return ops
}

func TestBroadcastStorage(t *testing.T) {
ctx := context.Background()

Expand All @@ -67,6 +94,39 @@ func TestBroadcastStorage(t *testing.T) {
assert.Len(t, addresses, 0)
assert.NotNil(t, addresses)
})

t.Run("broadcast", func(t *testing.T) {
send1 := opFiller("addr 1", 11)
err := storage.Broadcast(ctx, "addr 1", send1, &types.TransactionIdentifier{Hash: "tx 1"}, "payload 1")
assert.NoError(t, err)

send2 := opFiller("addr 2", 13)
err = storage.Broadcast(ctx, "addr 2", send2, &types.TransactionIdentifier{Hash: "tx 2"}, "payload 2")
assert.NoError(t, err)

// Check to make sure duplicate instances of address aren't reported
addresses, err := storage.LockedAddresses(ctx)
assert.NoError(t, err)
assert.Len(t, addresses, 2)
assert.ElementsMatch(t, []string{"addr 1", "addr 2"}, addresses)

broadcasts, err := storage.GetAllBroadcasts(ctx)
assert.NoError(t, err)
assert.ElementsMatch(t, []*Broadcast{
{
Identifier: &types.TransactionIdentifier{Hash: "tx 1"},
Sender: "addr 1",
Intent: send1,
Payload: "payload 1",
},
{
Identifier: &types.TransactionIdentifier{Hash: "tx 2"},
Sender: "addr 2",
Intent: send2,
Payload: "payload 2",
},
}, broadcasts)
})
}

var _ BroadcastStorageHelper = (*MockBroadcastStorageHelper)(nil)
Expand Down

0 comments on commit 25730b5

Please sign in to comment.