diff --git a/mocks/utils/fetcher_helper.go b/mocks/utils/fetcher_helper.go index 0337ac800..02fbf73dd 100644 --- a/mocks/utils/fetcher_helper.go +++ b/mocks/utils/fetcher_helper.go @@ -59,6 +59,49 @@ func (_m *FetcherHelper) AccountBalanceRetry(ctx context.Context, network *types return r0, r1, r2, r3 } +// AccountCoinsRetry provides a mock function with given fields: ctx, network, account, includeMempool, currencies +func (_m *FetcherHelper) AccountCoinsRetry(ctx context.Context, network *types.NetworkIdentifier, account *types.AccountIdentifier, includeMempool bool, currencies []*types.Currency) (*types.BlockIdentifier, []*types.Coin, map[string]interface{}, *fetcher.Error) { + ret := _m.Called(ctx, network, account, includeMempool, currencies) + + var r0 *types.BlockIdentifier + if rf, ok := ret.Get(0).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, bool, []*types.Currency) *types.BlockIdentifier); ok { + r0 = rf(ctx, network, account, includeMempool, currencies) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BlockIdentifier) + } + } + + var r1 []*types.Coin + if rf, ok := ret.Get(1).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, bool, []*types.Currency) []*types.Coin); ok { + r1 = rf(ctx, network, account, includeMempool, currencies) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]*types.Coin) + } + } + + var r2 map[string]interface{} + if rf, ok := ret.Get(2).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, bool, []*types.Currency) map[string]interface{}); ok { + r2 = rf(ctx, network, account, includeMempool, currencies) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(map[string]interface{}) + } + } + + var r3 *fetcher.Error + if rf, ok := ret.Get(3).(func(context.Context, *types.NetworkIdentifier, *types.AccountIdentifier, bool, []*types.Currency) *fetcher.Error); ok { + r3 = rf(ctx, network, account, includeMempool, currencies) + } else { + if ret.Get(3) != nil { + r3 = ret.Get(3).(*fetcher.Error) + } + } + + return r0, r1, r2, r3 +} + // NetworkList provides a mock function with given fields: ctx, metadata func (_m *FetcherHelper) NetworkList(ctx context.Context, metadata map[string]interface{}) (*types.NetworkListResponse, *fetcher.Error) { ret := _m.Called(ctx, metadata) diff --git a/storage/modules/coin_storage.go b/storage/modules/coin_storage.go index 52a39f485..20cf69fca 100644 --- a/storage/modules/coin_storage.go +++ b/storage/modules/coin_storage.go @@ -501,21 +501,28 @@ func (c *CoinStorage) GetLargestCoin( // This is used when importing prefunded addresses. func (c *CoinStorage) SetCoinsImported( ctx context.Context, - accountBalances []*utils.AccountBalance, + accts []*types.AccountIdentifier, + acctCoinsResp []*utils.AccountCoinsResponse, ) error { - var accountCoins []*types.AccountCoin - for _, accountBalance := range accountBalances { - for _, coin := range accountBalance.Coins { - accountCoin := &types.AccountCoin{ - Account: accountBalance.Account, + // Request array length should always equal response array length. + // But we still check it for sure. + if len(accts) != len(acctCoinsResp) { + return errors.ErrCoinImportFailed + } + + var acctCoins []*types.AccountCoin + for i, resp := range acctCoinsResp { + for _, coin := range resp.Coins { + acctCoin := &types.AccountCoin{ + Account: accts[i], Coin: coin, } - accountCoins = append(accountCoins, accountCoin) + acctCoins = append(acctCoins, acctCoin) } } - if err := c.AddCoins(ctx, accountCoins); err != nil { + if err := c.AddCoins(ctx, acctCoins); err != nil { return fmt.Errorf("%w: %v", errors.ErrCoinImportFailed, err) } diff --git a/storage/modules/coin_storage_test.go b/storage/modules/coin_storage_test.go index 3773be797..8d16e9816 100644 --- a/storage/modules/coin_storage_test.go +++ b/storage/modules/coin_storage_test.go @@ -386,6 +386,15 @@ var ( }, }, } + + acctCoins = []*utils.AccountCoinsResponse{ + { + Coins: accBalance1.Coins, + }, + { + Coins: accBalance2.Coins, + }, + } ) func TestCoinStorage(t *testing.T) { @@ -961,9 +970,8 @@ func TestCoinStorage(t *testing.T) { }) t.Run("SetCoinsImported", func(t *testing.T) { - accBalances := []*utils.AccountBalance{accBalance1, accBalance2} - - err := c.SetCoinsImported(ctx, accBalances) + accts := []*types.AccountIdentifier{accBalance1.Account, accBalance2.Account} + err := c.SetCoinsImported(ctx, accts, acctCoins) assert.NoError(t, err) mockHelper.On( diff --git a/utils/utils.go b/utils/utils.go index da286da9e..4740fd834 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -192,6 +192,14 @@ type FetcherHelper interface { block *types.PartialBlockIdentifier, currencies []*types.Currency, ) (*types.BlockIdentifier, []*types.Amount, map[string]interface{}, *fetcher.Error) + + AccountCoinsRetry( + ctx context.Context, + network *types.NetworkIdentifier, + acct *types.AccountIdentifier, + includeMempool bool, + currencies []*types.Currency, + ) (*types.BlockIdentifier, []*types.Coin, map[string]interface{}, *fetcher.Error) } type BlockStorageHelper interface { @@ -491,6 +499,57 @@ func GetAccountBalances( return accountBalances, nil } +// ------------------------------------------------------------------------------- +// ----------------- Helper struct for fetching account coins -------------------- +// ------------------------------------------------------------------------------- + +// AccountCoinsRequest defines the required information to get an account's coins. +type AccountCoinsRequest struct { + Account *types.AccountIdentifier + Network *types.NetworkIdentifier + Currencies []*types.Currency + IncludeMempool bool +} + +// AccountCoins defines an account's coins info at tip. +type AccountCoinsResponse struct { + Coins []*types.Coin +} + +// GetAccountCoins calls /account/coins endpoint and returns an array of coins at tip. +func GetAccountCoins( + ctx context.Context, + fetcher FetcherHelper, + acctCoinsReqs []*AccountCoinsRequest, +) ([]*AccountCoinsResponse, error) { + var acctCoins []*AccountCoinsResponse + for _, req := range acctCoinsReqs { + _, coins, _, err := fetcher.AccountCoinsRetry( + ctx, + req.Network, + req.Account, + req.IncludeMempool, + req.Currencies, + ) + + if err != nil { + return nil, err.Err + } + + resp := &AccountCoinsResponse{ + Coins: coins, + } + + acctCoins = append(acctCoins, resp) + } + + return acctCoins, nil +} + +// ------------------------------------------------------------------------------- +// ------------------- End of helper struct for account coins -------------------- +// ------------------------------------------------------------------------------- + // AtTip returns a boolean indicating if a block timestamp // is within tipDelay from the current time. func AtTip(