diff --git a/elfpy/data/acquire_data.py b/elfpy/data/acquire_data.py index e204bc5e40..b082bce6f5 100644 --- a/elfpy/data/acquire_data.py +++ b/elfpy/data/acquire_data.py @@ -6,9 +6,12 @@ from dotenv import load_dotenv from eth_typing import URI, BlockNumber +from eth_utils import address from web3 import Web3 +from web3.contract.contract import Contract -from elfpy.data import contract_interface, postgres +from elfpy.data import postgres +from elfpy import eth, hyperdrive_interface from elfpy.utils import logs as log_utils # pylint: disable=too-many-arguments @@ -32,18 +35,20 @@ def main( # initialize the postgres session session = postgres.initialize_session() # get web3 provider - web3: Web3 = contract_interface.initialize_web3_with_http_provider(ethereum_node, request_kwargs={"timeout": 60}) + web3: Web3 = eth.initialize_web3_with_http_provider(ethereum_node, request_kwargs={"timeout": 60}) # send a request to the local server to fetch the deployed contract addresses and # all Hyperdrive contract addresses from the server response - addresses = contract_interface.fetch_address_from_url(contracts_url) - abis = contract_interface.load_all_abis(abi_dir) + addresses = hyperdrive_interface.fetch_hyperdrive_address_from_url(contracts_url) + abis = eth.abi.load_all_abis(abi_dir) - hyperdrive_contract = contract_interface.get_hyperdrive_contract(web3, abis, addresses) - base_contract = contract_interface.get_funding_contract(web3, abis, addresses) + hyperdrive_contract = hyperdrive_interface.get_hyperdrive_contract(web3, abis, addresses) + base_contract: Contract = web3.eth.contract( + address=address.to_checksum_address(addresses.base_token), abi=abis["ERC20Mintable"] + ) # get pool config from hyperdrive contract - pool_config = contract_interface.get_hyperdrive_config(hyperdrive_contract) + pool_config = hyperdrive_interface.get_hyperdrive_config(hyperdrive_contract) postgres.add_pool_config(pool_config, session) # Get last entry of pool info in db @@ -65,11 +70,11 @@ def main( # and if the chain has executed until start_block (based on latest_mined_block check) if data_latest_block_number < block_number < latest_mined_block: # Query and add block_pool_info - block_pool_info = contract_interface.get_block_pool_info(web3, hyperdrive_contract, block_number) + block_pool_info = hyperdrive_interface.get_hyperdrive_pool_info(web3, hyperdrive_contract, block_number) postgres.add_pool_infos([block_pool_info], session) # Query and add block transactions - block_transactions = contract_interface.fetch_transactions_for_block(web3, hyperdrive_contract, block_number) + block_transactions = eth.transactions.fetch_transactions_for_block(web3, hyperdrive_contract, block_number) postgres.add_transactions(block_transactions, session) # monitor for new blocks & add pool info per block @@ -97,7 +102,7 @@ def main( block_pool_info = None for _ in range(RETRY_COUNT): try: - block_pool_info = contract_interface.get_block_pool_info( + block_pool_info = hyperdrive_interface.get_hyperdrive_pool_info( web3, hyperdrive_contract, block_number ) break @@ -111,7 +116,7 @@ def main( block_transactions = None for _ in range(RETRY_COUNT): try: - block_transactions = contract_interface.fetch_transactions_for_block( + block_transactions = eth.transactions.fetch_transactions_for_block( web3, hyperdrive_contract, block_number ) break @@ -122,7 +127,7 @@ def main( if block_transactions: postgres.add_transactions(block_transactions, session) - wallet_info_for_transactions = contract_interface.get_wallet_info( + wallet_info_for_transactions = hyperdrive_interface.get_wallet_info( hyperdrive_contract, base_contract, block_number, block_transactions ) postgres.add_wallet_infos(wallet_info_for_transactions, session) diff --git a/elfpy/data/contract_interface.py b/elfpy/data/contract_interface.py deleted file mode 100644 index da6fe34a10..0000000000 --- a/elfpy/data/contract_interface.py +++ /dev/null @@ -1,739 +0,0 @@ -"""Functions and classes for interfacing with smart contracts""" -from __future__ import annotations - -import json -import logging -import os -import re -import time -from datetime import datetime -from typing import Any, Sequence - -import attr -import requests -from eth_account import Account -from eth_account.signers.local import LocalAccount -from eth_typing import URI, BlockNumber, ChecksumAddress -from eth_utils import address -from fixedpointmath import FixedPoint -from hexbytes import HexBytes -from web3 import Web3 -from web3.contract.contract import Contract, ContractEvent, ContractFunction -from web3.middleware import geth_poa -from web3.types import ( - ABI, - ABIEvent, - ABIFunctionComponents, - ABIFunctionParams, - BlockData, - EventData, - LogReceipt, - RPCEndpoint, - RPCResponse, - TxReceipt, -) - -from elfpy.data.db_schema import PoolConfig, PoolInfo, Transaction, WalletInfo -from elfpy.markets.hyperdrive import hyperdrive_assets - -RETRY_COUNT = 10 - - -class TestAccount: - """Web3 account that has helper functions & associated funding source""" - - # TODO: We should be adding more methods to this class. - # If not, we can delete it at the end of the refactor. - # pylint: disable=too-few-public-methods - - def __init__(self, extra_entropy: str = "TEST ACCOUNT"): - """Initialize an account""" - self.account: LocalAccount = Account().create(extra_entropy=extra_entropy) - - @property - def checksum_address(self) -> ChecksumAddress: - """Return the checksum address of the account""" - return Web3.to_checksum_address(self.account.address) - - -@attr.s -class HyperdriveAddressesJson: - """Addresses for deployed Hyperdrive contracts.""" - - # pylint: disable=too-few-public-methods - - base_token: str = attr.ib() - mock_hyperdrive: str = attr.ib() - mock_hyperdrive_math: str = attr.ib() - - -def initialize_web3_with_http_provider(ethereum_node: URI | str, request_kwargs: dict | None = None) -> Web3: - """Initialize a Web3 instance using an HTTP provider and inject a geth Proof of Authority (poa) middleware. - - Arguments - --------- - ethereum_node: URI | str - Address of the http provider - request_kwargs: dict - The HTTPProvider uses the python requests library for making requests. - If you would like to modify how requests are made, - you can use the request_kwargs to do so. - - Notes - ----- - The geth_poa_middleware is required to connect to geth --dev or the Goerli public network. - It may also be needed for other EVM compatible blockchains like Polygon or BNB Chain (Binance Smart Chain). - See more `here `_. - """ - if request_kwargs is None: - request_kwargs = {} - provider = Web3.HTTPProvider(ethereum_node, request_kwargs) - web3 = Web3(provider) - web3.middleware_onion.inject(geth_poa.geth_poa_middleware, layer=0) - return web3 - - -def set_anvil_account_balance(web3: Web3, account_address: str, amount_wei: int) -> RPCResponse: - """Set an the account using the web3 provider - - Arguments - --------- - amount_wei : int - amount_wei to fund, in wei - """ - if not web3.is_checksum_address(account_address): - raise ValueError(f"argument {account_address=} must be a checksum address") - params = [account_address, hex(amount_wei)] # account, amount - rpc_response = web3.provider.make_request(method=RPCEndpoint("anvil_setBalance"), params=params) - return rpc_response - - -def mint_tokens(token_contract: Contract, account_address: str, amount_wei: int) -> HexBytes: - """Add funds to the account - - Arguments - --------- - amount_wei : int - amount_wei to fund, in wei - """ - tx_receipt = token_contract.functions.mint(account_address, amount_wei).transact() - return tx_receipt - - -def get_account_balance_from_provider(web3: Web3, account_address: str) -> int | None: - """Get the balance for an account deployed on the web3 provider""" - if not web3.is_checksum_address(account_address): - raise ValueError(f"argument {account_address=} must be a checksum address") - rpc_response = web3.provider.make_request(method=RPCEndpoint("eth_getBalance"), params=[account_address, "latest"]) - hex_result = rpc_response.get("result") - if hex_result is not None: - return int(hex_result, base=16) # cast hex to int - return None - - -def load_all_abis(abi_folder: str) -> dict: - """Load all ABI jsons given an abi_folder - - Arguments - --------- - abi_folder: str - The local directory that contains all abi json - """ - abis = {} - abi_files = _collect_files(abi_folder) - loaded = [] - for abi_file in abi_files: - file_name = os.path.splitext(os.path.basename(abi_file))[0] - with open(abi_file, mode="r", encoding="UTF-8") as file: - data = json.load(file) - if "abi" in data: - abis[file_name] = data["abi"] - loaded.append(abi_file) - else: - logging.warning("JSON file %s did not contain an ABI", abi_file) - logging.info("Loaded ABI files %s", str(loaded)) - return abis - - -def fetch_and_decode_logs(web3: Web3, contract: Contract, tx_receipt: TxReceipt) -> list[dict[Any, Any]]: - """Decode logs from a transaction receipt""" - logs = [] - if tx_receipt.get("logs"): - for log in tx_receipt["logs"]: - event_data, event = get_event_object(web3, contract, log, tx_receipt) - if event_data and event: - formatted_log = dict(event_data) - formatted_log["event"] = event.get("name") - formatted_log["args"] = dict(event_data["args"]) - logs.append(formatted_log) - return logs - - -def get_event_object( - web3: Web3, contract: Contract, log: LogReceipt, tx_receipt: TxReceipt -) -> tuple[EventData, ABIEvent] | tuple[None, None]: - """Retrieves the event object and anonymous types for a given contract and log""" - abi_events = [abi for abi in contract.abi if abi["type"] == "event"] # type: ignore - for event in abi_events: # type: ignore - # Get event signature components - name = event["name"] # type: ignore - inputs = [param["type"] for param in event["inputs"]] # type: ignore - inputs = ",".join(inputs) - # Hash event signature - event_signature_text = f"{name}({inputs})" - event_signature_hex = web3.keccak(text=event_signature_text).hex() - # Find match between log's event signature and ABI's event signature - receipt_event_signature_hex = log["topics"][0].hex() - if event_signature_hex == receipt_event_signature_hex: - # Decode matching log - contract_event: ContractEvent = contract.events[event["name"]]() # type: ignore - event_data: EventData = contract_event.process_receipt(tx_receipt)[0] - return event_data, event # type: ignore - return (None, None) - - -def contract_function_abi_outputs(contract_abi: ABI, function_name: str) -> list[tuple[str, str]] | None: - """Parse the function abi to get the name and type for each output""" - function_abi = None - # find the first function matching the function_name - for abi in contract_abi: # loop over each entry in the abi list - if abi.get("name") == function_name: # check the name - function_abi = abi # pull out the one with the desired name - break - if function_abi is None: - logging.warning("could not find function_name=%s in contract abi", function_name) - return None - function_outputs = function_abi.get("outputs") - if function_outputs is None: - logging.warning("function abi does not specify outputs") - return None - if not isinstance(function_outputs, Sequence): - logging.warning("function abi outputs are not a sequence") - return None - if len(function_outputs) > 1: # multiple unnamed vars were returned - return_names_and_types = [] - for output in function_outputs: - return_names_and_types.append(_get_name_and_type_from_abi(output)) - if ( - function_outputs[0].get("type") == "tuple" and function_outputs[0].get("components") is not None - ): # multiple named outputs were returned in a struct - abi_components = function_outputs[0].get("components") - if abi_components is None: - logging.warning("function abi output componenets are not a included") - return None - return_names_and_types = [] - for component in abi_components: - return_names_and_types.append(_get_name_and_type_from_abi(component)) - else: # final condition is a single output - return_names_and_types = [_get_name_and_type_from_abi(function_outputs[0])] - return return_names_and_types - - -def smart_contract_read(contract: Contract, function_name: str, *fn_args, **fn_kwargs) -> dict[str, Any]: - """Return from a smart contract read call - - .. todo:: - function to recursively find component names & types - function to dynamically assign types to output variables - would be cool if this also put stuff into FixedPoint - """ - # get the callable contract function from function_name & call it - function: ContractFunction = contract.get_function_by_name(function_name)(*fn_args) # , **fn_kwargs) - return_values = function.call(**fn_kwargs) - if not isinstance(return_values, Sequence): - return_values = [return_values] - if contract.abi: # not all contracts have an associated ABI - return_names_and_types = contract_function_abi_outputs(contract.abi, function_name) - if return_names_and_types is not None: - if len(return_names_and_types) != len(return_values): - raise AssertionError(f"{len(return_names_and_types)=} must equal {len(return_values)=}.") - function_return_dict = dict( - (var_name_and_type[0], var_value) - for var_name_and_type, var_value in zip(return_names_and_types, return_values) - ) - return function_return_dict - return {f"var_{idx}": value for idx, value in enumerate(return_values)} - - -def smart_contract_transact( - web3: Web3, contract: Contract, function_name: str, from_account: TestAccount, *fn_args -) -> TxReceipt: - """Execute a named function on a contract that requires a signature & gas""" - func_handle = contract.get_function_by_name(function_name)(*fn_args) - unsent_txn = func_handle.build_transaction( - { - "from": from_account.checksum_address, - "nonce": web3.eth.get_transaction_count(from_account.checksum_address), - } - ) - signed_txn = from_account.account.sign_transaction(unsent_txn) - tx_hash = web3.eth.send_raw_transaction(signed_txn.rawTransaction) - # wait for approval to complete - tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash) - return tx_receipt - - -def fetch_address_from_url(contracts_url: str) -> HyperdriveAddressesJson: - """Fetch addresses for deployed contracts in the Hyperdrive system.""" - attempt_num = 0 - response = None - while attempt_num < 100: - response = requests.get(contracts_url, timeout=60) - # Check the status code and retry the request if it fails - if response.status_code != 200: - logging.warning("Request failed with status code %s @ %s", response.status_code, time.ctime()) - time.sleep(10) - continue - attempt_num += 1 - if response is None: - raise ConnectionError("Request failed, returning status `None`") - if response.status_code != 200: - raise ConnectionError(f"Request failed with status code {response.status_code} @ {time.ctime()}") - addresses_json = response.json() - addresses = HyperdriveAddressesJson(**{_camel_to_snake(key): value for key, value in addresses_json.items()}) - return addresses - - -def get_hyperdrive_contract(web3: Web3, abis: dict, addresses: HyperdriveAddressesJson) -> Contract: - """Get the hyperdrive contract given abis - - Arguments - --------- - web3: Web3 - web3 provider object - abis: dict - A dictionary that contains all abis keyed by the abi name, returned from `load_all_abis` - addresses: HyperdriveAddressesJson - The block number to query from the chain - - Returns - ------- - Contract - The contract object returned from the query - """ - if "IHyperdrive" not in abis: - raise AssertionError("IHyperdrive ABI was not provided") - state_abi = abis["IHyperdrive"] - # get contract instance of hyperdrive - hyperdrive_contract: Contract = web3.eth.contract( - address=address.to_checksum_address(addresses.mock_hyperdrive), abi=state_abi - ) - return hyperdrive_contract - - -def get_funding_contract(web3: Web3, abis: dict, addresses: HyperdriveAddressesJson) -> Contract: - """Get the funding contract for a given abi - Arguments - --------- - web3: Web3 - web3 provider object - abis: dict - A dictionary that contains all abis keyed by the abi name, returned from `load_all_abis` - addresses: HyperdriveAddressesJson - The block number to query from the chain - - Returns - ------- - Contract - The contract object returned from the query - """ - if "ERC20Mintable" not in abis: - raise AssertionError("ERC20 ABI for minting base tokens was not provided") - state_abi = abis["ERC20Mintable"] - # get contract instance of hyperdrive - hyperdrive_contract: Contract = web3.eth.contract( - address=address.to_checksum_address(addresses.base_token), abi=state_abi - ) - return hyperdrive_contract - - -def fetch_transactions_for_block(web3: Web3, contract: Contract, block_number: BlockNumber) -> list[Transaction]: - """ - Fetch transactions related to the contract - Returns the block pool info from the Hyperdrive contract - - Arguments - --------- - web3: Web3 - web3 provider object - hyperdrive_contract: Contract - The contract to query the pool info from - block_number: BlockNumber - The block number to query from the chain - - Returns - ------- - list[Transaction] - A list of Transaction objects ready to be inserted into Postgres - """ - block: BlockData = web3.eth.get_block(block_number, full_transactions=True) - transactions = block.get("transactions") - if not transactions: - logging.info("no transactions in block %s", block.get("number")) - return [] - out_transactions = [] - for transaction in transactions: - if isinstance(transaction, HexBytes): - logging.warning("transaction HexBytes") - continue - if transaction.get("to") != contract.address: - logging.warning("transaction not from contract") - continue - transaction_dict: dict[str, Any] = dict(transaction) - # Convert the HexBytes fields to their hex representation - tx_hash = transaction.get("hash") or HexBytes("") - transaction_dict["hash"] = tx_hash.hex() - # Decode the transaction input - try: - method, params = contract.decode_function_input(transaction["input"]) - transaction_dict["input"] = {"method": method.fn_name, "params": params} - except ValueError: # if the input is not meant for the contract, ignore it - continue - tx_receipt = web3.eth.get_transaction_receipt(tx_hash) - logs = fetch_and_decode_logs(web3, contract, tx_receipt) - receipt: dict[str, Any] = _recursive_dict_conversion(tx_receipt) # type: ignore - - out_transactions.append(_build_transaction_object(transaction_dict, logs, receipt)) - - return out_transactions - - -def get_block_pool_info(web3: Web3, hyperdrive_contract: Contract, block_number: BlockNumber) -> PoolInfo: - """ - Returns the block pool info from the Hyperdrive contract - - Arguments - --------- - web3: Web3 - web3 provider object - hyperdrive_contract: Contract - The contract to query the pool info from - block_number: BlockNumber - The block number to query from the chain - - Returns - ------- - PoolInfo - A PoolInfo object ready to be inserted into Postgres - """ - pool_info_data_dict = smart_contract_read(hyperdrive_contract, "getPoolInfo", block_identifier=block_number) - pool_info_data_dict: dict[Any, Any] = { - key: _convert_scaled_value(value) for (key, value) in pool_info_data_dict.items() - } - current_block: BlockData = web3.eth.get_block(block_number) - current_block_timestamp = current_block.get("timestamp") - if current_block_timestamp is None: - raise AssertionError("Current block has no timestamp") - pool_info_data_dict.update({"timestamp": current_block_timestamp}) - pool_info_data_dict.update({"blockNumber": block_number}) - pool_info_dict = {} - for key in PoolInfo.__annotations__.keys(): - # Required keys - if key == "timestamp": - pool_info_dict[key] = datetime.fromtimestamp(pool_info_data_dict[key]) - elif key == "blockNumber": - pool_info_dict[key] = pool_info_data_dict[key] - # Otherwise default to None if not exist - else: - pool_info_dict[key] = pool_info_data_dict.get(key, None) - # Populating the dataclass from the dictionary - pool_info = PoolInfo(**pool_info_dict) - return pool_info - - -def get_hyperdrive_config(hyperdrive_contract: Contract) -> PoolConfig: - """Get the hyperdrive config from a deployed hyperdrive contract. - - Arguments - ---------- - hyperdrive_contract : Contract - The deployed hyperdrive contract instance. - - Returns - ------- - hyperdrive_config : PoolConfig - The hyperdrive config. - """ - - hyperdrive_config: dict[str, Any] = smart_contract_read(hyperdrive_contract, "getPoolConfig") - - out_config = {} - out_config["contractAddress"] = hyperdrive_contract.address - out_config["baseToken"] = hyperdrive_config.get("baseToken", None) - out_config["initializeSharePrice"] = _convert_scaled_value(hyperdrive_config.get("initializeSharePrice", None)) - out_config["positionDuration"] = hyperdrive_config.get("positionDuration", None) - out_config["checkpointDuration"] = hyperdrive_config.get("checkpointDuration", None) - config_time_stretch = hyperdrive_config.get("timeStretch", None) - if config_time_stretch: - fp_time_stretch = FixedPoint(scaled_value=config_time_stretch) - time_stretch = float(fp_time_stretch) - inv_time_stretch = float(1 / fp_time_stretch) - else: - time_stretch = None - inv_time_stretch = None - out_config["timeStretch"] = time_stretch - out_config["governance"] = hyperdrive_config.get("governance", None) - out_config["feeCollector"] = hyperdrive_config.get("feeCollector", None) - curve_fee, flat_fee, governance_fee = hyperdrive_config.get("fees", (None, None, None)) - out_config["curveFee"] = _convert_scaled_value(curve_fee) - out_config["flatFee"] = _convert_scaled_value(flat_fee) - out_config["governanceFee"] = _convert_scaled_value(governance_fee) - out_config["oracleSize"] = hyperdrive_config.get("oracleSize", None) - out_config["updateGap"] = hyperdrive_config.get("updateGap", None) - out_config["invTimeStretch"] = inv_time_stretch - if out_config["positionDuration"] is not None: - term_length = out_config["positionDuration"] / 60 / 60 / 24 # in days - else: - term_length = None - out_config["termLength"] = term_length - - return PoolConfig(**out_config) - - -def get_wallet_info( - hyperdrive_contract: Contract, - base_contract: Contract, - block_number: BlockNumber, - transactions: list[Transaction], -) -> list[WalletInfo]: - """Retrieves wallet information at a given block given a transaction - Transactions are needed here to get - (1) the wallet address of a transaction, and - (2) the token id of the transaction - - Arguments - ---------- - hyperdrive_contract : Contract - The deployed hyperdrive contract instance. - base_contract : Contract - The deployed base contract instance - block_number : BlockNumber - The block number to query - transactions : list[Transaction] - The list of transactions to get events from - - Returns - ------- - list[WalletInfo] - The list of WalletInfo objects ready to be inserted into postgres - """ - - # pylint: disable=too-many-locals - - out_wallet_info = [] - for transaction in transactions: - wallet_addr = transaction.event_operator - token_id = transaction.event_id - token_prefix = transaction.event_prefix - token_maturity_time = transaction.event_maturity_time - - if wallet_addr is None: - continue - - num_base_token_scaled = None - for _ in range(RETRY_COUNT): - try: - num_base_token_scaled = base_contract.functions.balanceOf(wallet_addr).call( - block_identifier=block_number - ) - break - except ValueError: - logging.warning("Error in getting base token balance, retrying") - time.sleep(1) - continue - - num_base_token = _convert_scaled_value(num_base_token_scaled) - if (num_base_token is not None) and (wallet_addr is not None): - out_wallet_info.append( - WalletInfo( - blockNumber=block_number, - walletAddress=wallet_addr, - baseTokenType="BASE", - tokenType="BASE", - tokenValue=num_base_token, - ) - ) - - # Handle cases where these fields don't exist - if (token_id is not None) and (token_prefix is not None): - base_token_type = hyperdrive_assets.AssetIdPrefix(token_prefix).name - if (token_maturity_time is not None) and (token_maturity_time > 0): - token_type = base_token_type + "-" + str(token_maturity_time) - maturity_time = token_maturity_time - else: - token_type = base_token_type - maturity_time = None - - num_custom_token_scaled = None - for _ in range(RETRY_COUNT): - try: - num_custom_token_scaled = hyperdrive_contract.functions.balanceOf(int(token_id), wallet_addr).call( - block_identifier=block_number - ) - except ValueError: - logging.warning("Error in getting custom token balance, retrying") - time.sleep(1) - continue - num_custom_token = _convert_scaled_value(num_custom_token_scaled) - - if num_custom_token is not None: - out_wallet_info.append( - WalletInfo( - blockNumber=block_number, - walletAddress=wallet_addr, - baseTokenType=base_token_type, - tokenType=token_type, - tokenValue=num_custom_token, - maturityTime=maturity_time, - ) - ) - - return out_wallet_info - - -def _convert_scaled_value(input_val: int | None) -> float | None: - """ - Given a scaled value int, converts it to an unscaled value in float, while dealing with Nones - - Arguments - ---------- - input_val: int | None - The scaled integer value to unscale and convert to float - - Returns - ------- - float | None - The unscaled floating point value - """ - - # We cast to FixedPoint, then to floats to keep noise to a minimum - # This is assuming there's no loss of precision going from Fixedpoint to float - # Once this gets fed into postgres, postgres has fixed precision Numeric type - if input is not None: - return float(FixedPoint(scaled_value=input_val)) - return None - - -def _build_transaction_object( - transaction_dict: dict[str, Any], - logs: list[dict[str, Any]], - receipt: dict[str, Any], -) -> Transaction: - """ - Conversion function to translate output of chain queries to the Transaction object - - Arguments - ---------- - transaction_dict : dict[str, Any] - A dictionary representing the decoded transactions from the query - logs: list[str, Any] - A dictionary representing the decoded logs from the query - receipt: dict[str, Any] - A dictionary representing the transaction receipt from the query - - Returns - ------- - Transaction - A transaction object to be inserted into postgres - """ - - # Build output obj dict incrementally to be passed into Transaction - # i.e., Transaction(**out_dict) - - # Base transaction fields - out_dict: dict[str, Any] = { - "blockNumber": transaction_dict["blockNumber"], - "transactionIndex": transaction_dict["transactionIndex"], - "nonce": transaction_dict["nonce"], - "transactionHash": transaction_dict["hash"], - "txn_to": transaction_dict["to"], - "txn_from": transaction_dict["from"], - "gasUsed": receipt["gasUsed"], - } - - # Input solidity methods and parameters - # TODO can the input field ever be empty or not exist? - out_dict["input_method"] = transaction_dict["input"]["method"] - input_params = transaction_dict["input"]["params"] - out_dict["input_params_contribution"] = _convert_scaled_value(input_params.get("_contribution", None)) - out_dict["input_params_apr"] = _convert_scaled_value(input_params.get("_apr", None)) - out_dict["input_params_destination"] = input_params.get("_destination", None) - out_dict["input_params_asUnderlying"] = input_params.get("_asUnderlying", None) - out_dict["input_params_baseAmount"] = _convert_scaled_value(input_params.get("_baseAmount", None)) - out_dict["input_params_minOutput"] = _convert_scaled_value(input_params.get("_minOutput", None)) - out_dict["input_params_bondAmount"] = _convert_scaled_value(input_params.get("_bondAmount", None)) - out_dict["input_params_maxDeposit"] = _convert_scaled_value(input_params.get("_maxDeposit", None)) - out_dict["input_params_maturityTime"] = input_params.get("_maturityTime", None) - out_dict["input_params_minApr"] = _convert_scaled_value(input_params.get("_minApr", None)) - out_dict["input_params_maxApr"] = _convert_scaled_value(input_params.get("_maxApr", None)) - out_dict["input_params_shares"] = _convert_scaled_value(input_params.get("_shares", None)) - - # Assuming one TransferSingle per transfer - # TODO Fix this below eventually - # There can be two transfer singles - # Currently grab first transfer single (e.g., Minting hyperdrive long, so address 0 to agent) - # Eventually need grabbing second transfer single (e.g., DAI from agent to hyperdrive) - event_logs = [log for log in logs if log["event"] == "TransferSingle"] - if len(event_logs) == 0: - event_args: dict[str, Any] = {} - # Set args as None - elif len(event_logs) == 1: - event_args: dict[str, Any] = event_logs[0]["args"] - else: - logging.warning("Tranfer event contains multiple TransferSingle logs, selecting first") - event_args: dict[str, Any] = event_logs[0]["args"] - - out_dict["event_value"] = _convert_scaled_value(event_args.get("value", None)) - out_dict["event_from"] = event_args.get("from", None) - out_dict["event_to"] = event_args.get("to", None) - out_dict["event_operator"] = event_args.get("operator", None) - out_dict["event_id"] = event_args.get("id", None) - - # Decode logs here - if out_dict["event_id"] is not None: - event_prefix, event_maturity_time = hyperdrive_assets.decode_asset_id(out_dict["event_id"]) - out_dict["event_prefix"] = event_prefix - out_dict["event_maturity_time"] = event_maturity_time - - transaction = Transaction(**out_dict) - - return transaction - - -def _recursive_dict_conversion(obj): - """Recursively converts a dictionary to convert objects to hex values""" - if isinstance(obj, HexBytes): - return obj.hex() - if isinstance(obj, dict): - return {key: _recursive_dict_conversion(value) for key, value in obj.items()} - if hasattr(obj, "items"): - return {key: _recursive_dict_conversion(value) for key, value in obj.items()} - return obj - - -def _camel_to_snake(camel_string: str) -> str: - """Convert camelCase to snake_case""" - snake_string = re.sub(r"(? list[str]: - """Load all files with the given extension into a list""" - collected_files = [] - for root, _, files in os.walk(folder_path): - for file in files: - if file.endswith(extension): - file_path = os.path.join(root, file) - collected_files.append(file_path) - return collected_files - - -def _get_name_and_type_from_abi(abi_outputs: ABIFunctionComponents | ABIFunctionParams) -> tuple[str, str]: - """Retrieve and narrow the types for abi outputs""" - return_value_name: str | None = abi_outputs.get("name") - if return_value_name is None: - return_value_name = "none" - return_value_type: str | None = abi_outputs.get("type") - if return_value_type is None: - return_value_type = "none" - return (return_value_name, return_value_type) diff --git a/elfpy/errors/__init__.py b/elfpy/errors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/elfpy/eth/__init__.py b/elfpy/eth/__init__.py new file mode 100644 index 0000000000..c3ea5342cc --- /dev/null +++ b/elfpy/eth/__init__.py @@ -0,0 +1,8 @@ +"""Interface functions""" + +from . import abi +from . import accounts +from .numeric_utils import convert_scaled_value +from .rpc_interface import set_anvil_account_balance, get_account_balance_from_provider +from .transactions import smart_contract_read, smart_contract_transact, fetch_transactions_for_block +from .web3_setup import initialize_web3_with_http_provider diff --git a/elfpy/eth/abi/__init__.py b/elfpy/eth/abi/__init__.py new file mode 100644 index 0000000000..2aa52056c3 --- /dev/null +++ b/elfpy/eth/abi/__init__.py @@ -0,0 +1,2 @@ +"""ABI related functions and classes""" +from .load_all_abis import load_all_abis diff --git a/elfpy/eth/abi/load_all_abis.py b/elfpy/eth/abi/load_all_abis.py new file mode 100644 index 0000000000..485a280444 --- /dev/null +++ b/elfpy/eth/abi/load_all_abis.py @@ -0,0 +1,41 @@ +"""Load all abis""" +from __future__ import annotations + +import json +import logging +import os + + +def load_all_abis(abi_folder: str) -> dict: + """Load all ABI jsons given an abi_folder + + Arguments + --------- + abi_folder: str + The local directory that contains all abi json + """ + abis = {} + abi_files = _collect_files(abi_folder) + loaded = [] + for abi_file in abi_files: + file_name = os.path.splitext(os.path.basename(abi_file))[0] + with open(abi_file, mode="r", encoding="UTF-8") as file: + data = json.load(file) + if "abi" in data: + abis[file_name] = data["abi"] + loaded.append(abi_file) + else: + logging.warning("JSON file %s did not contain an ABI", abi_file) + logging.info("Loaded ABI files %s", str(loaded)) + return abis + + +def _collect_files(folder_path: str, extension: str = ".json") -> list[str]: + """Load all files with the given extension into a list""" + collected_files = [] + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith(extension): + file_path = os.path.join(root, file) + collected_files.append(file_path) + return collected_files diff --git a/elfpy/eth/accounts/__init__.py b/elfpy/eth/accounts/__init__.py new file mode 100644 index 0000000000..3927ccbe38 --- /dev/null +++ b/elfpy/eth/accounts/__init__.py @@ -0,0 +1,3 @@ +"""Helper utilities for creating and managing accounts""" + +from .agent_account import AgentAccount diff --git a/elfpy/eth/accounts/agent_account.py b/elfpy/eth/accounts/agent_account.py new file mode 100644 index 0000000000..7ca1916cb7 --- /dev/null +++ b/elfpy/eth/accounts/agent_account.py @@ -0,0 +1,24 @@ +"""Empty test accounts for testing smart contracts""" +from __future__ import annotations + +from eth_account import Account +from eth_account.signers.local import LocalAccount +from eth_typing import ChecksumAddress +from web3 import Web3 + + +class AgentAccount: + """Web3 account that has helper functions & associated funding source""" + + # TODO: We should be adding more methods to this class. + # If not, we can delete it at the end of the refactor. + # pylint: disable=too-few-public-methods + + def __init__(self, extra_entropy: str = "TEST ACCOUNT"): + """Initialize an account""" + self.account: LocalAccount = Account().create(extra_entropy=extra_entropy) + + @property + def checksum_address(self) -> ChecksumAddress: + """Return the checksum address of the account""" + return Web3.to_checksum_address(self.account.address) diff --git a/elfpy/eth/numeric_utils.py b/elfpy/eth/numeric_utils.py new file mode 100644 index 0000000000..9b448ab64e --- /dev/null +++ b/elfpy/eth/numeric_utils.py @@ -0,0 +1,29 @@ +"""Utilities to convert solidity numbers to python numbers""" +from __future__ import annotations + +from fixedpointmath import FixedPoint + + +def convert_scaled_value(input_val: int | None) -> float | None: + """ + Given a scaled value int, converts it to an unscaled value in float, while dealing with Nones + + Arguments + ---------- + input_val: int | None + The scaled integer value to unscale and convert to float + + Returns + ------- + float | None + The unscaled floating point value + + Note + ---- + We cast to FixedPoint, then to floats to keep noise to a minimum. + There is no loss of precision when going from Fixedpoint to float. + Once this is fed into postgres, postgres will use the fixed-precision Numeric type. + """ + if input_val is not None: + return float(FixedPoint(scaled_value=input_val)) + return None diff --git a/elfpy/eth/rpc_interface.py b/elfpy/eth/rpc_interface.py new file mode 100644 index 0000000000..3414dd569e --- /dev/null +++ b/elfpy/eth/rpc_interface.py @@ -0,0 +1,31 @@ +"""Functions for interfacing with the anvil or ethereum RPC endpoint""" +from __future__ import annotations + +from web3 import Web3 +from web3.types import RPCEndpoint, RPCResponse + + +def set_anvil_account_balance(web3: Web3, account_address: str, amount_wei: int) -> RPCResponse: + """Set an the account using the web3 provider + + Arguments + --------- + amount_wei : int + amount_wei to fund, in wei + """ + if not web3.is_checksum_address(account_address): + raise ValueError(f"argument {account_address=} must be a checksum address") + params = [account_address, hex(amount_wei)] # account, amount + rpc_response = web3.provider.make_request(method=RPCEndpoint("anvil_setBalance"), params=params) + return rpc_response + + +def get_account_balance_from_provider(web3: Web3, account_address: str) -> int | None: + """Get the balance for an account deployed on the web3 provider""" + if not web3.is_checksum_address(account_address): + raise ValueError(f"argument {account_address=} must be a checksum address") + rpc_response = web3.provider.make_request(method=RPCEndpoint("eth_getBalance"), params=[account_address, "latest"]) + hex_result = rpc_response.get("result") + if hex_result is not None: + return int(hex_result, base=16) # cast hex to int + return None diff --git a/elfpy/eth/transactions.py b/elfpy/eth/transactions.py new file mode 100644 index 0000000000..203ce48ba7 --- /dev/null +++ b/elfpy/eth/transactions.py @@ -0,0 +1,297 @@ +"""Web3 powered functions for interfacing with smart contracts""" +from __future__ import annotations + +import logging + +from typing import Any, Sequence + +from eth_typing import BlockNumber +from hexbytes import HexBytes +from web3 import Web3 +from web3.contract.contract import Contract, ContractEvent, ContractFunction +from web3.types import ( + ABI, + ABIFunctionComponents, + ABIFunctionParams, + ABIEvent, + BlockData, + EventData, + LogReceipt, + TxReceipt, +) + +from elfpy.data.db_schema import Transaction +from elfpy.markets.hyperdrive import hyperdrive_assets + +from .accounts import AgentAccount +from .numeric_utils import convert_scaled_value + + +def smart_contract_read(contract: Contract, function_name: str, *fn_args, **fn_kwargs) -> dict[str, Any]: + """Return from a smart contract read call + + .. todo:: + function to recursively find component names & types + function to dynamically assign types to output variables + would be cool if this also put stuff into FixedPoint + """ + # get the callable contract function from function_name & call it + function: ContractFunction = contract.get_function_by_name(function_name)(*fn_args) # , **fn_kwargs) + return_values = function.call(**fn_kwargs) + if not isinstance(return_values, Sequence): # could be list or tuple + return_values = [return_values] + if contract.abi: # not all contracts have an associated ABI + return_names_and_types = _contract_function_abi_outputs(contract.abi, function_name) + if return_names_and_types is not None: + if len(return_names_and_types) != len(return_values): + raise AssertionError( + f"{len(return_names_and_types)=} must equal {len(return_values)=}." + f"\n{return_names_and_types=}\n{return_values=}" + ) + function_return_dict = dict( + (var_name_and_type[0], var_value) + for var_name_and_type, var_value in zip(return_names_and_types, return_values) + ) + return function_return_dict + return {f"var_{idx}": value for idx, value in enumerate(return_values)} + + +def smart_contract_transact( + web3: Web3, contract: Contract, function_name: str, from_account: AgentAccount, *fn_args +) -> TxReceipt: + """Execute a named function on a contract that requires a signature & gas""" + func_handle = contract.get_function_by_name(function_name)(*fn_args) + unsent_txn = func_handle.build_transaction( + { + "from": from_account.checksum_address, + "nonce": web3.eth.get_transaction_count(from_account.checksum_address), + } + ) + signed_txn = from_account.account.sign_transaction(unsent_txn) + tx_hash = web3.eth.send_raw_transaction(signed_txn.rawTransaction) + # wait for approval to complete + tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash) + return tx_receipt + + +def fetch_transactions_for_block(web3: Web3, contract: Contract, block_number: BlockNumber) -> list[Transaction]: + """ + Fetch transactions related to the contract + Returns the block pool info from the Hyperdrive contract + + Arguments + --------- + web3: Web3 + web3 provider object + hyperdrive_contract: Contract + The contract to query the pool info from + block_number: BlockNumber + The block number to query from the chain + + Returns + ------- + list[Transaction] + A list of Transaction objects ready to be inserted into Postgres + """ + block: BlockData = web3.eth.get_block(block_number, full_transactions=True) + transactions = block.get("transactions") + if not transactions: + logging.info("no transactions in block %s", block.get("number")) + return [] + out_transactions = [] + for transaction in transactions: + if isinstance(transaction, HexBytes): + logging.warning("transaction HexBytes") + continue + if transaction.get("to") != contract.address: + logging.warning("transaction not from contract") + continue + transaction_dict: dict[str, Any] = dict(transaction) + # Convert the HexBytes fields to their hex representation + tx_hash = transaction.get("hash") or HexBytes("") + transaction_dict["hash"] = tx_hash.hex() + # Decode the transaction input + try: + method, params = contract.decode_function_input(transaction["input"]) + transaction_dict["input"] = {"method": method.fn_name, "params": params} + except ValueError: # if the input is not meant for the contract, ignore it + continue + tx_receipt = web3.eth.get_transaction_receipt(tx_hash) + logs = _fetch_and_decode_logs(web3, contract, tx_receipt) + receipt: dict[str, Any] = _recursive_dict_conversion(tx_receipt) # type: ignore + out_transactions.append(_build_transaction_object(transaction_dict, logs, receipt)) + return out_transactions + + +def _get_name_and_type_from_abi(abi_outputs: ABIFunctionComponents | ABIFunctionParams) -> tuple[str, str]: + """Retrieve and narrow the types for abi outputs""" + return_value_name: str | None = abi_outputs.get("name") + if return_value_name is None: + return_value_name = "none" + return_value_type: str | None = abi_outputs.get("type") + if return_value_type is None: + return_value_type = "none" + return (return_value_name, return_value_type) + + +def _contract_function_abi_outputs(contract_abi: ABI, function_name: str) -> list[tuple[str, str]] | None: + """Parse the function abi to get the name and type for each output""" + function_abi = None + # find the first function matching the function_name + for abi in contract_abi: # loop over each entry in the abi list + if abi.get("name") == function_name: # check the name + function_abi = abi # pull out the one with the desired name + break + if function_abi is None: + logging.warning("could not find function_name=%s in contract abi", function_name) + return None + function_outputs = function_abi.get("outputs") + if function_outputs is None: + logging.warning("function abi does not specify outputs") + return None + if not isinstance(function_outputs, Sequence): # could be list or tuple + logging.warning("function abi outputs are not a sequence") + return None + if len(function_outputs) > 1: # multiple unnamed vars were returned + return_names_and_types = [] + for output in function_outputs: + return_names_and_types.append(_get_name_and_type_from_abi(output)) + if ( + function_outputs[0].get("type") == "tuple" and function_outputs[0].get("components") is not None + ): # multiple named outputs were returned in a struct + abi_components = function_outputs[0].get("components") + if abi_components is None: + logging.warning("function abi output componenets are not a included") + return None + return_names_and_types = [] + for component in abi_components: + return_names_and_types.append(_get_name_and_type_from_abi(component)) + else: # final condition is a single output + return_names_and_types = [_get_name_and_type_from_abi(function_outputs[0])] + return return_names_and_types + + +def _recursive_dict_conversion(obj): + """Recursively converts a dictionary to convert objects to hex values""" + if isinstance(obj, HexBytes): + return obj.hex() + if isinstance(obj, dict): + return {key: _recursive_dict_conversion(value) for key, value in obj.items()} + if hasattr(obj, "items"): + return {key: _recursive_dict_conversion(value) for key, value in obj.items()} + return obj + + +def _build_transaction_object( + transaction_dict: dict[str, Any], + logs: list[dict[str, Any]], + receipt: dict[str, Any], +) -> Transaction: + """ + Conversion function to translate output of chain queries to the Transaction object + + Arguments + ---------- + transaction_dict : dict[str, Any] + A dictionary representing the decoded transactions from the query + logs: list[str, Any] + A dictionary representing the decoded logs from the query + receipt: dict[str, Any] + A dictionary representing the transaction receipt from the query + + Returns + ------- + Transaction + A transaction object to be inserted into postgres + """ + # Build output obj dict incrementally to be passed into Transaction + # i.e., Transaction(**out_dict) + # Base transaction fields + out_dict: dict[str, Any] = { + "blockNumber": transaction_dict["blockNumber"], + "transactionIndex": transaction_dict["transactionIndex"], + "nonce": transaction_dict["nonce"], + "transactionHash": transaction_dict["hash"], + "txn_to": transaction_dict["to"], + "txn_from": transaction_dict["from"], + "gasUsed": receipt["gasUsed"], + } + # Input solidity methods and parameters + # TODO can the input field ever be empty or not exist? + out_dict["input_method"] = transaction_dict["input"]["method"] + input_params = transaction_dict["input"]["params"] + out_dict["input_params_contribution"] = convert_scaled_value(input_params.get("_contribution", None)) + out_dict["input_params_apr"] = convert_scaled_value(input_params.get("_apr", None)) + out_dict["input_params_destination"] = input_params.get("_destination", None) + out_dict["input_params_asUnderlying"] = input_params.get("_asUnderlying", None) + out_dict["input_params_baseAmount"] = convert_scaled_value(input_params.get("_baseAmount", None)) + out_dict["input_params_minOutput"] = convert_scaled_value(input_params.get("_minOutput", None)) + out_dict["input_params_bondAmount"] = convert_scaled_value(input_params.get("_bondAmount", None)) + out_dict["input_params_maxDeposit"] = convert_scaled_value(input_params.get("_maxDeposit", None)) + out_dict["input_params_maturityTime"] = input_params.get("_maturityTime", None) + out_dict["input_params_minApr"] = convert_scaled_value(input_params.get("_minApr", None)) + out_dict["input_params_maxApr"] = convert_scaled_value(input_params.get("_maxApr", None)) + out_dict["input_params_shares"] = convert_scaled_value(input_params.get("_shares", None)) + # Assuming one TransferSingle per transfer + # TODO Fix this below eventually + # There can be two transfer singles + # Currently grab first transfer single (e.g., Minting hyperdrive long, so address 0 to agent) + # Eventually need grabbing second transfer single (e.g., DAI from agent to hyperdrive) + event_logs = [log for log in logs if log["event"] == "TransferSingle"] + if len(event_logs) == 0: + event_args: dict[str, Any] = {} + # Set args as None + elif len(event_logs) == 1: + event_args: dict[str, Any] = event_logs[0]["args"] + else: + logging.warning("Tranfer event contains multiple TransferSingle logs, selecting first") + event_args: dict[str, Any] = event_logs[0]["args"] + out_dict["event_value"] = convert_scaled_value(event_args.get("value", None)) + out_dict["event_from"] = event_args.get("from", None) + out_dict["event_to"] = event_args.get("to", None) + out_dict["event_operator"] = event_args.get("operator", None) + out_dict["event_id"] = event_args.get("id", None) + # Decode logs here + if out_dict["event_id"] is not None: + event_prefix, event_maturity_time = hyperdrive_assets.decode_asset_id(out_dict["event_id"]) + out_dict["event_prefix"] = event_prefix + out_dict["event_maturity_time"] = event_maturity_time + transaction = Transaction(**out_dict) + return transaction + + +def _fetch_and_decode_logs(web3: Web3, contract: Contract, tx_receipt: TxReceipt) -> list[dict[Any, Any]]: + """Decode logs from a transaction receipt""" + logs = [] + if tx_receipt.get("logs"): + for log in tx_receipt["logs"]: + event_data, event = _get_event_object(web3, contract, log, tx_receipt) + if event_data and event: + formatted_log = dict(event_data) + formatted_log["event"] = event.get("name") + formatted_log["args"] = dict(event_data["args"]) + logs.append(formatted_log) + return logs + + +def _get_event_object( + web3: Web3, contract: Contract, log: LogReceipt, tx_receipt: TxReceipt +) -> tuple[EventData, ABIEvent] | tuple[None, None]: + """Retrieves the event object and anonymous types for a given contract and log""" + abi_events = [abi for abi in contract.abi if abi["type"] == "event"] # type: ignore + for event in abi_events: # type: ignore + # Get event signature components + name = event["name"] # type: ignore + inputs = [param["type"] for param in event["inputs"]] # type: ignore + inputs = ",".join(inputs) + # Hash event signature + event_signature_text = f"{name}({inputs})" + event_signature_hex = web3.keccak(text=event_signature_text).hex() + # Find match between log's event signature and ABI's event signature + receipt_event_signature_hex = log["topics"][0].hex() + if event_signature_hex == receipt_event_signature_hex: + # Decode matching log + contract_event: ContractEvent = contract.events[event["name"]]() # type: ignore + event_data: EventData = contract_event.process_receipt(tx_receipt)[0] + return event_data, event # type: ignore + return (None, None) diff --git a/elfpy/eth/web3_setup.py b/elfpy/eth/web3_setup.py new file mode 100644 index 0000000000..f24c1d34bd --- /dev/null +++ b/elfpy/eth/web3_setup.py @@ -0,0 +1,32 @@ +"""Functions and classes for setting up a web3py interface""" +from __future__ import annotations + +from eth_typing import URI +from web3 import Web3 +from web3.middleware import geth_poa + + +def initialize_web3_with_http_provider(ethereum_node: URI | str, request_kwargs: dict | None = None) -> Web3: + """Initialize a Web3 instance using an HTTP provider and inject a geth Proof of Authority (poa) middleware. + + Arguments + --------- + ethereum_node: URI | str + Address of the http provider + request_kwargs: dict + The HTTPProvider uses the python requests library for making requests. + If you would like to modify how requests are made, + you can use the request_kwargs to do so. + + Notes + ----- + The geth_poa_middleware is required to connect to geth --dev or the Goerli public network. + It may also be needed for other EVM compatible blockchains like Polygon or BNB Chain (Binance Smart Chain). + See more `here `_. + """ + if request_kwargs is None: + request_kwargs = {} + provider = Web3.HTTPProvider(ethereum_node, request_kwargs) + web3 = Web3(provider) + web3.middleware_onion.inject(geth_poa.geth_poa_middleware, layer=0) + return web3 diff --git a/elfpy/hyperdrive_interface/__init__.py b/elfpy/hyperdrive_interface/__init__.py new file mode 100644 index 0000000000..dbf5f40e32 --- /dev/null +++ b/elfpy/hyperdrive_interface/__init__.py @@ -0,0 +1,10 @@ +"""Interfaces for elfpy bots and hyperdrive smart contracts""" + +from .hyperdrive_addresses import HyperdriveAddresses +from .hyperdrive_interface import ( + fetch_hyperdrive_address_from_url, + get_hyperdrive_contract, + get_hyperdrive_pool_info, + get_hyperdrive_config, + get_wallet_info, +) diff --git a/elfpy/hyperdrive_interface/hyperdrive_addresses.py b/elfpy/hyperdrive_interface/hyperdrive_addresses.py new file mode 100644 index 0000000000..44876fa63c --- /dev/null +++ b/elfpy/hyperdrive_interface/hyperdrive_addresses.py @@ -0,0 +1,15 @@ +"""Helper class for storing Hyperdrive addresses""" +from __future__ import annotations + +import attr + + +@attr.s +class HyperdriveAddresses: + """Addresses for deployed Hyperdrive contracts.""" + + # pylint: disable=too-few-public-methods + + base_token: str = attr.ib() + mock_hyperdrive: str = attr.ib() + mock_hyperdrive_math: str = attr.ib() diff --git a/elfpy/hyperdrive_interface/hyperdrive_interface.py b/elfpy/hyperdrive_interface/hyperdrive_interface.py new file mode 100644 index 0000000000..8d0acb1799 --- /dev/null +++ b/elfpy/hyperdrive_interface/hyperdrive_interface.py @@ -0,0 +1,259 @@ +"""Helper functions for interfacing with hyperdrive""" +from __future__ import annotations + +import logging +import re +import time + +from datetime import datetime +from typing import Any + +import requests + +from eth_typing import BlockNumber +from eth_utils import address +from fixedpointmath import FixedPoint +from web3.contract.contract import Contract +from web3 import Web3 +from web3.types import BlockData + +from elfpy import eth +from elfpy.data.db_schema import PoolConfig, PoolInfo, Transaction, WalletInfo +from elfpy.markets.hyperdrive import hyperdrive_assets + +from .hyperdrive_addresses import HyperdriveAddresses + +RETRY_COUNT = 10 + + +def fetch_hyperdrive_address_from_url(contracts_url: str) -> HyperdriveAddresses: + """Fetch addresses for deployed contracts in the Hyperdrive system.""" + attempt_num = 0 + response = None + while attempt_num < 100: + response = requests.get(contracts_url, timeout=60) + # Check the status code and retry the request if it fails + if response.status_code != 200: + logging.warning("Request failed with status code %s @ %s", response.status_code, time.ctime()) + time.sleep(10) + continue + attempt_num += 1 + if response is None: + raise ConnectionError("Request failed, returning status `None`") + if response.status_code != 200: + raise ConnectionError(f"Request failed with status code {response.status_code} @ {time.ctime()}") + addresses_json = response.json() + + def camel_to_snake(snake_string: str) -> str: + return re.sub(r"(? Contract: + """Get the hyperdrive contract given abis + + Arguments + --------- + web3: Web3 + web3 provider object + abis: dict + A dictionary that contains all abis keyed by the abi name, returned from `load_all_abis` + addresses: HyperdriveAddressesJson + The block number to query from the chain + + Returns + ------- + Contract + The contract object returned from the query + """ + if "IHyperdrive" not in abis: + raise AssertionError("IHyperdrive ABI was not provided") + state_abi = abis["IHyperdrive"] + # get contract instance of hyperdrive + hyperdrive_contract: Contract = web3.eth.contract( + address=address.to_checksum_address(addresses.mock_hyperdrive), abi=state_abi + ) + return hyperdrive_contract + + +def get_hyperdrive_pool_info(web3: Web3, hyperdrive_contract: Contract, block_number: BlockNumber) -> PoolInfo: + """ + Returns the block pool info from the Hyperdrive contract + + Arguments + --------- + web3: Web3 + web3 provider object + hyperdrive_contract: Contract + The contract to query the pool info from + block_number: BlockNumber + The block number to query from the chain + + Returns + ------- + PoolInfo + A PoolInfo object ready to be inserted into Postgres + """ + pool_info_data_dict = eth.smart_contract_read(hyperdrive_contract, "getPoolInfo", block_identifier=block_number) + pool_info_data_dict: dict[Any, Any] = { + key: eth.convert_scaled_value(value) for (key, value) in pool_info_data_dict.items() + } + current_block: BlockData = web3.eth.get_block(block_number) + current_block_timestamp = current_block.get("timestamp") + if current_block_timestamp is None: + raise AssertionError("Current block has no timestamp") + pool_info_data_dict.update({"timestamp": current_block_timestamp}) + pool_info_data_dict.update({"blockNumber": block_number}) + pool_info_dict = {} + for key in PoolInfo.__annotations__.keys(): + # Required keys + if key == "timestamp": + pool_info_dict[key] = datetime.fromtimestamp(pool_info_data_dict[key]) + elif key == "blockNumber": + pool_info_dict[key] = pool_info_data_dict[key] + # Otherwise default to None if not exist + else: + pool_info_dict[key] = pool_info_data_dict.get(key, None) + # Populating the dataclass from the dictionary + pool_info = PoolInfo(**pool_info_dict) + return pool_info + + +def get_hyperdrive_config(hyperdrive_contract: Contract) -> PoolConfig: + """Get the hyperdrive config from a deployed hyperdrive contract. + + Arguments + ---------- + hyperdrive_contract : Contract + The deployed hyperdrive contract instance. + + Returns + ------- + hyperdrive_config : PoolConfig + The hyperdrive config. + """ + hyperdrive_config: dict[str, Any] = eth.smart_contract_read(hyperdrive_contract, "getPoolConfig") + out_config = {} + out_config["contractAddress"] = hyperdrive_contract.address + out_config["baseToken"] = hyperdrive_config.get("baseToken", None) + out_config["initializeSharePrice"] = eth.convert_scaled_value(hyperdrive_config.get("initializeSharePrice", None)) + out_config["positionDuration"] = hyperdrive_config.get("positionDuration", None) + out_config["checkpointDuration"] = hyperdrive_config.get("checkpointDuration", None) + config_time_stretch = hyperdrive_config.get("timeStretch", None) + if config_time_stretch: + fp_time_stretch = FixedPoint(scaled_value=config_time_stretch) + time_stretch = float(fp_time_stretch) + inv_time_stretch = float(1 / fp_time_stretch) + else: + time_stretch = None + inv_time_stretch = None + out_config["timeStretch"] = time_stretch + out_config["governance"] = hyperdrive_config.get("governance", None) + out_config["feeCollector"] = hyperdrive_config.get("feeCollector", None) + curve_fee, flat_fee, governance_fee = hyperdrive_config.get("fees", (None, None, None)) + out_config["curveFee"] = eth.convert_scaled_value(curve_fee) + out_config["flatFee"] = eth.convert_scaled_value(flat_fee) + out_config["governanceFee"] = eth.convert_scaled_value(governance_fee) + out_config["oracleSize"] = hyperdrive_config.get("oracleSize", None) + out_config["updateGap"] = hyperdrive_config.get("updateGap", None) + out_config["invTimeStretch"] = inv_time_stretch + if out_config["positionDuration"] is not None: + term_length = out_config["positionDuration"] / 60 / 60 / 24 # in days + else: + term_length = None + out_config["termLength"] = term_length + return PoolConfig(**out_config) + + +def get_wallet_info( + hyperdrive_contract: Contract, + base_contract: Contract, + block_number: BlockNumber, + transactions: list[Transaction], +) -> list[WalletInfo]: + """Retrieves wallet information at a given block given a transaction + Transactions are needed here to get + (1) the wallet address of a transaction, and + (2) the token id of the transaction + + Arguments + ---------- + hyperdrive_contract : Contract + The deployed hyperdrive contract instance. + base_contract : Contract + The deployed base contract instance + block_number : BlockNumber + The block number to query + transactions : list[Transaction] + The list of transactions to get events from + + Returns + ------- + list[WalletInfo] + The list of WalletInfo objects ready to be inserted into postgres + """ + # pylint: disable=too-many-locals + out_wallet_info = [] + for transaction in transactions: + wallet_addr = transaction.event_operator + token_id = transaction.event_id + token_prefix = transaction.event_prefix + token_maturity_time = transaction.event_maturity_time + if wallet_addr is None: + continue + num_base_token_scaled = None + for _ in range(RETRY_COUNT): + try: + num_base_token_scaled = base_contract.functions.balanceOf(wallet_addr).call( + block_identifier=block_number + ) + break + except ValueError: + logging.warning("Error in getting base token balance, retrying") + time.sleep(1) + continue + num_base_token = eth.convert_scaled_value(num_base_token_scaled) + if (num_base_token is not None) and (wallet_addr is not None): + out_wallet_info.append( + WalletInfo( + blockNumber=block_number, + walletAddress=wallet_addr, + baseTokenType="BASE", + tokenType="BASE", + tokenValue=num_base_token, + ) + ) + # Handle cases where these fields don't exist + if (token_id is not None) and (token_prefix is not None): + base_token_type = hyperdrive_assets.AssetIdPrefix(token_prefix).name + if (token_maturity_time is not None) and (token_maturity_time > 0): + token_type = base_token_type + "-" + str(token_maturity_time) + maturity_time = token_maturity_time + else: + token_type = base_token_type + maturity_time = None + num_custom_token_scaled = None + for _ in range(RETRY_COUNT): + try: + num_custom_token_scaled = hyperdrive_contract.functions.balanceOf(int(token_id), wallet_addr).call( + block_identifier=block_number + ) + except ValueError: + logging.warning("Error in getting custom token balance, retrying") + time.sleep(1) + continue + num_custom_token = eth.convert_scaled_value(num_custom_token_scaled) + if num_custom_token is not None: + out_wallet_info.append( + WalletInfo( + blockNumber=block_number, + walletAddress=wallet_addr, + baseTokenType=base_token_type, + tokenType=token_type, + tokenValue=num_custom_token, + maturityTime=maturity_time, + ) + ) + return out_wallet_info diff --git a/elfpy/math/__init__.py b/elfpy/math/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/elfpy/simulators/__init__.py b/elfpy/simulators/__init__.py index fa30809dba..048495ac50 100644 --- a/elfpy/simulators/__init__.py +++ b/elfpy/simulators/__init__.py @@ -1,2 +1,11 @@ -"""trading and simulation modules""" -from .simulators import * +"""Trading and simulation modules""" +from .simulators import Simulator +from .config import Config +from .simulation_state import ( + SimulationState, + RunSimVariables, + DaySimVariables, + BlockSimVariables, + TradeSimVariables, + NewSimulationState, +) diff --git a/elfpy/simulators/simulation_state.py b/elfpy/simulators/simulation_state.py index eb08d10ff8..e3ed6b4b9b 100644 --- a/elfpy/simulators/simulation_state.py +++ b/elfpy/simulators/simulation_state.py @@ -79,29 +79,6 @@ def __setitem__(self, key, value): setattr(self, key, value) -def simulation_state_aggreagator(constructor): - """Returns a dataclass that aggregates simulation state attributes""" - # Wrap the type from the constructor in a list, but keep the name - attribs = [(str(key), list[val], field(default_factory=list)) for key, val in constructor.__annotations__.items()] - - # Make a new dataclass that has helper functions for appending to the list - def update(obj, dictionary): - for key, value in dictionary.items(): - obj.update_item(key, value) - - # The lambda is used because of the self variable -- TODO: can possibly remove? - # pylint: disable=unnecessary-lambda - aggregator = make_dataclass( - cls_name=constructor.__name__ + "Aggregator", - fields=attribs, - namespace={ - "update_item": lambda self, key, value: getattr(self, key).append(value), - "update": lambda self, dict_like: update(self, dict_like), - }, - )() - return aggregator - - @dataclass class RunSimVariables: """Simulation state variables that change by run""" @@ -188,10 +165,10 @@ class NewSimulationState: def __post_init__(self) -> None: r"""Construct empty dataclasses with appropriate attributes for each state variable type""" - self._run_updates = simulation_state_aggreagator(RunSimVariables) - self._day_updates = simulation_state_aggreagator(DaySimVariables) - self._block_updates = simulation_state_aggreagator(BlockSimVariables) - self._trade_updates = simulation_state_aggreagator(TradeSimVariables) + self._run_updates = _simulation_state_aggreagator(RunSimVariables) + self._day_updates = _simulation_state_aggreagator(DaySimVariables) + self._block_updates = _simulation_state_aggreagator(BlockSimVariables) + self._trade_updates = _simulation_state_aggreagator(TradeSimVariables) def update( self, @@ -237,3 +214,26 @@ def combined_dataframe(self) -> pd.DataFrame: with entries in the smaller dataframes duplicated accordingly """ return self.trade_updates.merge(self.block_updates.merge(self.day_updates.merge(self.run_updates))) + + +def _simulation_state_aggreagator(constructor): + """Returns a dataclass that aggregates simulation state attributes""" + # Wrap the type from the constructor in a list, but keep the name + attribs = [(str(key), list[val], field(default_factory=list)) for key, val in constructor.__annotations__.items()] + + # Make a new dataclass that has helper functions for appending to the list + def update(obj, dictionary): + for key, value in dictionary.items(): + obj.update_item(key, value) + + # The lambda is used because of the self variable -- TODO: can possibly remove? + # pylint: disable=unnecessary-lambda + aggregator = make_dataclass( + cls_name=constructor.__name__ + "Aggregator", + fields=attribs, + namespace={ + "update_item": lambda self, key, value: getattr(self, key).append(value), + "update": lambda self, dict_like: update(self, dict_like), + }, + )() + return aggregator diff --git a/elfpy/time/__init__.py b/elfpy/time/__init__.py index ce0fa5e7ae..199bdb8b8b 100644 --- a/elfpy/time/__init__.py +++ b/elfpy/time/__init__.py @@ -1,2 +1,8 @@ """Time and time related utilities""" -from .time import * +from .time import ( + TimeUnit, + BlockTime, + StretchedTime, + get_years_remaining, + days_to_time_remaining, +) diff --git a/elfpy/wallet/__init__.py b/elfpy/wallet/__init__.py new file mode 100644 index 0000000000..e69de29bb2