Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions internal/api/services/pools_balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@ package services

import (
"context"
"math/big"

"github.com/formancehq/payments/internal/models"
"github.com/google/uuid"
)

type aggregatedBalance struct {
amount *big.Int
relatedAccounts []models.AccountID
}

func (s *Service) PoolsBalances(
ctx context.Context,
poolID uuid.UUID,
Expand All @@ -30,35 +24,9 @@ func (s *Service) PoolsBalances(
}
}

res := make(map[string]*aggregatedBalance)
for i := range pool.PoolAccounts {
balances, err := s.storage.BalancesGetLatest(ctx, pool.PoolAccounts[i])
if err != nil {
return nil, newStorageError(err, "cannot get latest balances")
}

for _, balance := range balances {
v, ok := res[balance.Asset]
if !ok {
v = &aggregatedBalance{
amount: big.NewInt(0),
relatedAccounts: []models.AccountID{},
}
}

v.amount = v.amount.Add(v.amount, balance.Balance)
v.relatedAccounts = append(v.relatedAccounts, balance.AccountID)
res[balance.Asset] = v
}
}

balances := make([]models.AggregatedBalance, 0, len(res))
for asset, balance := range res {
balances = append(balances, models.AggregatedBalance{
Asset: asset,
Amount: balance.amount,
RelatedAccounts: balance.relatedAccounts,
})
balances, err := s.storage.BalancesGetFromAccountIDs(ctx, pool.PoolAccounts, nil)
if err != nil {
return nil, newStorageError(err, "cannot get latest balances")
}

return balances, nil
Expand Down
34 changes: 4 additions & 30 deletions internal/api/services/pools_balances_at.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package services
import (
"context"
"encoding/json"
"math/big"
"time"

"github.com/formancehq/go-libs/v3/bun/bunpaginate"
"github.com/formancehq/go-libs/v3/pointer"
"github.com/formancehq/go-libs/v3/query"
"github.com/formancehq/payments/internal/models"
"github.com/formancehq/payments/internal/storage"
Expand All @@ -31,35 +31,9 @@ func (s *Service) PoolsBalancesAt(
}
}

res := make(map[string]*aggregatedBalance)
for i := range pool.PoolAccounts {
balances, err := s.storage.BalancesGetAt(ctx, pool.PoolAccounts[i], at)
if err != nil {
return nil, newStorageError(err, "cannot get balances")
}

for _, balance := range balances {
v, ok := res[balance.Asset]
if !ok {
v = &aggregatedBalance{
amount: big.NewInt(0),
relatedAccounts: []models.AccountID{},
}
}

v.amount = v.amount.Add(v.amount, balance.Balance)
v.relatedAccounts = append(v.relatedAccounts, balance.AccountID)
res[balance.Asset] = v
}
}

balances := make([]models.AggregatedBalance, 0, len(res))
for asset, balance := range res {
balances = append(balances, models.AggregatedBalance{
Asset: asset,
Amount: balance.amount,
RelatedAccounts: balance.relatedAccounts,
})
balances, err := s.storage.BalancesGetFromAccountIDs(ctx, pool.PoolAccounts, pointer.For(at))
if err != nil {
return nil, newStorageError(err, "cannot get balances")
}

return balances, nil
Expand Down
41 changes: 21 additions & 20 deletions internal/api/services/pools_balances_at_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/formancehq/go-libs/v3/pointer"
"github.com/formancehq/payments/internal/connectors/engine"
"github.com/formancehq/payments/internal/models"
"github.com/formancehq/payments/internal/storage"
Expand All @@ -27,30 +28,30 @@ func TestPoolsBalancesAt(t *testing.T) {

id := uuid.New()
poolsAccount := []models.AccountID{{}}
balancesResponse := []*models.Balance{
balancesResponse := []models.AggregatedBalance{
{
AccountID: models.AccountID{
Reference: "test1",
ConnectorID: models.ConnectorID{},
RelatedAccounts: []models.AccountID{
{
Reference: "test1",
ConnectorID: models.ConnectorID{},
},
{
Reference: "test2",
ConnectorID: models.ConnectorID{},
},
},
Asset: "EUR/2",
Balance: big.NewInt(100),
Asset: "EUR/2",
Amount: big.NewInt(400),
},
{
AccountID: models.AccountID{
Reference: "test1",
ConnectorID: models.ConnectorID{},
RelatedAccounts: []models.AccountID{
{
Reference: "test1",
ConnectorID: models.ConnectorID{},
},
},
Asset: "USD/2",
Balance: big.NewInt(200),
},
{
AccountID: models.AccountID{
Reference: "test2",
ConnectorID: models.ConnectorID{},
},
Asset: "EUR/2",
Balance: big.NewInt(300),
Asset: "USD/2",
Amount: big.NewInt(200),
},
}
at := time.Now().Add(-time.Hour)
Expand Down Expand Up @@ -98,7 +99,7 @@ func TestPoolsBalancesAt(t *testing.T) {
PoolAccounts: poolsAccount,
}, test.poolsGetStorageErr)
if test.poolsGetStorageErr == nil {
store.EXPECT().BalancesGetAt(gomock.Any(), models.AccountID{}, at).Return(balancesResponse, test.accountsBalancesAtErr)
store.EXPECT().BalancesGetFromAccountIDs(gomock.Any(), gomock.Any(), pointer.For(at)).Return(balancesResponse, test.accountsBalancesAtErr)
}

balances, err := s.PoolsBalancesAt(context.Background(), id, at)
Expand Down
40 changes: 20 additions & 20 deletions internal/api/services/pools_balances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,30 @@ func TestPoolsBalancesLatest(t *testing.T) {

id := uuid.New()
poolsAccount := []models.AccountID{{}}
balancesResponse := []*models.Balance{
balancesResponse := []models.AggregatedBalance{
{
AccountID: models.AccountID{
Reference: "test1",
ConnectorID: models.ConnectorID{},
RelatedAccounts: []models.AccountID{
{
Reference: "test1",
ConnectorID: models.ConnectorID{},
},
{
Reference: "test2",
ConnectorID: models.ConnectorID{},
},
},
Asset: "EUR/2",
Balance: big.NewInt(100),
Asset: "EUR/2",
Amount: big.NewInt(400),
},
{
AccountID: models.AccountID{
Reference: "test1",
ConnectorID: models.ConnectorID{},
RelatedAccounts: []models.AccountID{
{
Reference: "test1",
ConnectorID: models.ConnectorID{},
},
},
Asset: "USD/2",
Balance: big.NewInt(200),
},
{
AccountID: models.AccountID{
Reference: "test2",
ConnectorID: models.ConnectorID{},
},
Asset: "EUR/2",
Balance: big.NewInt(300),
Asset: "USD/2",
Amount: big.NewInt(200),
},
}

Expand Down Expand Up @@ -97,7 +97,7 @@ func TestPoolsBalancesLatest(t *testing.T) {
PoolAccounts: poolsAccount,
}, test.poolsGetStorageErr)
if test.poolsGetStorageErr == nil {
store.EXPECT().BalancesGetLatest(gomock.Any(), models.AccountID{}).Return(balancesResponse, test.accountsBalancesErr)
store.EXPECT().BalancesGetFromAccountIDs(gomock.Any(), gomock.Any(), nil).Return(balancesResponse, test.accountsBalancesErr)
}

balances, err := s.PoolsBalances(context.Background(), id)
Expand Down
123 changes: 41 additions & 82 deletions internal/storage/balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/formancehq/go-libs/v3/bun/bunpaginate"
"github.com/formancehq/go-libs/v3/pointer"
internalTime "github.com/formancehq/go-libs/v3/time"
internalEvents "github.com/formancehq/payments/internal/events"
"github.com/formancehq/payments/internal/models"
Expand Down Expand Up @@ -287,99 +286,59 @@ func (s *store) BalancesList(ctx context.Context, q ListBalancesQuery) (*bunpagi
}, nil
}

func (s *store) balancesListAssets(ctx context.Context, accountID models.AccountID) ([]string, error) {
var assets []string

err := s.db.NewSelect().
ColumnExpr("DISTINCT asset").
Model(&models.Balance{}).
Where("account_id = ?", accountID).
Scan(ctx, &assets)
if err != nil {
return nil, e("failed to list balance assets", err)
}

return assets, nil
}

func (s *store) balancesGetAtByAsset(ctx context.Context, accountID models.AccountID, asset string, at time.Time) (*models.Balance, error) {
var balance balance

err := s.db.NewSelect().
Model(&balance).
Where("account_id = ?", accountID).
Where("asset = ?", asset).
Where("created_at <= ?", at).
Where("last_updated_at >= ?", at).
Order("created_at DESC", "sort_id DESC").
Limit(1).
Scan(ctx)
if err != nil {
return nil, e("failed to get balance", err)
}

return pointer.For(toBalanceModels(balance)), nil
}

func (s *store) balancesGetLatestByAsset(ctx context.Context, accountID models.AccountID, asset string) (*models.Balance, error) {
var balance balance

err := s.db.NewSelect().
Model(&balance).
Where("account_id = ?", accountID).
Where("asset = ?", asset).
Order("created_at DESC", "sort_id DESC").
Limit(1).
Scan(ctx)
if err != nil {
return nil, e("failed to get latest balance", err)
// Get balances from account IDs at a specific time. If at is nil, it will return the latest balances.
func (s *store) BalancesGetFromAccountIDs(ctx context.Context, accountIDs []models.AccountID, at *time.Time) ([]models.AggregatedBalance, error) {
if len(accountIDs) == 0 {
// return empty array if no account IDs are provided
return []models.AggregatedBalance{}, nil
}

return pointer.For(toBalanceModels(balance)), nil
}

func (s *store) BalancesGetAt(ctx context.Context, accountID models.AccountID, at time.Time) ([]*models.Balance, error) {
assets, err := s.balancesListAssets(ctx, accountID)
if err != nil {
return nil, e("failed to list balance assets", err)
type assetBalances struct {
AccountIDs []string `bun:"account_ids,array"`
Asset string `bun:"asset"`
Balance *big.Int `bun:"balance"`
}

var balances []*models.Balance
for _, currency := range assets {
balance, err := s.balancesGetAtByAsset(ctx, accountID, currency, at)
if err != nil {
if errors.Is(err, ErrNotFound) {
continue
}
return nil, e("failed to get balance", err)
}

balances = append(balances, balance)
selectedBalancesQuery := s.db.NewSelect().
Model((*balance)(nil)).
DistinctOn("account_id, asset").
Column("account_id", "asset", "created_at", "sort_id", "balance").
Where("account_id IN (?)", bun.In(accountIDs)).
Order("account_id desc", "asset desc", "created_at desc", "sort_id desc")

if at != nil && !at.IsZero() {
selectedBalancesQuery = selectedBalancesQuery.Where("created_at <= ?", at).
Where("last_updated_at >= ?", at)
}

return balances, nil
}

func (s *store) BalancesGetLatest(ctx context.Context, accountID models.AccountID) ([]*models.Balance, error) {
assets, err := s.balancesListAssets(ctx, accountID)
var balanceAssets []assetBalances
query := s.db.NewSelect().
With(
"selected_balances",
selectedBalancesQuery,
).
ModelTableExpr("selected_balances").
ColumnExpr("array_agg(account_id) as account_ids, asset, SUM(balance) AS balance").
Group("asset")

err := query.Scan(ctx, &balanceAssets)
if err != nil {
return nil, e("failed to list balance assets", err)
}

var balances []*models.Balance
for _, currency := range assets {
balance, err := s.balancesGetLatestByAsset(ctx, accountID, currency)
if err != nil {
if errors.Is(err, ErrNotFound) {
continue
}
return nil, e("failed to get latest balance for asset", err)
res := make([]models.AggregatedBalance, 0, len(balanceAssets))
for _, balanceAsset := range balanceAssets {
relatedAccounts := make([]models.AccountID, len(balanceAsset.AccountIDs))
for i, accountID := range balanceAsset.AccountIDs {
relatedAccounts[i] = models.MustAccountIDFromString(accountID)
}

balances = append(balances, balance)
res = append(res, models.AggregatedBalance{
Asset: balanceAsset.Asset,
Amount: balanceAsset.Balance,
RelatedAccounts: relatedAccounts,
})
}

return balances, nil
return res, nil
}

func fromBalancesModels(from []models.Balance) []balance {
Expand Down
Loading
Loading