Skip to content

Commit

Permalink
refresh all vault orders
Browse files Browse the repository at this point in the history
  • Loading branch information
tqin7 committed Mar 20, 2024
1 parent bd041ab commit c400abd
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 17 deletions.
6 changes: 6 additions & 0 deletions protocol/lib/metrics/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ const (
DistributedRewardTokens = "distributed_reward_tokens"
TreasuryBalanceAfterDistribution = "treasury_balance_after_distribution"

// Vault.
VaultCancelOrder = "vault_cancel_order"
VaultPlaceOrder = "vault_place_order"
VaultType = "vault_type"
VaultId = "vault_id"

// Vest.
GetVestEntry = "get_vest_entry"
VestAmount = "vest_amount"
Expand Down
6 changes: 6 additions & 0 deletions protocol/x/vault/abci_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package vault_test

import "testing"

// TODO (TRA-168): add endblocker test once deposit is implemented.
func TestEndBlocker(t *testing.T) {}
84 changes: 76 additions & 8 deletions protocol/x/vault/keeper/orders.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (

errorsmod "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/lib/log"
"github.com/dydxprotocol/v4-chain/protocol/lib/metrics"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
"github.com/dydxprotocol/v4-chain/protocol/x/vault/types"
)

Expand All @@ -30,10 +32,79 @@ const (
)

// RefreshAllVaultOrders refreshes all orders for all vaults by
// TODO(TRA-134)
// 1. Cancelling all existing orders.
// 2. Placing new orders.
func (k Keeper) RefreshAllVaultOrders(ctx sdk.Context) {
// Iterate through all vaults.
totalSharesIterator := k.getTotalSharesIterator(ctx)
defer totalSharesIterator.Close()
for ; totalSharesIterator.Valid(); totalSharesIterator.Next() {
var vaultId types.VaultId
k.cdc.MustUnmarshal(totalSharesIterator.Key(), &vaultId)
var totalShares types.NumShares
k.cdc.MustUnmarshal(totalSharesIterator.Value(), &totalShares)

// Skip if TotalShares is non-positive.
if totalShares.NumShares.Cmp(dtypes.NewInt(0)) <= 0 {
continue
}

// Refresh orders depending on vault type.
switch vaultId.Type {
case types.VaultType_VAULT_TYPE_CLOB:
err := k.RefreshVaultClobOrders(ctx, vaultId)
if err != nil {
log.ErrorLogWithError(ctx, "Failed to refresh vault clob orders", err, "vaultId", vaultId)
}
}
}
}

// RefreshVaultClobOrders refreshes orders of a CLOB vault.
func (k Keeper) RefreshVaultClobOrders(ctx sdk.Context, vaultId types.VaultId) (err error) {
// Cancel CLOB orders from last block.
ordersToCancel, err := k.GetVaultClobOrders(
ctx.WithBlockHeight(ctx.BlockHeight()-1),
vaultId,
)
if err != nil {
log.ErrorLogWithError(ctx, "Failed to get vault clob orders to cancel", err, "vaultId", vaultId)
return err
}
for _, order := range ordersToCancel {
if _, exists := k.clobKeeper.GetLongTermOrderPlacement(ctx, order.OrderId); exists {
err := k.clobKeeper.HandleMsgCancelOrder(ctx, clobtypes.NewMsgCancelOrderStateful(
order.OrderId,
uint32(ctx.BlockTime().Unix())+ORDER_EXPIRATION_SECONDS,
))
if err != nil {
log.ErrorLogWithError(ctx, "Failed to cancel order", err, "order", order, "vaultId", vaultId)
}
vaultId.IncrCounterWithLabels(
metrics.VaultCancelOrder,
metrics.GetLabelForBoolValue(metrics.Success, err == nil),
)
}
}

// Place new CLOB orders.
ordersToPlace, err := k.GetVaultClobOrders(ctx, vaultId)
if err != nil {
log.ErrorLogWithError(ctx, "Failed to get vault clob orders to place", err, "vaultId", vaultId)
return err
}
for _, order := range ordersToPlace {
err := k.clobKeeper.HandleMsgPlaceOrder(ctx, clobtypes.NewMsgPlaceOrder(*order))
if err != nil {
log.ErrorLogWithError(ctx, "Failed to place order", err, "order", order, "vaultId", vaultId)
}
vaultId.IncrCounterWithLabels(
metrics.VaultPlaceOrder,
metrics.GetLabelForBoolValue(metrics.Success, err == nil),
)
}

return nil
}

// GetVaultClobOrders returns a list of long term orders for a given vault, with its corresponding
Expand Down Expand Up @@ -80,10 +151,7 @@ func (k Keeper) GetVaultClobOrders(
}

// Get vault (subaccount 0 of corresponding module account).
vault := satypes.SubaccountId{
Owner: vaultId.ToModuleAccountAddress(),
Number: 0,
}
vault := vaultId.ToSubaccountId()
// Calculate spread.
spreadPpm := lib.Max(
MIN_BASE_SPREAD_PPM,
Expand Down Expand Up @@ -112,7 +180,7 @@ func (k Keeper) GetVaultClobOrders(
// Construct ask at this layer.
ask := clobtypes.Order{
OrderId: clobtypes.OrderId{
SubaccountId: vault,
SubaccountId: *vault,
ClientId: k.GetVaultClobOrderClientId(ctx, clobtypes.Order_SIDE_SELL, uint8(i+1)),
OrderFlags: clobtypes.OrderIdFlags_LongTerm,
ClobPairId: clobPair.Id,
Expand All @@ -130,7 +198,7 @@ func (k Keeper) GetVaultClobOrders(
// Construct bid at this layer.
bid := clobtypes.Order{
OrderId: clobtypes.OrderId{
SubaccountId: vault,
SubaccountId: *vault,
ClientId: k.GetVaultClobOrderClientId(ctx, clobtypes.Order_SIDE_BUY, uint8(i+1)),
OrderFlags: clobtypes.OrderIdFlags_LongTerm,
ClobPairId: clobPair.Id,
Expand Down
202 changes: 195 additions & 7 deletions protocol/x/vault/keeper/orders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/cometbft/cometbft/types"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
testapp "github.com/dydxprotocol/v4-chain/protocol/testutil/app"
"github.com/dydxprotocol/v4-chain/protocol/testutil/constants"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
Expand All @@ -23,6 +24,196 @@ const (
orderExpirationSeconds = uint32(5) // 5 seconds
)

func TestRefreshAllVaultOrders(t *testing.T) {
tests := map[string]struct {
// Vault IDs.
vaultIds []vaulttypes.VaultId
// Total Shares of each vault ID above.
totalShares []vaulttypes.NumShares
}{
"Two Vaults, Both Positive Shares": {
vaultIds: []vaulttypes.VaultId{
constants.Vault_Clob_0,
constants.Vault_Clob_1,
},
totalShares: []vaulttypes.NumShares{
{
NumShares: dtypes.NewInt(1_000),
},
{
NumShares: dtypes.NewInt(200),
},
},
},
"Two Vaults, One Positive Shares, One Zero Shares": {
vaultIds: []vaulttypes.VaultId{
constants.Vault_Clob_0,
constants.Vault_Clob_1,
},
totalShares: []vaulttypes.NumShares{
{
NumShares: dtypes.NewInt(1_000),
},
{
NumShares: dtypes.NewInt(0),
},
},
},
"Two Vaults, Both Zero Shares": {
vaultIds: []vaulttypes.VaultId{
constants.Vault_Clob_0,
constants.Vault_Clob_1,
},
totalShares: []vaulttypes.NumShares{
{
NumShares: dtypes.NewInt(0),
},
{
NumShares: dtypes.NewInt(0),
},
},
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
// Initialize tApp and ctx (in deliverTx mode).
tApp := testapp.NewTestAppBuilder(t).WithGenesisDocFn(func() (genesis types.GenesisDoc) {
genesis = testapp.DefaultGenesis()
// Initialize each vault with quote quantums to be able to place orders.
testapp.UpdateGenesisDocWithAppStateForModule(
&genesis,
func(genesisState *satypes.GenesisState) {
subaccounts := make([]satypes.Subaccount, len(tc.vaultIds))
for i, vaultId := range tc.vaultIds {
subaccounts[i] = satypes.Subaccount{
Id: vaultId.ToSubaccountId(),
AssetPositions: []*satypes.AssetPosition{
{
AssetId: 0,
Quantums: dtypes.NewInt(1_000_000_000), // 1,000 USDC
},
},
}
}
genesisState.Subaccounts = subaccounts
},
)
return genesis
}).Build()
ctx := tApp.InitChain().WithIsCheckTx(false)

// Set total shares for each vault ID.
for i, vaultId := range tc.vaultIds {
err := tApp.App.VaultKeeper.SetTotalShares(ctx, vaultId, tc.totalShares[i])
require.NoError(t, err)
}

// Check that there's no stateful orders yet.
allStatefulOrders := tApp.App.ClobKeeper.GetAllStatefulOrders(ctx)
require.Len(t, allStatefulOrders, 0)

// Refresh all vault orders.
tApp.App.VaultKeeper.RefreshAllVaultOrders(ctx)

// Check orders are as expected.
numExpectedOrders := 0
allExpectedOrderIds := make(map[clobtypes.OrderId]bool)
for i, vaultId := range tc.vaultIds {
if tc.totalShares[i].NumShares.Cmp(dtypes.NewInt(0)) > 0 {
expectedOrders, err := tApp.App.VaultKeeper.GetVaultClobOrders(ctx, vaultId)
require.NoError(t, err)
numExpectedOrders += len(expectedOrders)
for _, order := range expectedOrders {
allExpectedOrderIds[order.OrderId] = true
}
}
}
allStatefulOrders = tApp.App.ClobKeeper.GetAllStatefulOrders(ctx)
require.Len(t, allStatefulOrders, numExpectedOrders)
for _, order := range allStatefulOrders {
require.True(t, allExpectedOrderIds[order.OrderId])
}
})
}
}

func TestRefreshVaultClobOrders(t *testing.T) {
tests := map[string]struct {
/* --- Setup --- */
// Vault ID.
vaultId vaulttypes.VaultId

/* --- Expectations --- */
expectedErr error
}{
"Success - Refresh Orders from Vault for Clob Pair 0": {
vaultId: constants.Vault_Clob_0,
},
"Error - Refresh Orders from Vault for Clob Pair 4321 (non-existent clob pair)": {
vaultId: vaulttypes.VaultId{
Type: vaulttypes.VaultType_VAULT_TYPE_CLOB,
Number: 4321,
},
expectedErr: vaulttypes.ErrClobPairNotFound,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
// Initialize tApp and ctx (in deliverTx mode).
tApp := testapp.NewTestAppBuilder(t).WithGenesisDocFn(func() (genesis types.GenesisDoc) {
genesis = testapp.DefaultGenesis()
// Initialize vault with quote quantums to be able to place orders.
testapp.UpdateGenesisDocWithAppStateForModule(
&genesis,
func(genesisState *satypes.GenesisState) {
genesisState.Subaccounts = []satypes.Subaccount{
{
Id: tc.vaultId.ToSubaccountId(),
AssetPositions: []*satypes.AssetPosition{
{
AssetId: 0,
Quantums: dtypes.NewInt(1_000_000_000), // 1,000 USDC
},
},
},
}
},
)
return genesis
}).Build()
ctx := tApp.InitChain().WithIsCheckTx(false)

// Check that there's no stateful orders yet.
allStatefulOrders := tApp.App.ClobKeeper.GetAllStatefulOrders(ctx)
require.Len(t, allStatefulOrders, 0)

// Refresh vault orders.
err := tApp.App.VaultKeeper.RefreshVaultClobOrders(ctx, tc.vaultId)
allStatefulOrders = tApp.App.ClobKeeper.GetAllStatefulOrders(ctx)
if tc.expectedErr != nil {
// Check that the error is as expected.
require.ErrorContains(t, err, tc.expectedErr.Error())
// Check that there's no stateful orders.
require.Len(t, allStatefulOrders, 0)
return
} else {
// Check that there's no error.
require.NoError(t, err)
// Check that the number of orders is as expected.
require.Len(t, allStatefulOrders, int(numLayers)*2)
// Check that the orders are as expected.
expectedOrders, err := tApp.App.VaultKeeper.GetVaultClobOrders(ctx, tc.vaultId)
require.NoError(t, err)
for i := uint8(0); i < numLayers*2; i++ {
require.Equal(t, *expectedOrders[i], allStatefulOrders[i])
}
}
})
}
}

func TestGetVaultClobOrders(t *testing.T) {
tests := map[string]struct {
/* --- Setup --- */
Expand Down Expand Up @@ -167,13 +358,10 @@ func TestGetVaultClobOrders(t *testing.T) {
) *clobtypes.Order {
return &clobtypes.Order{
OrderId: clobtypes.OrderId{
SubaccountId: satypes.SubaccountId{
Owner: tc.vaultId.ToModuleAccountAddress(),
Number: 0,
},
ClientId: tApp.App.VaultKeeper.GetVaultClobOrderClientId(ctx, side, layer),
OrderFlags: clobtypes.OrderIdFlags_LongTerm,
ClobPairId: tc.vaultId.Number,
SubaccountId: *tc.vaultId.ToSubaccountId(),
ClientId: tApp.App.VaultKeeper.GetVaultClobOrderClientId(ctx, side, layer),
OrderFlags: clobtypes.OrderIdFlags_LongTerm,
ClobPairId: tc.vaultId.Number,
},
Side: side,
Quantums: quantums,
Expand Down
8 changes: 8 additions & 0 deletions protocol/x/vault/keeper/shares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"cosmossdk.io/store/prefix"
storetypes "cosmossdk.io/store/types"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
"github.com/dydxprotocol/v4-chain/protocol/x/vault/types"
Expand Down Expand Up @@ -39,3 +40,10 @@ func (k Keeper) SetTotalShares(

return nil
}

// getTotalSharesIterator returns an iterator over all TotalShares.
func (k Keeper) getTotalSharesIterator(ctx sdk.Context) storetypes.Iterator {
store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.TotalSharesKeyPrefix))

return storetypes.KVStorePrefixIterator(store, []byte{})
}
Loading

0 comments on commit c400abd

Please sign in to comment.