diff --git a/pallets/subtensor/src/staking/stake_utils.rs b/pallets/subtensor/src/staking/stake_utils.rs index 21ff1a57d5..1cbd0d2b76 100644 --- a/pallets/subtensor/src/staking/stake_utils.rs +++ b/pallets/subtensor/src/staking/stake_utils.rs @@ -21,6 +21,10 @@ impl Pallet { SubnetAlphaIn::::get(netuid).saturating_add(SubnetAlphaOut::::get(netuid)) } + pub fn get_protocol_tao(netuid: NetUid) -> TaoCurrency { + T::SwapInterface::get_protocol_tao(netuid) + } + pub fn get_moving_alpha_price(netuid: NetUid) -> U96F32 { let one = U96F32::saturating_from_num(1.0); if netuid.is_root() { @@ -688,6 +692,9 @@ impl Pallet { price_limit: TaoCurrency, drop_fees: bool, ) -> Result { + // Record the protocol TAO before the swap. + let protocol_tao = Self::get_protocol_tao(netuid); + // Decrease alpha on subnet let actual_alpha_decrease = Self::decrease_stake_for_hotkey_and_coldkey_on_subnet(hotkey, coldkey, netuid, alpha); @@ -696,6 +703,11 @@ impl Pallet { let swap_result = Self::swap_alpha_for_tao(netuid, actual_alpha_decrease, price_limit, drop_fees)?; + // Record the protocol TAO after the swap. + let protocol_tao_after = Self::get_protocol_tao(netuid); + // This should decrease as we are removing TAO from the protocol. + let protocol_tao_delta: TaoCurrency = protocol_tao.saturating_sub(protocol_tao_after); + // Refund the unused alpha (in case if limit price is hit) let refund = actual_alpha_decrease.saturating_sub( swap_result @@ -722,7 +734,7 @@ impl Pallet { // } // Record TAO outflow - Self::record_tao_outflow(netuid, swap_result.amount_paid_out.into()); + Self::record_tao_outflow(netuid, protocol_tao_delta); LastColdkeyHotkeyStakeBlock::::insert(coldkey, hotkey, Self::get_current_block_as_u64()); @@ -761,9 +773,18 @@ impl Pallet { set_limit: bool, drop_fees: bool, ) -> Result { + // Record the protocol TAO before the swap. + let protocol_tao = Self::get_protocol_tao(netuid); + // Swap the tao to alpha. let swap_result = Self::swap_tao_for_alpha(netuid, tao, price_limit, drop_fees)?; + // Record the protocol TAO after the swap. + let protocol_tao_after = Self::get_protocol_tao(netuid); + + // This should increase as we are adding TAO to the protocol. + let protocol_tao_delta: TaoCurrency = protocol_tao_after.saturating_sub(protocol_tao); + ensure!( !swap_result.amount_paid_out.is_zero(), Error::::AmountTooLow @@ -799,7 +820,7 @@ impl Pallet { } // Record TAO inflow - Self::record_tao_inflow(netuid, swap_result.amount_paid_in.into()); + Self::record_tao_inflow(netuid, protocol_tao_delta); LastColdkeyHotkeyStakeBlock::::insert(coldkey, hotkey, Self::get_current_block_as_u64()); diff --git a/pallets/subtensor/src/tests/staking.rs b/pallets/subtensor/src/tests/staking.rs index 27c5b5c16d..4146786709 100644 --- a/pallets/subtensor/src/tests/staking.rs +++ b/pallets/subtensor/src/tests/staking.rs @@ -5596,14 +5596,13 @@ fn test_staking_records_flow() { mock::setup_reserves(netuid, tao_reserve, alpha_in); // Initialize swap v3 - let order = GetAlphaForTao::::with_amount(0); - assert_ok!(::SwapInterface::swap( - netuid.into(), - order, - TaoCurrency::MAX, + SubtensorModule::swap_tao_for_alpha( + netuid, + TaoCurrency::ZERO, + 1_000_000_000_000.into(), false, - true - )); + ) + .unwrap(); // Add stake with slippage safety and check if the result is ok assert_ok!(SubtensorModule::stake_into_subnet( diff --git a/pallets/swap-interface/src/lib.rs b/pallets/swap-interface/src/lib.rs index ae7d375f97..19af1303c1 100644 --- a/pallets/swap-interface/src/lib.rs +++ b/pallets/swap-interface/src/lib.rs @@ -39,6 +39,7 @@ pub trait SwapHandler { fn approx_fee_amount(netuid: NetUid, amount: T) -> T; fn current_alpha_price(netuid: NetUid) -> U96F32; + fn get_protocol_tao(netuid: NetUid) -> TaoCurrency; fn max_price() -> C; fn min_price() -> C; fn adjust_protocol_liquidity( diff --git a/pallets/swap/src/pallet/impls.rs b/pallets/swap/src/pallet/impls.rs index 34b5e624e6..74f6d410d5 100644 --- a/pallets/swap/src/pallet/impls.rs +++ b/pallets/swap/src/pallet/impls.rs @@ -1112,6 +1112,24 @@ impl SwapHandler for Pallet { Self::current_price(netuid.into()) } + fn get_protocol_tao(netuid: NetUid) -> TaoCurrency { + let protocol_account_id = Self::protocol_account_id(); + let mut positions = + Positions::::iter_prefix_values((netuid, protocol_account_id.clone())) + .collect::>(); + + if let Some(position) = positions.get_mut(0) { + let current_sqrt_price = AlphaSqrtPrice::::get(netuid); + // Adjust liquidity + let maybe_token_amounts = position.to_token_amounts(current_sqrt_price); + if let Ok((tao, _)) = maybe_token_amounts { + return tao.into(); + } + } + + TaoCurrency::ZERO + } + fn min_price() -> C { Self::min_price_inner() }