Skip to content

Commit

Permalink
Merge pull request #6811 from aarmoa/feat/injective_low_leve_api_comp…
Browse files Browse the repository at this point in the history
…onents

Feat/injective low leve api components
  • Loading branch information
rapcmia committed Feb 2, 2024
2 parents ae9c26a + e9f9b75 commit cecc0dd
Show file tree
Hide file tree
Showing 21 changed files with 2,969 additions and 874 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(
self._configure_event_forwarders()
self._latest_polled_order_fill_time: float = self._time()
self._orders_transactions_check_task: Optional[asyncio.Task] = None
self._last_received_message_timestamp = 0
self._orders_queued_to_create: List[GatewayPerpetualInFlightOrder] = []
self._orders_queued_to_cancel: List[GatewayPerpetualInFlightOrder] = []

Expand Down Expand Up @@ -926,31 +925,26 @@ def _configure_event_forwarders(self):
self._data_source.add_listener(event_tag=InjectiveEvent.ChainTransactionEvent, listener=event_forwarder)

def _process_balance_event(self, event: BalanceUpdateEvent):
self._last_received_message_timestamp = self._time()
self._all_trading_events_queue.put_nowait(
{"channel": "balance", "data": event}
)

def _process_position_event(self, event: BalanceUpdateEvent):
self._last_received_message_timestamp = self._time()
self._all_trading_events_queue.put_nowait(
{"channel": "position", "data": event}
)

def _process_user_order_update(self, order_update: OrderUpdate):
self._last_received_message_timestamp = self._time()
self._all_trading_events_queue.put_nowait(
{"channel": "order", "data": order_update}
)

def _process_user_trade_update(self, trade_update: TradeUpdate):
self._last_received_message_timestamp = self._time()
self._all_trading_events_queue.put_nowait(
{"channel": "trade", "data": trade_update}
)

def _process_transaction_event(self, transaction_event: Dict[str, Any]):
self._last_received_message_timestamp = self._time()
self._all_trading_events_queue.put_nowait(
{"channel": "transaction", "data": transaction_event}
)
Expand Down Expand Up @@ -1045,7 +1039,7 @@ async def _get_last_traded_price(self, trading_pair: str) -> float:
return float(last_price)

def _get_poll_interval(self, timestamp: float) -> float:
last_recv_diff = timestamp - self._last_received_message_timestamp
last_recv_diff = timestamp - self._data_source.last_received_message_timestamp
poll_interval = (
self.SHORT_POLL_INTERVAL
if last_recv_diff > self.TICK_INTERVAL_LIMIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from hummingbot.client.config.config_data_types import BaseConnectorConfigMap, ClientFieldData
from hummingbot.connector.exchange.injective_v2.injective_v2_utils import (
ACCOUNT_MODES,
FEE_CALCULATOR_MODES,
NETWORK_MODES,
InjectiveMainnetNetworkMode,
InjectiveReadOnlyAccountMode,
InjectiveSimulatedTransactionFeeCalculatorMode,
)
from hummingbot.core.data_type.trade_fee import TradeFeeSchema

Expand Down Expand Up @@ -43,6 +45,13 @@ class InjectiveConfigMap(BaseConnectorConfigMap):
prompt_on_new=True,
),
)
fee_calculator: Union[tuple(FEE_CALCULATOR_MODES.values())] = Field(
default=InjectiveSimulatedTransactionFeeCalculatorMode(),
client_data=ClientFieldData(
prompt=lambda cm: f"Select the fee calculator ({'/'.join(list(FEE_CALCULATOR_MODES.keys()))})",
prompt_on_new=True,
),
)

