From 81a812cdd9b082cf1d24ed12029f602afe58e06e Mon Sep 17 00:00:00 2001 From: Paul Nicolas Date: Wed, 26 Nov 2025 14:39:39 +0100 Subject: [PATCH 1/4] feat(accounts,pools): compute balances inside postgres --- internal/api/services/pools_balances.go | 32 +---- internal/api/services/pools_balances_at.go | 34 +---- .../api/services/pools_balances_at_test.go | 41 +++--- internal/api/services/pools_balances_test.go | 40 +++--- internal/storage/balances.go | 125 ++++++------------ internal/storage/balances_test.go | 79 ++++++++--- internal/storage/storage.go | 3 +- internal/storage/storage_generated.go | 30 +---- 8 files changed, 161 insertions(+), 223 deletions(-) diff --git a/internal/api/services/pools_balances.go b/internal/api/services/pools_balances.go index 17b7c6dbd..e81619838 100644 --- a/internal/api/services/pools_balances.go +++ b/internal/api/services/pools_balances.go @@ -30,35 +30,9 @@ func (s *Service) PoolsBalances( } } - res := make(map[string]*aggregatedBalance) - for i := range pool.PoolAccounts { - balances, err := s.storage.BalancesGetLatest(ctx, pool.PoolAccounts[i]) - if err != nil { - return nil, newStorageError(err, "cannot get latest balances") - } - - for _, balance := range balances { - v, ok := res[balance.Asset] - if !ok { - v = &aggregatedBalance{ - amount: big.NewInt(0), - relatedAccounts: []models.AccountID{}, - } - } - - v.amount = v.amount.Add(v.amount, balance.Balance) - v.relatedAccounts = append(v.relatedAccounts, balance.AccountID) - res[balance.Asset] = v - } - } - - balances := make([]models.AggregatedBalance, 0, len(res)) - for asset, balance := range res { - balances = append(balances, models.AggregatedBalance{ - Asset: asset, - Amount: balance.amount, - RelatedAccounts: balance.relatedAccounts, - }) + balances, err := s.storage.BalancesGetFromAccountIDs(ctx, pool.PoolAccounts, nil) + if err != nil { + return nil, newStorageError(err, "cannot get latest balances") } return balances, nil diff --git a/internal/api/services/pools_balances_at.go b/internal/api/services/pools_balances_at.go index 4322fde5e..df7d4bc3b 100644 --- a/internal/api/services/pools_balances_at.go +++ b/internal/api/services/pools_balances_at.go @@ -3,10 +3,10 @@ package services import ( "context" "encoding/json" - "math/big" "time" "github.com/formancehq/go-libs/v3/bun/bunpaginate" + "github.com/formancehq/go-libs/v3/pointer" "github.com/formancehq/go-libs/v3/query" "github.com/formancehq/payments/internal/models" "github.com/formancehq/payments/internal/storage" @@ -31,35 +31,9 @@ func (s *Service) PoolsBalancesAt( } } - res := make(map[string]*aggregatedBalance) - for i := range pool.PoolAccounts { - balances, err := s.storage.BalancesGetAt(ctx, pool.PoolAccounts[i], at) - if err != nil { - return nil, newStorageError(err, "cannot get balances") - } - - for _, balance := range balances { - v, ok := res[balance.Asset] - if !ok { - v = &aggregatedBalance{ - amount: big.NewInt(0), - relatedAccounts: []models.AccountID{}, - } - } - - v.amount = v.amount.Add(v.amount, balance.Balance) - v.relatedAccounts = append(v.relatedAccounts, balance.AccountID) - res[balance.Asset] = v - } - } - - balances := make([]models.AggregatedBalance, 0, len(res)) - for asset, balance := range res { - balances = append(balances, models.AggregatedBalance{ - Asset: asset, - Amount: balance.amount, - RelatedAccounts: balance.relatedAccounts, - }) + balances, err := s.storage.BalancesGetFromAccountIDs(ctx, pool.PoolAccounts, pointer.For(at)) + if err != nil { + return nil, newStorageError(err, "cannot get balances") } return balances, nil diff --git a/internal/api/services/pools_balances_at_test.go b/internal/api/services/pools_balances_at_test.go index fcf1f6fd4..64d7866c8 100644 --- a/internal/api/services/pools_balances_at_test.go +++ b/internal/api/services/pools_balances_at_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/formancehq/go-libs/v3/pointer" "github.com/formancehq/payments/internal/connectors/engine" "github.com/formancehq/payments/internal/models" "github.com/formancehq/payments/internal/storage" @@ -27,30 +28,30 @@ func TestPoolsBalancesAt(t *testing.T) { id := uuid.New() poolsAccount := []models.AccountID{{}} - balancesResponse := []*models.Balance{ + balancesResponse := []models.AggregatedBalance{ { - AccountID: models.AccountID{ - Reference: "test1", - ConnectorID: models.ConnectorID{}, + RelatedAccounts: []models.AccountID{ + { + Reference: "test1", + ConnectorID: models.ConnectorID{}, + }, + { + Reference: "test2", + ConnectorID: models.ConnectorID{}, + }, }, - Asset: "EUR/2", - Balance: big.NewInt(100), + Asset: "EUR/2", + Amount: big.NewInt(400), }, { - AccountID: models.AccountID{ - Reference: "test1", - ConnectorID: models.ConnectorID{}, + RelatedAccounts: []models.AccountID{ + { + Reference: "test1", + ConnectorID: models.ConnectorID{}, + }, }, - Asset: "USD/2", - Balance: big.NewInt(200), - }, - { - AccountID: models.AccountID{ - Reference: "test2", - ConnectorID: models.ConnectorID{}, - }, - Asset: "EUR/2", - Balance: big.NewInt(300), + Asset: "USD/2", + Amount: big.NewInt(200), }, } at := time.Now().Add(-time.Hour) @@ -98,7 +99,7 @@ func TestPoolsBalancesAt(t *testing.T) { PoolAccounts: poolsAccount, }, test.poolsGetStorageErr) if test.poolsGetStorageErr == nil { - store.EXPECT().BalancesGetAt(gomock.Any(), models.AccountID{}, at).Return(balancesResponse, test.accountsBalancesAtErr) + store.EXPECT().BalancesGetFromAccountIDs(gomock.Any(), gomock.Any(), pointer.For(at)).Return(balancesResponse, test.accountsBalancesAtErr) } balances, err := s.PoolsBalancesAt(context.Background(), id, at) diff --git a/internal/api/services/pools_balances_test.go b/internal/api/services/pools_balances_test.go index e00d7f1fb..6054ef961 100644 --- a/internal/api/services/pools_balances_test.go +++ b/internal/api/services/pools_balances_test.go @@ -27,30 +27,30 @@ func TestPoolsBalancesLatest(t *testing.T) { id := uuid.New() poolsAccount := []models.AccountID{{}} - balancesResponse := []*models.Balance{ + balancesResponse := []models.AggregatedBalance{ { - AccountID: models.AccountID{ - Reference: "test1", - ConnectorID: models.ConnectorID{}, + RelatedAccounts: []models.AccountID{ + { + Reference: "test1", + ConnectorID: models.ConnectorID{}, + }, + { + Reference: "test2", + ConnectorID: models.ConnectorID{}, + }, }, - Asset: "EUR/2", - Balance: big.NewInt(100), + Asset: "EUR/2", + Amount: big.NewInt(400), }, { - AccountID: models.AccountID{ - Reference: "test1", - ConnectorID: models.ConnectorID{}, + RelatedAccounts: []models.AccountID{ + { + Reference: "test1", + ConnectorID: models.ConnectorID{}, + }, }, - Asset: "USD/2", - Balance: big.NewInt(200), - }, - { - AccountID: models.AccountID{ - Reference: "test2", - ConnectorID: models.ConnectorID{}, - }, - Asset: "EUR/2", - Balance: big.NewInt(300), + Asset: "USD/2", + Amount: big.NewInt(200), }, } @@ -97,7 +97,7 @@ func TestPoolsBalancesLatest(t *testing.T) { PoolAccounts: poolsAccount, }, test.poolsGetStorageErr) if test.poolsGetStorageErr == nil { - store.EXPECT().BalancesGetLatest(gomock.Any(), models.AccountID{}).Return(balancesResponse, test.accountsBalancesErr) + store.EXPECT().BalancesGetFromAccountIDs(gomock.Any(), gomock.Any(), nil).Return(balancesResponse, test.accountsBalancesErr) } balances, err := s.PoolsBalances(context.Background(), id) diff --git a/internal/storage/balances.go b/internal/storage/balances.go index 1e52e4b56..ff8fcc532 100644 --- a/internal/storage/balances.go +++ b/internal/storage/balances.go @@ -4,11 +4,11 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "math/big" "time" "github.com/formancehq/go-libs/v3/bun/bunpaginate" - "github.com/formancehq/go-libs/v3/pointer" internalTime "github.com/formancehq/go-libs/v3/time" internalEvents "github.com/formancehq/payments/internal/events" "github.com/formancehq/payments/internal/models" @@ -287,99 +287,58 @@ func (s *store) BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpagi }, nil } -func (s *store) balancesListAssets(ctx context.Context, accountID models.AccountID) ([]string, error) { - var assets []string - - err := s.db.NewSelect(). - ColumnExpr("DISTINCT asset"). - Model(&models.Balance{}). - Where("account_id = ?", accountID). - Scan(ctx, &assets) - if err != nil { - return nil, e("failed to list balance assets", err) +// Get balances from account IDs at a specific time. If at is nil, it will return the latest balances. +func (s *store) BalancesGetFromAccountIDs(ctx context.Context, accountIDs []models.AccountID, at *time.Time) ([]models.AggregatedBalance, error) { + type assetBalances struct { + AccountIDs []string `bun:"account_ids,array"` + Asset string `bun:"asset"` + Balance *big.Int `bun:"balance"` } - return assets, nil -} - -func (s *store) balancesGetAtByAsset(ctx context.Context, accountID models.AccountID, asset string, at time.Time) (*models.Balance, error) { - var balance balance - - err := s.db.NewSelect(). - Model(&balance). - Where("account_id = ?", accountID). - Where("asset = ?", asset). - Where("created_at <= ?", at). - Where("last_updated_at >= ?", at). - Order("created_at DESC", "sort_id DESC"). - Limit(1). - Scan(ctx) - if err != nil { - return nil, e("failed to get balance", err) - } - - return pointer.For(toBalanceModels(balance)), nil -} - -func (s *store) balancesGetLatestByAsset(ctx context.Context, accountID models.AccountID, asset string) (*models.Balance, error) { - var balance balance - - err := s.db.NewSelect(). - Model(&balance). - Where("account_id = ?", accountID). - Where("asset = ?", asset). - Order("created_at DESC", "sort_id DESC"). - Limit(1). - Scan(ctx) - if err != nil { - return nil, e("failed to get latest balance", err) + selectedBalancesQuery := s.db.NewSelect(). + Model((*balance)(nil)). + DistinctOn("account_id, asset"). + Column("account_id", "asset", "created_at", "sort_id", "balance"). + Where("account_id IN (?)", bun.In(accountIDs)). + Order("account_id desc", "asset desc", "created_at desc", "sort_id desc") + + if at != nil && !at.IsZero() { + selectedBalancesQuery = selectedBalancesQuery.Where("created_at <= ?", at). + Where("last_updated_at >= ?", at) } - - return pointer.For(toBalanceModels(balance)), nil -} - -func (s *store) BalancesGetAt(ctx context.Context, accountID models.AccountID, at time.Time) ([]*models.Balance, error) { - assets, err := s.balancesListAssets(ctx, accountID) + var balanceAssets []assetBalances + query := s.db.NewSelect(). + With( + "selected_balances", + selectedBalancesQuery, + ). + ModelTableExpr("selected_balances"). + ColumnExpr("array_agg(account_id) as account_ids, asset, SUM(balance) AS balance"). + Group("asset") + + fmt.Println(query.String()) + + err := query.Scan(ctx, &balanceAssets) if err != nil { return nil, e("failed to list balance assets", err) } - var balances []*models.Balance - for _, currency := range assets { - balance, err := s.balancesGetAtByAsset(ctx, accountID, currency, at) - if err != nil { - if errors.Is(err, ErrNotFound) { - continue - } - return nil, e("failed to get balance", err) - } - - balances = append(balances, balance) - } - - return balances, nil -} + fmt.Println(balanceAssets) -func (s *store) BalancesGetLatest(ctx context.Context, accountID models.AccountID) ([]*models.Balance, error) { - assets, err := s.balancesListAssets(ctx, accountID) - if err != nil { - return nil, e("failed to list balance assets", err) - } - - var balances []*models.Balance - for _, currency := range assets { - balance, err := s.balancesGetLatestByAsset(ctx, accountID, currency) - if err != nil { - if errors.Is(err, ErrNotFound) { - continue - } - return nil, e("failed to get latest balance for asset", err) + res := make([]models.AggregatedBalance, 0, len(balanceAssets)) + for _, balanceAsset := range balanceAssets { + relatedAccounts := make([]models.AccountID, len(balanceAsset.AccountIDs)) + for i, accountID := range balanceAsset.AccountIDs { + relatedAccounts[i] = models.MustAccountIDFromString(accountID) } - - balances = append(balances, balance) + res = append(res, models.AggregatedBalance{ + Asset: balanceAsset.Asset, + Amount: balanceAsset.Balance, + RelatedAccounts: relatedAccounts, + }) } - return balances, nil + return res, nil } func fromBalancesModels(from []models.Balance) []balance { diff --git a/internal/storage/balances_test.go b/internal/storage/balances_test.go index be32c6ba8..f01d47dac 100644 --- a/internal/storage/balances_test.go +++ b/internal/storage/balances_test.go @@ -3,6 +3,7 @@ package storage import ( "context" "encoding/json" + "fmt" "math/big" "testing" "time" @@ -66,6 +67,42 @@ func defaultBalances2() []models.Balance { } } +func defaultBalancesDuplicates() []models.Balance { + defaultAccounts := defaultAccounts() + return []models.Balance{ + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-60 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-60 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(100), + }, + { + AccountID: defaultAccounts[1].ID, + CreatedAt: now.Add(-30 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-30 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(1000), + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-55 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-55 * time.Minute).UTC().Time, + Asset: "EUR/2", + Balance: big.NewInt(150), + PsuID: &defaultPSU2.ID, + OpenBankingConnectionID: &defaultOpenBankingConnection.ConnectionID, + }, + { + AccountID: defaultAccounts[0].ID, + CreatedAt: now.Add(-50 * time.Minute).UTC().Time, + LastUpdatedAt: now.Add(-50 * time.Minute).UTC().Time, + Asset: "USD/2", + Balance: big.NewInt(200), + }, + } +} + func upsertBalances(t *testing.T, ctx context.Context, storage Storage, balances []models.Balance) { require.NoError(t, storage.BalancesUpsert(ctx, balances)) } @@ -840,21 +877,21 @@ func TestBalancesGetAt(t *testing.T) { t.Run("get balances at before first balance should return empty", func(t *testing.T) { accounts := defaultAccounts() - balances, err := store.BalancesGetAt(ctx, accounts[0].ID, now.Add(-61*time.Minute).UTC().Time) + balances, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, pointer.For(now.Add(-61*time.Minute).UTC().Time)) require.NoError(t, err) - require.Nil(t, balances) + require.Empty(t, balances) }) t.Run("get balances at after last balance updated at should return empty", func(t *testing.T) { accounts := defaultAccounts() - balances, err := store.BalancesGetAt(ctx, accounts[0].ID, now.Add(-50*time.Minute).UTC().Time) + balances, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, pointer.For(now.Add(-50*time.Minute).UTC().Time)) require.NoError(t, err) - require.Nil(t, balances) + require.Empty(t, balances) }) t.Run("get balances at", func(t *testing.T) { accounts := defaultAccounts() - balances, err := store.BalancesGetAt(ctx, accounts[0].ID, now.Add(-60*time.Minute).UTC().Time) + balances, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, pointer.For(now.Add(-60*time.Minute).UTC().Time)) require.NoError(t, err) require.NotNil(t, balances) require.Len(t, balances, 1) @@ -872,7 +909,7 @@ func TestBalancesGetAt(t *testing.T) { upsertBalances(t, ctx, store, []models.Balance{b}) - balances, err := store.BalancesGetAt(ctx, accounts[0].ID, now.Add(-50*time.Minute).UTC().Time) + balances, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, pointer.For(now.Add(-50*time.Minute).UTC().Time)) require.NoError(t, err) require.NotNil(t, balances) require.Len(t, balances, 1) @@ -899,14 +936,14 @@ func TestBalancesGetAt(t *testing.T) { upsertBalances(t, ctx, store, []models.Balance{b, b1}) - balances, err := store.BalancesGetAt(ctx, accounts[0].ID, now.Add(-50*time.Minute).UTC().Time) + balances, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, pointer.For(now.Add(-50*time.Minute).UTC().Time)) require.NoError(t, err) require.NotNil(t, balances) require.Len(t, balances, 2) }) } -func TestBalancesGetLatest(t *testing.T) { +func TestBalancesGetLatestFromAccountIDs(t *testing.T) { t.Parallel() ctx := logging.TestingContext() @@ -917,11 +954,15 @@ func TestBalancesGetLatest(t *testing.T) { createPSU(t, ctx, store, defaultPSU2) createOpenBankingConnection(t, ctx, store, defaultPSU2.ID, defaultOpenBankingConnection) upsertAccounts(t, ctx, store, defaultAccounts()) - upsertBalances(t, ctx, store, defaultBalances()) + upsertBalances(t, ctx, store, defaultBalancesDuplicates()) t.Run("get latest balances returns 1 balance per currency", func(t *testing.T) { accounts := defaultAccounts() - balances, err := store.BalancesGetLatest(ctx, accounts[0].ID) + accountIDs := make([]models.AccountID, len(accounts)) + for i, account := range accounts { + accountIDs[i] = account.ID + } + balances, err := store.BalancesGetFromAccountIDs(ctx, accountIDs, nil) require.NoError(t, err) require.NotNil(t, balances) require.Len(t, balances, 2) @@ -929,8 +970,12 @@ func TestBalancesGetLatest(t *testing.T) { assert.Equal(t, balances[1].Asset, "USD/2") }) - t.Run("get balances after inserting a new balance", func(t *testing.T) { + t.Run("get latest balances after inserting a new balance", func(t *testing.T) { accounts := defaultAccounts() + accountIDs := make([]models.AccountID, len(accounts)) + for i, account := range accounts { + accountIDs[i] = account.ID + } b := models.Balance{ AccountID: accounts[0].ID, CreatedAt: now.Add(-20 * time.Minute).UTC().Time, @@ -941,15 +986,17 @@ func TestBalancesGetLatest(t *testing.T) { upsertBalances(t, ctx, store, []models.Balance{b}) - balances, err := store.BalancesGetLatest(ctx, accounts[0].ID) + balances, err := store.BalancesGetFromAccountIDs(ctx, accountIDs, nil) require.NoError(t, err) + fmt.Println(balances) require.NotNil(t, balances) require.Len(t, balances, 2) assert.Equal(t, balances[1].Asset, "USD/2") - assert.Equal(t, balances[1].Balance, b.Balance) + assert.Equal(t, balances[1].Amount, b.Balance) + assert.Equal(t, balances[1].RelatedAccounts, []models.AccountID{accounts[0].ID}) }) - t.Run("rollback on foreign key violation", func(t *testing.T) { + t.Run("rollback on foreign key violation with latest balances", func(t *testing.T) { // Create a valid balance first upsertConnector(t, ctx, store, defaultConnector) createPSU(t, ctx, store, defaultPSU2) @@ -958,7 +1005,7 @@ func TestBalancesGetLatest(t *testing.T) { accounts := defaultAccounts() // Count existing balances - balancesBefore, err := store.BalancesGetLatest(ctx, accounts[0].ID) + balancesBefore, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, nil) require.NoError(t, err) countBefore := len(balancesBefore) @@ -977,7 +1024,7 @@ func TestBalancesGetLatest(t *testing.T) { require.Error(t, err) // Verify no balance was inserted - balancesAfter, err := store.BalancesGetLatest(ctx, accounts[0].ID) + balancesAfter, err := store.BalancesGetFromAccountIDs(ctx, []models.AccountID{accounts[0].ID}, nil) require.NoError(t, err) assert.Equal(t, countBefore, len(balancesAfter), "no balances should be inserted on error") diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 8bfad85ee..98f4ab209 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -30,8 +30,7 @@ type Storage interface { // Balances BalancesUpsert(ctx context.Context, balances []models.Balance) error BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpaginate.Cursor[models.Balance], error) - BalancesGetAt(ctx context.Context, accountID models.AccountID, at time.Time) ([]*models.Balance, error) - BalancesGetLatest(ctx context.Context, accountID models.AccountID) ([]*models.Balance, error) + BalancesGetFromAccountIDs(ctx context.Context, accountIDs []models.AccountID, at *time.Time) ([]models.AggregatedBalance, error) // Bank Accounts BankAccountsUpsert(ctx context.Context, bankAccount models.BankAccount) error diff --git a/internal/storage/storage_generated.go b/internal/storage/storage_generated.go index c5e0ff831..49af018a1 100644 --- a/internal/storage/storage_generated.go +++ b/internal/storage/storage_generated.go @@ -25,7 +25,6 @@ import ( type MockStorage struct { ctrl *gomock.Controller recorder *MockStorageMockRecorder - isgomock struct{} } // MockStorageMockRecorder is the mock recorder for MockStorage. @@ -159,34 +158,19 @@ func (mr *MockStorageMockRecorder) AccountsUpsert(ctx, accounts any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccountsUpsert", reflect.TypeOf((*MockStorage)(nil).AccountsUpsert), ctx, accounts) } -// BalancesGetAt mocks base method. -func (m *MockStorage) BalancesGetAt(ctx context.Context, accountID models.AccountID, at time.Time) ([]*models.Balance, error) { +// BalancesGetFromAccountIDs mocks base method. +func (m *MockStorage) BalancesGetFromAccountIDs(ctx context.Context, accountIDs []models.AccountID, at *time.Time) ([]models.AggregatedBalance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BalancesGetAt", ctx, accountID, at) - ret0, _ := ret[0].([]*models.Balance) + ret := m.ctrl.Call(m, "BalancesGetFromAccountIDs", ctx, accountIDs, at) + ret0, _ := ret[0].([]models.AggregatedBalance) ret1, _ := ret[1].(error) return ret0, ret1 } -// BalancesGetAt indicates an expected call of BalancesGetAt. -func (mr *MockStorageMockRecorder) BalancesGetAt(ctx, accountID, at any) *gomock.Call { +// BalancesGetFromAccountIDs indicates an expected call of BalancesGetFromAccountIDs. +func (mr *MockStorageMockRecorder) BalancesGetFromAccountIDs(ctx, accountIDs, at any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BalancesGetAt", reflect.TypeOf((*MockStorage)(nil).BalancesGetAt), ctx, accountID, at) -} - -// BalancesGetLatest mocks base method. -func (m *MockStorage) BalancesGetLatest(ctx context.Context, accountID models.AccountID) ([]*models.Balance, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BalancesGetLatest", ctx, accountID) - ret0, _ := ret[0].([]*models.Balance) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BalancesGetLatest indicates an expected call of BalancesGetLatest. -func (mr *MockStorageMockRecorder) BalancesGetLatest(ctx, accountID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BalancesGetLatest", reflect.TypeOf((*MockStorage)(nil).BalancesGetLatest), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BalancesGetFromAccountIDs", reflect.TypeOf((*MockStorage)(nil).BalancesGetFromAccountIDs), ctx, accountIDs, at) } // BalancesList mocks base method. From a4780b622b9d0d5944bfe20033eaa0708e032659 Mon Sep 17 00:00:00 2001 From: Paul Nicolas Date: Wed, 26 Nov 2025 15:08:31 +0100 Subject: [PATCH 2/4] fix dirty --- internal/api/services/pools_balances.go | 6 ------ internal/storage/storage_generated.go | 1 + 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/internal/api/services/pools_balances.go b/internal/api/services/pools_balances.go index e81619838..e4e93dca7 100644 --- a/internal/api/services/pools_balances.go +++ b/internal/api/services/pools_balances.go @@ -2,17 +2,11 @@ package services import ( "context" - "math/big" "github.com/formancehq/payments/internal/models" "github.com/google/uuid" ) -type aggregatedBalance struct { - amount *big.Int - relatedAccounts []models.AccountID -} - func (s *Service) PoolsBalances( ctx context.Context, poolID uuid.UUID, diff --git a/internal/storage/storage_generated.go b/internal/storage/storage_generated.go index 49af018a1..a8dd79bfc 100644 --- a/internal/storage/storage_generated.go +++ b/internal/storage/storage_generated.go @@ -25,6 +25,7 @@ import ( type MockStorage struct { ctrl *gomock.Controller recorder *MockStorageMockRecorder + isgomock struct{} } // MockStorageMockRecorder is the mock recorder for MockStorage. From aa6bafca9609fb319a1497960a6db22f702972b8 Mon Sep 17 00:00:00 2001 From: Paul Nicolas Date: Wed, 26 Nov 2025 16:53:59 +0100 Subject: [PATCH 3/4] remove useless prints --- internal/storage/balances.go | 5 ----- internal/storage/balances_test.go | 2 -- 2 files changed, 7 deletions(-) diff --git a/internal/storage/balances.go b/internal/storage/balances.go index ff8fcc532..dc2f90b9a 100644 --- a/internal/storage/balances.go +++ b/internal/storage/balances.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "encoding/json" - "fmt" "math/big" "time" @@ -316,15 +315,11 @@ func (s *store) BalancesGetFromAccountIDs(ctx context.Context, accountIDs []mode ColumnExpr("array_agg(account_id) as account_ids, asset, SUM(balance) AS balance"). Group("asset") - fmt.Println(query.String()) - err := query.Scan(ctx, &balanceAssets) if err != nil { return nil, e("failed to list balance assets", err) } - fmt.Println(balanceAssets) - res := make([]models.AggregatedBalance, 0, len(balanceAssets)) for _, balanceAsset := range balanceAssets { relatedAccounts := make([]models.AccountID, len(balanceAsset.AccountIDs)) diff --git a/internal/storage/balances_test.go b/internal/storage/balances_test.go index f01d47dac..fb15a2eef 100644 --- a/internal/storage/balances_test.go +++ b/internal/storage/balances_test.go @@ -3,7 +3,6 @@ package storage import ( "context" "encoding/json" - "fmt" "math/big" "testing" "time" @@ -988,7 +987,6 @@ func TestBalancesGetLatestFromAccountIDs(t *testing.T) { balances, err := store.BalancesGetFromAccountIDs(ctx, accountIDs, nil) require.NoError(t, err) - fmt.Println(balances) require.NotNil(t, balances) require.Len(t, balances, 2) assert.Equal(t, balances[1].Asset, "USD/2") From 05b31cbf5bcd08e32825707c94da3fd4442ecbf3 Mon Sep 17 00:00:00 2001 From: Paul Nicolas Date: Thu, 27 Nov 2025 10:06:49 +0100 Subject: [PATCH 4/4] add check about accounts id length --- internal/storage/balances.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/storage/balances.go b/internal/storage/balances.go index dc2f90b9a..703a0ecc7 100644 --- a/internal/storage/balances.go +++ b/internal/storage/balances.go @@ -288,6 +288,11 @@ func (s *store) BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpagi // Get balances from account IDs at a specific time. If at is nil, it will return the latest balances. func (s *store) BalancesGetFromAccountIDs(ctx context.Context, accountIDs []models.AccountID, at *time.Time) ([]models.AggregatedBalance, error) { + if len(accountIDs) == 0 { + // return empty array if no account IDs are provided + return []models.AggregatedBalance{}, nil + } + type assetBalances struct { AccountIDs []string `bun:"account_ids,array"` Asset string `bun:"asset"`