class Config:
title = "injective_v2_perpetual"
Expand Down Expand Up @@ -71,11 +80,24 @@ def validate_account_type(cls, v: Union[(str, Dict) + tuple(ACCOUNT_MODES.values
sub_model = ACCOUNT_MODES[v].construct()
return sub_model

@validator("fee_calculator", pre=True)
def validate_fee_calculator(cls, v: Union[(str, Dict) + tuple(FEE_CALCULATOR_MODES.values())]):
if isinstance(v, tuple(FEE_CALCULATOR_MODES.values()) + (Dict,)):
sub_model = v
elif v not in FEE_CALCULATOR_MODES:
raise ValueError(
f"Invalid fee calculator, please choose a value from {list(FEE_CALCULATOR_MODES.keys())}."
)
else:
sub_model = FEE_CALCULATOR_MODES[v].construct()
return sub_model

def create_data_source(self):
return self.account_type.create_data_source(
network=self.network.network(),
use_secure_connection=self.network.use_secure_connection(),
rate_limits=self.network.rate_limits(),
fee_calculator_mode=self.fee_calculator,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

from bidict import bidict
from google.protobuf import any_pb2
from grpc import RpcError
from pyinjective import Transaction
from pyinjective.composer import Composer, injective_exchange_tx_pb
from pyinjective.core.market import DerivativeMarket, SpotMarket
Expand Down Expand Up @@ -104,6 +104,11 @@ def portfolio_account_subaccount_index(self) -> int:
def network_name(self) -> str:
raise NotImplementedError

@property
@abstractmethod
def last_received_message_timestamp(self):
raise NotImplementedError

@abstractmethod
async def composer(self) -> Composer:
raise NotImplementedError
Expand Down Expand Up @@ -319,8 +324,8 @@ async def all_account_balances(self) -> Dict[str, Dict[str, Decimal]]:
async with self.throttler.execute_task(limit_id=CONSTANTS.PORTFOLIO_BALANCES_LIMIT_ID):
portfolio_response = await self.query_executor.account_portfolio(account_address=account_address)

bank_balances = portfolio_response["bankBalances"]
sub_account_balances = portfolio_response.get("subaccounts", [])
bank_balances = portfolio_response["portfolio"]["bankBalances"]
sub_account_balances = portfolio_response["portfolio"].get("subaccounts", [])

balances_dict: Dict[str, Dict[str, Decimal]] = {}

Expand Down Expand Up @@ -639,7 +644,7 @@ async def funding_info(self, market_id: str) -> FundingInfo:
trading_pair=await self.trading_pair_for_market(market_id=market_id),
index_price=last_traded_price, # Use the last traded price as the index_price
mark_price=oracle_price,
next_funding_utc_timestamp=int(updated_market_info["perpetualMarketInfo"]["nextFundingTimestamp"]),
next_funding_utc_timestamp=int(updated_market_info["market"]["perpetualMarketInfo"]["nextFundingTimestamp"]),
rate=funding_rate,
)
return funding_info
Expand Down Expand Up @@ -717,6 +722,10 @@ async def _generate_injective_order_data(self, order: GatewayInFlightOrder, mark
async def _updated_derivative_market_info_for_id(self, market_id: str) -> Dict[str, Any]:
raise NotImplementedError

@abstractmethod
async def _configure_gas_fee_for_transaction(self, transaction: Transaction):
raise NotImplementedError

def _place_order_results(
self,
orders_to_create: List[GatewayInFlightOrder],
Expand Down Expand Up @@ -775,12 +784,15 @@ async def _oracle_price(self, market_id: str) -> Decimal:

return price

def _chain_stream(
async def _listen_chain_stream_updates(
self,
spot_markets: List[InjectiveSpotMarket],
derivative_markets: List[InjectiveDerivativeMarket],
subaccount_ids: List[str],
composer: Composer,
callback: Callable,
on_end_callback: Optional[Callable] = None,
on_status_callback: Optional[Callable] = None,
):
spot_market_ids = [market_info.market_id for market_info in spot_markets]
derivative_market_ids = []
Expand Down Expand Up @@ -820,7 +832,10 @@ def _chain_stream(
positions_filter = None
oracle_price_filter = None

stream = self.query_executor.chain_stream(
await self.query_executor.listen_chain_stream_updates(
callback=callback,
on_end_callback=on_end_callback,
on_status_callback=on_status_callback,
subaccount_deposits_filter=subaccount_deposits_filter,
spot_trades_filter=spot_trades_filter,
derivative_trades_filter=derivative_trades_filter,
Expand All @@ -831,11 +846,18 @@ def _chain_stream(
positions_filter=positions_filter,
oracle_price_filter=oracle_price_filter
)
return stream

def _transactions_stream(self):
stream = self.query_executor.transactions_stream()
return stream
async def _listen_transactions_updates(
self,
callback: Callable,
on_end_callback: Callable,
on_status_callback: Callable,
):
await self.query_executor.listen_transactions_updates(
callback=callback,
on_end_callback=on_end_callback,
on_status_callback=on_status_callback,
)

async def _parse_spot_trade_entry(self, trade_info: Dict[str, Any]) -> TradeUpdate:
exchange_order_id: str = trade_info["orderHash"]
Expand Down Expand Up @@ -963,25 +985,14 @@ async def _send_in_transaction(self, messages: List[any_pb2.Any]) -> Dict[str, A
transaction.with_account_num(await self.trading_account_number())
transaction.with_chain_id(self.injective_chain_id)

signed_transaction_data = self._sign_and_encode(transaction=transaction)

async with self.throttler.execute_task(limit_id=CONSTANTS.SIMULATE_TRANSACTION_LIMIT_ID):
try:
simulation_result = await self.query_executor.simulate_tx(tx_byte=signed_transaction_data)
await self._configure_gas_fee_for_transaction(transaction=transaction)
except RuntimeError as simulation_ex:
if CONSTANTS.ACCOUNT_SEQUENCE_MISMATCH_ERROR in str(simulation_ex):
await self.initialize_trading_account()
raise

composer = await self.composer()
gas_limit = int(simulation_result["gasInfo"]["gasUsed"]) + CONSTANTS.EXTRA_TRANSACTION_GAS
fee = [composer.Coin(
amount=gas_limit * CONSTANTS.DEFAULT_GAS_PRICE,
denom=self.fee_denom,
)]

transaction.with_gas(gas_limit)
transaction.with_fee(fee)
transaction.with_memo("")
transaction.with_timeout_height(await self.timeout_height())

Expand All @@ -995,58 +1006,60 @@ async def _send_in_transaction(self, messages: List[any_pb2.Any]) -> Dict[str, A

return result

def _chain_stream_exception_handler(self, exception: RpcError):
self.logger().warning(f"Error while listening to chain stream ({exception})")

def _chain_stream_closed_handler(self):
self.logger().debug("Reconnecting stream for chain stream")

async def _listen_to_chain_updates(
self,
spot_markets: List[InjectiveSpotMarket],
derivative_markets: List[InjectiveDerivativeMarket],
subaccount_ids: List[str],
):
composer = await self.composer()
await self._listen_stream_events(
stream_provider=partial(
self._chain_stream,

async def _chain_stream_event_handler(event: Dict[str, Any]):
try:
await self._process_chain_stream_update(
chain_stream_update=event, derivative_markets=derivative_markets,
)
except asyncio.CancelledError:
raise
except Exception as ex:
self.logger().warning(f"Invalid chain stream event format ({ex})\n{event}")

while True:
# Running in a cycle to reconnect to the stream after connection errors
await self._listen_chain_stream_updates(
spot_markets=spot_markets,
derivative_markets=derivative_markets,
subaccount_ids=subaccount_ids,
composer=composer
),
event_processor=self._process_chain_stream_update,
event_name_for_errors="chain stream",
spot_markets=spot_markets,
derivative_markets=derivative_markets,
)
composer=composer,
callback=_chain_stream_event_handler,
on_end_callback=self._chain_stream_closed_handler,
on_status_callback=self._chain_stream_exception_handler,
)

async def _listen_to_chain_transactions(self):
await self._listen_stream_events(
stream_provider=self._transactions_stream,
event_processor=self._process_transaction_update,
event_name_for_errors="transaction",
)
def _transaction_stream_exception_handler(self, exception: RpcError):
self.logger().warning(f"Error while listening to transaction stream ({exception})")

async def _listen_stream_events(
self,
stream_provider: Callable,
event_processor: Callable,
event_name_for_errors: str,
**kwargs):
def _transaction_stream_closed_handler(self):
self.logger().debug("Reconnecting stream for transaction stream")

async def _listen_to_chain_transactions(self):
while True:
self.logger().debug(f"Starting stream for {event_name_for_errors}")
try:
stream = stream_provider()
async for event in stream:
try:
await event_processor(event, **kwargs)
except asyncio.CancelledError:
raise
except Exception as ex:
self.logger().warning(f"Invalid {event_name_for_errors} event format ({ex})\n{event}")
except asyncio.CancelledError:
raise
except Exception as ex:
self.logger().error(f"Error while listening to {event_name_for_errors} stream, reconnecting ... ({ex})")
self.logger().debug(f"Reconnecting stream for {event_name_for_errors}")
# Running in a cycle to reconnect to the stream after connection errors
await self._listen_transactions_updates(
callback=self._process_transaction_update,
on_end_callback=self._transaction_stream_closed_handler,
on_status_callback=self._transaction_stream_exception_handler,
)

async def _process_chain_stream_update(self, chain_stream_update: Dict[str, Any], **kwargs):
async def _process_chain_stream_update(
self, chain_stream_update: Dict[str, Any], derivative_markets: List[InjectiveDerivativeMarket],
):
block_height = int(chain_stream_update["blockHeight"])
block_timestamp = int(chain_stream_update["blockTime"]) * 1e-3
tasks = []
Expand Down Expand Up @@ -1129,7 +1142,7 @@ async def _process_chain_stream_update(self, chain_stream_update: Dict[str, Any]
oracle_price_updates=chain_stream_update.get("oraclePrices", []),
block_height=block_height,
block_timestamp=block_timestamp,
derivative_markets=kwargs.get("derivative_markets", [])
derivative_markets=derivative_markets,
)
)
)
Expand Down

0 comments on commit cecc0dd

Please sign in to comment.