From c1a86667b93467dc2479394f005a8f009cf58970 Mon Sep 17 00:00:00 2001 From: Mihai Date: Mon, 17 Jul 2023 22:22:13 -0400 Subject: [PATCH 1/7] add checkpoints to db --- elfpy/data/acquire_data.py | 16 ++- elfpy/data/db_schema.py | 12 ++ elfpy/data/postgres.py | 79 ++++++++++- elfpy/hyperdrive_interface/__init__.py | 5 +- .../hyperdrive_interface.py | 33 +++++ tests/data/test_checkpoint.py | 129 ++++++++++++++++++ 6 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 tests/data/test_checkpoint.py diff --git a/elfpy/data/acquire_data.py b/elfpy/data/acquire_data.py index 6d51f6845b..ab51a6fbc6 100644 --- a/elfpy/data/acquire_data.py +++ b/elfpy/data/acquire_data.py @@ -177,11 +177,22 @@ def main( # Query and add block_pool_info pool_info_dict = hyperdrive_interface.get_hyperdrive_pool_info(web3, hyperdrive_contract, block_number) # Set defaults - for key in db_schema.PoolInfo.__annotations__.keys(): - if key not in pool_info_dict.keys(): + for key in db_schema.PoolInfo.__annotations__: + if key not in pool_info_dict: pool_info_dict[key] = None block_pool_info = db_schema.PoolInfo(**pool_info_dict) postgres.add_pool_infos([block_pool_info], session) + + # Query and add block_checkpoint_info + checkpoint_info_dict = hyperdrive_interface.get_hyperdrive_checkpoint_info( + web3, hyperdrive_contract, block_number + ) + # Set defaults + for key in db_schema.CheckpointInfo.__annotations__: + if key not in checkpoint_info_dict: + checkpoint_info_dict[key] = None + block_checkpoint_info = db_schema.CheckpointInfo(**checkpoint_info_dict) + postgres.add_checkpoint_infos([block_checkpoint_info], session) # Query and add block transactions block_transactions = db_schema.fetch_transactions_for_block(web3, hyperdrive_contract, block_number) postgres.add_transactions(block_transactions, session) @@ -225,6 +236,7 @@ def main( continue if block_pool_info: postgres.add_pool_infos([block_pool_info], session) + block_transactions = None for _ in range(RETRY_COUNT): try: diff --git a/elfpy/data/db_schema.py b/elfpy/data/db_schema.py index 3ab072502b..f9a04d47cd 100644 --- a/elfpy/data/db_schema.py +++ b/elfpy/data/db_schema.py @@ -84,6 +84,18 @@ class PoolConfig(Base): termLength: Mapped[Union[float, None]] = mapped_column(Numeric, default=None) +class CheckpointInfo(Base): + """Table/dataclass schema for checkpoint information""" + + __tablename__ = "checkpointinfo" + + blockNumber: Mapped[int] = mapped_column(BigInteger, primary_key=True) + timestamp: Mapped[datetime] = mapped_column(DateTime, index=True) + sharePrice: Mapped[Union[float, None]] = mapped_column(Numeric, default=None) + longSharePrice: Mapped[Union[float, None]] = mapped_column(Numeric, default=None) + shortBaseVolume: Mapped[Union[float, None]] = mapped_column(Numeric, default=None) + + class PoolInfo(Base): """ Table/dataclass schema for pool info diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index b7c28f2718..604957109b 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -11,7 +11,7 @@ from sqlalchemy import URL, create_engine, func from sqlalchemy.orm import Session, sessionmaker -from elfpy.data.db_schema import Base, PoolConfig, PoolInfo, Transaction, UserMap, WalletInfo +from elfpy.data.db_schema import Base, PoolConfig, PoolInfo, CheckpointInfo, Transaction, UserMap, WalletInfo # classes for sqlalchemy that define table schemas have no methods. # pylint: disable=too-few-public-methods @@ -59,6 +59,24 @@ def build_postgres_config() -> PostgresConfig: return PostgresConfig(**arg_dict) +def query_tables(session): + """Return a list of tables in the database.""" + stmt = text( + "SELECT tablename FROM pg_catalog.pg_tables " + "WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + ) + result = session.execute(stmt) # Execute the statement + tables = result.fetchall() # Fetch all the results + return tables + + +def drop_table(session, table_name): + """Drop a table from the database.""" + stmt = text(f"DROP TABLE IF EXISTS {table_name}") + session.execute(stmt) + session.commit() + + def initialize_session() -> Session: """Initialize the database if not already initialized""" @@ -180,6 +198,25 @@ def add_pool_infos(pool_infos: list[PoolInfo], session: Session) -> None: raise err +def add_checkpoint_infos(checkpoint_infos: list[CheckpointInfo], session: Session) -> None: + """Add checkpoint info to the checkpointinfo table + + Arguments + --------- + checkpoint_infos: list[Checkpoint] + A list of Checkpoint objects to insert into postgres + session: Session + The initialized session object + """ + for checkpoint_info in checkpoint_infos: + session.add(checkpoint_info) + try: + session.commit() + except sqlalchemy.exc.DataError as err: # type: ignore + session.rollback() + raise err + + def add_transactions(transactions: list[Transaction], session: Session) -> None: """Add transactions to the poolinfo table @@ -302,6 +339,46 @@ def get_pool_info(session: Session, start_block: int | None = None, end_block: i return pd.read_sql(query.statement, con=session.connection()).set_index("blockNumber") +def get_checkpoint_info(session: Session, start_block: int | None = None, end_block: int | None = None) -> pd.DataFrame: + """Gets all info associated with a given checkpoint. + + This includes + - `sharePrice` : The share price of the first transaction in the checkpoint. + - `longSharePrice` : The weighted average of the share prices that all longs in the checkpoint were opened at. + - `shortBaseVolume` : The aggregate amount of base committed by LPs to pay for bonds sold short in the checkpoint. + + Arguments + --------- + session : Session + The initialized session object + block : int | None, optional + The block number whose checkpoint to return. If None, returns the most recent checkpoint. + + Returns + ------- + DataFrame + A DataFrame that consists of the queried checkpoint info + """ + + query = session.query(CheckpointInfo) + + # Support for negative indices + if (start_block is not None) and (start_block < 0): + start_block = get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + start_block + 1 + if (end_block is not None) and (end_block < 0): + end_block = get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + end_block + 1 + + if start_block is not None: + query = query.filter(CheckpointInfo.blockNumber >= start_block) + if end_block is not None: + query = query.filter(CheckpointInfo.blockNumber < end_block) + + # Always sort by time in order + query = query.order_by(CheckpointInfo.timestamp) + + return pd.read_sql(query.statement, con=session.connection()).set_index("blockNumber") + + def get_transactions(session: Session, start_block: int | None = None, end_block: int | None = None) -> pd.DataFrame: """ Gets all transactions and returns as a pandas dataframe diff --git a/elfpy/hyperdrive_interface/__init__.py b/elfpy/hyperdrive_interface/__init__.py index fce3bad67b..aa7991b60c 100644 --- a/elfpy/hyperdrive_interface/__init__.py +++ b/elfpy/hyperdrive_interface/__init__.py @@ -3,8 +3,9 @@ from .hyperdrive_addresses import HyperdriveAddresses from .hyperdrive_interface import ( fetch_hyperdrive_address_from_url, - get_hyperdrive_config, get_hyperdrive_contract, - get_hyperdrive_market, get_hyperdrive_pool_info, + get_hyperdrive_checkpoint_info, + get_hyperdrive_config, + get_hyperdrive_market, ) diff --git a/elfpy/hyperdrive_interface/hyperdrive_interface.py b/elfpy/hyperdrive_interface/hyperdrive_interface.py index eec3ac35cc..c6f29ab9a2 100644 --- a/elfpy/hyperdrive_interface/hyperdrive_interface.py +++ b/elfpy/hyperdrive_interface/hyperdrive_interface.py @@ -119,6 +119,39 @@ def get_hyperdrive_pool_info(web3: Web3, hyperdrive_contract: Contract, block_nu return pool_info +def get_hyperdrive_checkpoint_info( + web3: Web3, hyperdrive_contract: Contract, block_number: BlockNumber +) -> dict[str, Any]: + """Returns the checkpoitn info of Hyperdrive contract for the given block. + + Arguments + --------- + web3: Web3 + web3 provider object + hyperdrive_contract: Contract + The contract to query the pool info from + block_number: BlockNumber + The block number to query from the chain + + Returns + ------- + Checkpoint + A Checkpoint object ready to be inserted into Postgres + """ + current_block: BlockData = web3.eth.get_block(block_number) + current_block_timestamp = current_block.get("timestamp") + if current_block_timestamp is None: + raise AssertionError("Current block has no timestamp") + checkpoint_data: dict[str, int] = eth.smart_contract_read(hyperdrive_contract, "getCheckpoint", block_number) + return { + "blockNumber": int(block_number), + "timestamp": datetime.fromtimestamp(current_block_timestamp), + "sharePrice": eth.convert_scaled_value(checkpoint_data["sharePrice"]), + "longSharePrice": eth.convert_scaled_value(checkpoint_data["longSharePrice"]), + "shortBaseVolume": eth.convert_scaled_value(checkpoint_data["shortBaseVolume"]), + } + + def get_hyperdrive_config(hyperdrive_contract: Contract) -> dict[str, Any]: """Get the hyperdrive config from a deployed hyperdrive contract. diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py new file mode 100644 index 0000000000..e1c14d04f6 --- /dev/null +++ b/tests/data/test_checkpoint.py @@ -0,0 +1,129 @@ +"""CRUD tests for CheckpointInfo""" +from datetime import datetime + +import pytest +import numpy as np +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from elfpy.data import postgres +from elfpy.data.db_schema import Base, CheckpointInfo + +engine = create_engine("sqlite:///:memory:") # in-memory SQLite database for testing +Session = sessionmaker(bind=engine) + +# fixture arguments in test function have to be the same as the fixture name +# pylint: disable=redefined-outer-name + + +@pytest.fixture(scope="function") +def session(): + """Session fixture for tests""" + Base.metadata.create_all(engine) # create tables + session_ = Session() + yield session_ + session_.close() + Base.metadata.drop_all(engine) # drop tables + + +class TestCheckpointTable: + """CRUD tests for checkpoint table""" + + def test_create_checkpoint(self, session): + """Create and entry""" + # Note: this test is using inmemory sqlite, which doesn't seem to support + # autoincrementing ids without init, whereas postgres does this with no issues + # Hence, we explicitly add id here + timestamp = datetime.now() + checkpoint = CheckpointInfo(blockNumber=1, timestamp=timestamp) + postgres.add_checkpoint_infos([checkpoint], session) + session.commit() + + retrieved_checkpoint = session.query(CheckpointInfo).filter_by(blockNumber=1).first() + assert retrieved_checkpoint is not None + # event_value retreieved from postgres is in Decimal, cast to float + assert retrieved_checkpoint.timestamp == timestamp + + def test_update_checkpoint(self, session): + """Update an entry""" + timestamp = datetime.now() + checkpoint = CheckpointInfo(blockNumber=1, timestamp=timestamp) + postgres.add_checkpoint_infos([checkpoint], session) + session.commit() + + checkpoint.sharePrice = 5.0 + session.commit() + + updated_checkpoint = session.query(CheckpointInfo).filter_by(blockNumber=1).first() + # event_value retreieved from postgres is in Decimal, cast to float + assert updated_checkpoint.sharePrice == 5.0 + + def test_delete_checkpoint(self, session): + """Delete an entry""" + timestamp = datetime.now() + checkpoint = CheckpointInfo(blockNumber=1, timestamp=timestamp) + postgres.add_checkpoint_infos([checkpoint], session) + session.commit() + + session.delete(checkpoint) + session.commit() + + deleted_checkpoint = session.query(CheckpointInfo).filter_by(blockNumber=1).first() + assert deleted_checkpoint is None + + +class TestCheckpointInterface: + """Testing postgres interface for checkpoint table""" + + def test_latest_block_number(self, session): + """Testing retrevial of checkpoint via interface""" + checkpoint_1 = CheckpointInfo(blockNumber=1, timestamp=datetime.now()) + postgres.add_checkpoint_infos([checkpoint_1], session) + session.commit() + + latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + assert latest_block_number == 1 + + checkpoint_2 = CheckpointInfo(blockNumber=2, timestamp=datetime.now()) + checkpoint_3 = CheckpointInfo(blockNumber=3, timestamp=datetime.now()) + postgres.add_checkpoint_infos([checkpoint_2, checkpoint_3], session) + + latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + assert latest_block_number == 3 + + def test_get_checkpoints(self, session): + """Testing retrevial of checkpoints via interface""" + date_1 = datetime(1945, 8, 6) + date_2 = datetime(1984, 8, 9) + date_3 = datetime(2001, 9, 11) + checkpoint_1 = CheckpointInfo(blockNumber=0, timestamp=date_1) + checkpoint_2 = CheckpointInfo(blockNumber=1, timestamp=date_2) + checkpoint_3 = CheckpointInfo(blockNumber=2, timestamp=date_3) + postgres.add_checkpoint_infos([checkpoint_1, checkpoint_2, checkpoint_3], session) + + checkpoints_df = postgres.get_checkpoint_info(session) + np.testing.assert_array_equal( + checkpoints_df["timestamp"].dt.to_pydatetime(), np.array([date_1, date_2, date_3]) + ) + + def test_block_query_checkpoints(self, session): + """Testing querying by block number of checkpoints via interface""" + checkpoint_1 = CheckpointInfo(blockNumber=0, timestamp=datetime.now(), sharePrice=3.1) + checkpoint_2 = CheckpointInfo(blockNumber=1, timestamp=datetime.now(), sharePrice=3.2) + checkpoint_3 = CheckpointInfo(blockNumber=2, timestamp=datetime.now(), sharePrice=3.3) + postgres.add_checkpoint_infos([checkpoint_1, checkpoint_2, checkpoint_3], session) + + checkpoints_df = postgres.get_checkpoint_info(session, start_block=1) + np.testing.assert_array_equal(checkpoints_df["sharePrice"], [3.2, 3.3]) + + checkpoints_df = postgres.get_checkpoint_info(session, start_block=-1) + np.testing.assert_array_equal(checkpoints_df["sharePrice"], [3.3]) + + checkpoints_df = postgres.get_checkpoint_info(session, end_block=1) + np.testing.assert_array_equal(checkpoints_df["sharePrice"], [3.1]) + + checkpoints_df = postgres.get_checkpoint_info(session, end_block=-1) + np.testing.assert_array_equal(checkpoints_df["sharePrice"], [3.1, 3.2]) + + checkpoints_df = postgres.get_checkpoint_info(session, start_block=1, end_block=-1) + np.testing.assert_array_equal(checkpoints_df["sharePrice"], [3.2]) From ef06b41baef9c73b7c480b15af8abbc36248b062 Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 18:15:32 -0400 Subject: [PATCH 2/7] fixes --- elfpy/data/postgres.py | 2 +- examples/hackweek_demo/extract_data_logs.py | 6 +----- examples/hackweek_demo/run_demo.py | 3 +-- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index 604957109b..7fb098db01 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -470,7 +470,7 @@ def get_wallet_info_history(session: Session) -> dict[str, pd.DataFrame]: # Get data all_wallet_info = get_all_wallet_info(session) - pool_info_lookup = get_pool_info(session)[["blockNumber", "timestamp", "sharePrice"]].set_index("blockNumber") + pool_info_lookup = get_pool_info(session)[["timestamp", "sharePrice"]] # Pivot tokenType to columns, keeping walletAddress and blockNumber all_wallet_info = all_wallet_info.pivot( diff --git a/examples/hackweek_demo/extract_data_logs.py b/examples/hackweek_demo/extract_data_logs.py index fea3b7e862..17206a2e66 100644 --- a/examples/hackweek_demo/extract_data_logs.py +++ b/examples/hackweek_demo/extract_data_logs.py @@ -60,7 +60,7 @@ def get_combined_data(txn_data, pool_info_data): pool_info_data.index = pool_info_data.index.astype(int) # txn_data.index = txn_data["blockNumber"] # Combine pool info data and trans data by block number - data = txn_data.merge(pool_info_data) + data = txn_data.merge(pool_info_data, on="blockNumber") rename_dict = { "event_operator": "operator", @@ -71,7 +71,6 @@ def get_combined_data(txn_data, pool_info_data): "event_maturity_time": "maturity_time", "event_value": "value", "bondReserves": "bond_reserves", - "blockNumber": "block_number", "input_method": "trade_type", "longsOutstanding": "longs_outstanding", "longAverageMaturityTime": "longs_average_maturity_time", @@ -92,9 +91,6 @@ def get_combined_data(txn_data, pool_info_data): # Rename columns trade_data = trade_data.rename(columns=rename_dict) - # TODO: Fix this -- will break if we allow multiple trades per block - trade_data.index = trade_data["block_number"] - # Calculate trade type and timetsamp from args.id def decode_prefix(row): # Check for nans diff --git a/examples/hackweek_demo/run_demo.py b/examples/hackweek_demo/run_demo.py index 508b145f7e..03cf04aa8c 100644 --- a/examples/hackweek_demo/run_demo.py +++ b/examples/hackweek_demo/run_demo.py @@ -55,8 +55,7 @@ def get_ticker(data): """Given transaction data, return a subset of the dataframe""" # Return reverse of methods to put most recent transactions at the top - out = data[["blockNumber", "input_method"]].set_index("blockNumber").iloc[::-1] - return out + return data[["input_method"]].iloc[::-1] agent_list = postgres.get_agents(session, start_block=start_block) From 7df638782c7623b6d12dfa6fd16dae7ce010fa41 Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 19:34:46 -0400 Subject: [PATCH 3/7] incorporate comments --- elfpy/data/acquire_data.py | 11 +- elfpy/data/db_schema.py | 4 - elfpy/data/postgres.py | 221 ++++++++++++++++++++++------------ tests/data/test_checkpoint.py | 6 +- tests/data/test_db_utils.py | 63 ++++++++++ 5 files changed, 212 insertions(+), 93 deletions(-) create mode 100644 tests/data/test_db_utils.py diff --git a/elfpy/data/acquire_data.py b/elfpy/data/acquire_data.py index ab51a6fbc6..82baed9acc 100644 --- a/elfpy/data/acquire_data.py +++ b/elfpy/data/acquire_data.py @@ -177,6 +177,7 @@ def main( # Query and add block_pool_info pool_info_dict = hyperdrive_interface.get_hyperdrive_pool_info(web3, hyperdrive_contract, block_number) # Set defaults + # TODO: abstract this out: pull the conversion between the interface to the db object into various functions for key in db_schema.PoolInfo.__annotations__: if key not in pool_info_dict: pool_info_dict[key] = None @@ -216,8 +217,7 @@ def main( latest_mined_block, ) continue - # get_block_pool_info crashes randomly with ValueError on some intermediate block, - # keep trying until it returns + # keep querying until it returns to avoid random crashes with ValueError on some intermediate block block_pool_info = None for _ in range(RETRY_COUNT): try: @@ -225,18 +225,19 @@ def main( web3, hyperdrive_contract, block_number ) # Set defaults - for key in db_schema.PoolInfo.__annotations__.keys(): - if key not in pool_info_dict.keys(): + for key in db_schema.PoolInfo.__annotations__: + if key not in pool_info_dict: pool_info_dict[key] = None block_pool_info = db_schema.PoolInfo(**pool_info_dict) break except ValueError: - logging.warning("Error in get_block_pool_info, retrying") + logging.warning("Error in get_hyperdrive_pool_info, retrying") time.sleep(1) continue if block_pool_info: postgres.add_pool_infos([block_pool_info], session) + # keep querying until it returns to avoid random crashes with ValueError on some intermediate block block_transactions = None for _ in range(RETRY_COUNT): try: diff --git a/elfpy/data/db_schema.py b/elfpy/data/db_schema.py index f9a04d47cd..37e77cd84e 100644 --- a/elfpy/data/db_schema.py +++ b/elfpy/data/db_schema.py @@ -62,8 +62,6 @@ class PoolConfig(Base): Table/dataclass schema for pool config """ - # pylint: disable=too-many-instance-attributes - __tablename__ = "poolconfig" contractAddress: Mapped[str] = mapped_column(String, primary_key=True) @@ -102,8 +100,6 @@ class PoolInfo(Base): Mapped class that is a data class on the python side, and an declarative base on the sql side. """ - # pylint: disable=too-many-instance-attributes - __tablename__ = "poolinfo" blockNumber: Mapped[int] = mapped_column(BigInteger, primary_key=True) diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index 7fb098db01..1a21feb71c 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -8,7 +8,7 @@ import pandas as pd import sqlalchemy -from sqlalchemy import URL, create_engine, func +from sqlalchemy import URL, create_engine, func, create_engine, inspect, MetaData, Table, exc from sqlalchemy.orm import Session, sessionmaker from elfpy.data.db_schema import Base, PoolConfig, PoolInfo, CheckpointInfo, Transaction, UserMap, WalletInfo @@ -16,13 +16,26 @@ # classes for sqlalchemy that define table schemas have no methods. # pylint: disable=too-few-public-methods -# replace the user, password, and db_name with credentials -# TODO remove engine as global - @dataclass class PostgresConfig: - """The configuration dataclass for postgress connections""" + """The configuration dataclass for postgress connections. + + Replace the user, password, and db_name with the credentials of your setup. + + Attributes + ---------- + POSTGRES_USER : str + The username to authentiate with + POSTGRES_PASSWORD : str + The password to authentiate with + POSTGRES_DB : str + The name of the database + POSTGRES_HOST : str + The hostname to connect to + POSTGRES_PORT : int + The port to connect to + """ # default values for local postgres # Matching environemnt variables to search for @@ -35,8 +48,14 @@ class PostgresConfig: def build_postgres_config() -> PostgresConfig: - """Build a PostgresConfig that looks for environmental variables + """Build a PostgresConfig that looks for environmental variables. + If env var exists, use that, otherwise, default + + Returns + ------- + config : PostgresConfig + Config settings required to connect to and use the database """ user = os.getenv("POSTGRES_USER") password = os.getenv("POSTGRES_PASSWORD") @@ -59,26 +78,50 @@ def build_postgres_config() -> PostgresConfig: return PostgresConfig(**arg_dict) -def query_tables(session): - """Return a list of tables in the database.""" - stmt = text( - "SELECT tablename FROM pg_catalog.pg_tables " - "WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" - ) - result = session.execute(stmt) # Execute the statement - tables = result.fetchall() # Fetch all the results - return tables +def query_tables(session: Session) -> list[str]: + """Return a list of tables in the database. + + Arguments + --------- + session : Session + The initialized session object + Returns + ------- + table_names : list[str] + A list of table names in the database + """ + inspector = inspect(session.bind) # nice gadget + assert inspector is not None, "inspector is None" + return inspector.get_table_names() -def drop_table(session, table_name): - """Drop a table from the database.""" - stmt = text(f"DROP TABLE IF EXISTS {table_name}") - session.execute(stmt) - session.commit() + +def drop_table(session: Session, table_name: str) -> None: + """Drop a table from the database. + + Arguments + --------- + session : Session + The initialized session object + table_names : str + The name of the table to be dropped + """ + metadata = MetaData() + table = Table(table_name, metadata) + bind = session.bind + assert isinstance(bind, sqlalchemy.engine.base.Engine), "bind is not an engine" + # checkfirst=true automatically adds an "IF EXISTS" clause + table.drop(checkfirst=True, bind=bind) def initialize_session() -> Session: - """Initialize the database if not already initialized""" + """Initialize the database if not already initialized. + + Returns + ------- + session : Session + The initialized session object + """ postgres_config = build_postgres_config() @@ -108,7 +151,7 @@ def initialize_session() -> Session: def close_session(session: Session) -> None: - """Close the session + """Close the session. Arguments --------- @@ -119,7 +162,8 @@ def close_session(session: Session) -> None: def add_wallet_infos(wallet_infos: list[WalletInfo], session: Session) -> None: - """Add wallet info to the walletinfo table + """Add wallet info to the walletinfo table. + Arguments --------- wallet_infos: list[WalletInfo] @@ -131,22 +175,22 @@ def add_wallet_infos(wallet_infos: list[WalletInfo], session: Session) -> None: session.add(wallet_info) try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: session.rollback() print(f"{wallet_infos=}") raise err def add_pool_config(pool_config: PoolConfig, session: Session) -> None: - """ - Add pool config to the pool config table if not exist - Verify pool config if it does exist + """Add pool config to the pool config table if not exist. + + Verify pool config if it does exist. Arguments --------- - pool_config: PoolConfig + pool_config : PoolConfig A PoolConfig object to insert into postgres - session: Session + session : Session The initialized session object """ @@ -160,7 +204,7 @@ def add_pool_config(pool_config: PoolConfig, session: Session) -> None: session.add(pool_config) try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: session.rollback() print(f"{pool_config=}") raise err @@ -179,20 +223,20 @@ def add_pool_config(pool_config: PoolConfig, session: Session) -> None: def add_pool_infos(pool_infos: list[PoolInfo], session: Session) -> None: - """Add a pool info to the poolinfo table + """Add a pool info to the poolinfo table. Arguments --------- - pool_infos: list[PoolInfo] + pool_infos : list[PoolInfo] A list of PoolInfo objects to insert into postgres - session: Session + session : Session The initialized session object """ for pool_info in pool_infos: session.add(pool_info) try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: session.rollback() print(f"{pool_infos=}") raise err @@ -203,50 +247,50 @@ def add_checkpoint_infos(checkpoint_infos: list[CheckpointInfo], session: Sessio Arguments --------- - checkpoint_infos: list[Checkpoint] + checkpoint_infos : list[Checkpoint] A list of Checkpoint objects to insert into postgres - session: Session + session : Session The initialized session object """ for checkpoint_info in checkpoint_infos: session.add(checkpoint_info) try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: session.rollback() raise err def add_transactions(transactions: list[Transaction], session: Session) -> None: - """Add transactions to the poolinfo table + """Add transactions to the poolinfo table. Arguments --------- - transactions: list[Transaction] + transactions : list[Transaction] A list of Transaction objects to insert into postgres - session: Session + session : Session The initialized session object """ for transaction in transactions: session.add(transaction) try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: session.rollback() print(f"{transactions=}") raise err def add_user_map(username: str, addresses: list[str], session: Session) -> None: - """Add username mapping to postgres during evm_bots initialization + """Add username mapping to postgres during evm_bots initialization. Arguments --------- - username: str + username : str The logical username to attach to the wallet address - addresses: list[str] + addresses : list[str] A list of wallet addresses to map to the username - session: Session + session : Session The initialized session object """ @@ -272,20 +316,19 @@ def add_user_map(username: str, addresses: list[str], session: Session) -> None: try: session.commit() - except sqlalchemy.exc.DataError as err: # type: ignore + except exc.DataError as err: print(f"{username=}, {addresses=}") raise err def get_pool_config(session: Session, contract_address: str | None = None) -> pd.DataFrame: - """ - Gets all pool config and returns as a pandas dataframe + """Gets all pool config and returns as a pandas dataframe. Arguments --------- - session: Session + session : Session The initialized session object - contract_address: str | None + contract_address : str | None, optional The contract_address to filter the results on. Return all if None Returns @@ -300,17 +343,16 @@ def get_pool_config(session: Session, contract_address: str | None = None) -> pd def get_pool_info(session: Session, start_block: int | None = None, end_block: int | None = None) -> pd.DataFrame: - """ - Gets all pool info and returns as a pandas dataframe + """Gets all pool info and returns as a pandas dataframe. Arguments --------- - session: Session + session : Session The initialized session object - start_block: int | None + start_block : int | None, optional The starting block to filter the query on. start_block integers matches python slicing notation, e.g., list[:3], list[:-3] - end_block: int | None + end_block : int | None, optional The ending block to filter the query on. end_block integers matches python slicing notation, e.g., list[:3], list[:-3] @@ -380,17 +422,16 @@ def get_checkpoint_info(session: Session, start_block: int | None = None, end_bl def get_transactions(session: Session, start_block: int | None = None, end_block: int | None = None) -> pd.DataFrame: - """ - Gets all transactions and returns as a pandas dataframe + """Gets all transactions and returns as a pandas dataframe. Arguments --------- - session: Session + session : Session The initialized session object - start_block: int | None + start_block : int | None The starting block to filter the query on. start_block integers matches python slicing notation, e.g., list[:3], list[:-3] - end_block: int | None + end_block : int | None The ending block to filter the query on. end_block integers matches python slicing notation, e.g., list[:3], list[:-3] @@ -417,17 +458,16 @@ def get_transactions(session: Session, start_block: int | None = None, end_block def get_all_wallet_info(session: Session, start_block: int | None = None, end_block: int | None = None) -> pd.DataFrame: - """ - Gets all of the wallet_info data in history and returns as a pandas dataframe + """Gets all of the wallet_info data in history and returns as a pandas dataframe. Arguments --------- - session: Session + session : Session The initialized session object - start_block: int | None + start_block : int | None, optional The starting block to filter the query on. start_block integers matches python slicing notation, e.g., list[:3], list[:-3] - end_block: int | None + end_block : int | None, optional The ending block to filter the query on. end_block integers matches python slicing notation, e.g., list[:3], list[:-3] @@ -454,10 +494,11 @@ def get_all_wallet_info(session: Session, start_block: int | None = None, end_bl def get_wallet_info_history(session: Session) -> dict[str, pd.DataFrame]: - """Gets the history of all wallet info over block time + """Gets the history of all wallet info over block time. + Arguments --------- - session: Session + session : Session The initialized session object Returns @@ -507,12 +548,12 @@ def get_current_wallet_info( Arguments --------- - session: Session + session : Session The initialized session object - start_block: int | None + start_block : int | None, optional The starting block to filter the query on. start_block integers matches python slicing notation, e.g., list[:3], list[:-3] - end_block: int | None + end_block : int | None, optional The ending block to filter the query on. end_block integers matches python slicing notation, e.g., list[:3], list[:-3] @@ -526,7 +567,7 @@ def get_current_wallet_info( # Get last entry in the table of each wallet address and token type # This should always return a dataframe # Pandas doesn't play nice with types - current_wallet_info: pd.DataFrame = ( + result = ( all_wallet_info.sort_values("blockNumber", ascending=False) .groupby(["walletAddress", "tokenType"]) .agg( @@ -538,7 +579,9 @@ def get_current_wallet_info( "sharePrice": "first", } ) - ) # type: ignore + ) + assert isinstance(result, pd.DataFrame), "result is not a dataframe" + current_wallet_info: pd.DataFrame = result # Rename blockNumber column current_wallet_info = current_wallet_info.rename({"blockNumber": "latestUpdateBlock"}, axis=1) @@ -549,7 +592,24 @@ def get_current_wallet_info( def get_agents(session: Session, start_block: int | None = None, end_block: int | None = None) -> list[str]: - """Gets the list of all agents from the WalletInfo table""" + """Gets the list of all agents from the WalletInfo table. + + Arguments + --------- + session : Session + The initialized session object + start_block : int | None, optional + The starting block to filter the query on. start_block integers + matches python slicing notation, e.g., list[:3], list[:-3] + end_block : int | None, optional + The ending block to filter the query on. end_block integers + matches python slicing notation, e.g., list[:3], list[:-3] + + Returns + ------- + list[str] + A list of agent addresses + """ query = session.query(WalletInfo.walletAddress) # Support for negative indices if (start_block is not None) and (start_block < 0): @@ -572,14 +632,13 @@ def get_agents(session: Session, start_block: int | None = None, end_block: int def get_user_map(session: Session, address: str | None = None) -> pd.DataFrame: - """ - Gets all usermapping and returns as a pandas dataframe + """Gets all usermapping and returns as a pandas dataframe. Arguments --------- - session: Session + session : Session The initialized session object - address: str | None + address : str | None, optional The wallet address to filter the results on. Return all if None Returns @@ -598,7 +657,7 @@ def get_latest_block_number(session: Session) -> int: Arguments --------- - session: Session + session : Session The initialized session object Returns @@ -609,12 +668,14 @@ def get_latest_block_number(session: Session) -> int: return get_latest_block_number_from_table(PoolInfo, session) -def get_latest_block_number_from_table(table_obj: Type[WalletInfo | PoolInfo | Transaction], session: Session) -> int: +def get_latest_block_number_from_table( + table_obj: Type[WalletInfo | PoolInfo | Transaction | CheckpointInfo], session: Session +) -> int: """Gets the latest block number based on the specified table in the db. Arguments --------- - table : Type[WalletInfo | PoolInfo | Transaction] + table_obj : Type[WalletInfo | PoolInfo | Transaction | CheckpointInfo] The sqlalchemy class that contains the blockNumber column session : Session The initialized session object diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index e1c14d04f6..881de48b95 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -41,7 +41,6 @@ def test_create_checkpoint(self, session): retrieved_checkpoint = session.query(CheckpointInfo).filter_by(blockNumber=1).first() assert retrieved_checkpoint is not None - # event_value retreieved from postgres is in Decimal, cast to float assert retrieved_checkpoint.timestamp == timestamp def test_update_checkpoint(self, session): @@ -55,7 +54,6 @@ def test_update_checkpoint(self, session): session.commit() updated_checkpoint = session.query(CheckpointInfo).filter_by(blockNumber=1).first() - # event_value retreieved from postgres is in Decimal, cast to float assert updated_checkpoint.sharePrice == 5.0 def test_delete_checkpoint(self, session): @@ -81,14 +79,14 @@ def test_latest_block_number(self, session): postgres.add_checkpoint_infos([checkpoint_1], session) session.commit() - latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo, session) assert latest_block_number == 1 checkpoint_2 = CheckpointInfo(blockNumber=2, timestamp=datetime.now()) checkpoint_3 = CheckpointInfo(blockNumber=3, timestamp=datetime.now()) postgres.add_checkpoint_infos([checkpoint_2, checkpoint_3], session) - latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + latest_block_number = postgres.get_latest_block_number_from_table(CheckpointInfo, session) assert latest_block_number == 3 def test_get_checkpoints(self, session): diff --git a/tests/data/test_db_utils.py b/tests/data/test_db_utils.py new file mode 100644 index 0000000000..5d3c02e8bb --- /dev/null +++ b/tests/data/test_db_utils.py @@ -0,0 +1,63 @@ +"""CRUD tests for CheckpointInfo""" +from datetime import datetime + +import pytest +import numpy as np +from sqlalchemy import create_engine, MetaData, String +from sqlalchemy.orm import sessionmaker + +from elfpy.data import postgres +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + +engine = create_engine("sqlite:///:memory:") # in-memory SQLite database for testing +Session = sessionmaker(bind=engine) + +# fixture arguments in test function have to be the same as the fixture name +# pylint: disable=redefined-outer-name + + +class Based(MappedAsDataclass, DeclarativeBase): + """Base class to subclass from to define the schema""" + + +class Very(Based): + """Dummy but very sincere table schema.""" + + __tablename__ = "verybased" + + key: Mapped[str] = mapped_column(String, primary_key=True) + + +class DropMe(Based): + """Dummy table schema that wants to be dropped.""" + + __tablename__ = "dropme" + + key: Mapped[str] = mapped_column(String, primary_key=True) + + +@pytest.fixture(scope="function") +def session(): + """Session fixture for tests""" + Based.metadata.create_all(engine) # create tables + session_ = Session() + yield session_ + session_.close() + Based.metadata.drop_all(engine) # drop tables + + +def test_query_tables(session): + """Return a list of tables in the database.""" + table_names = postgres.query_tables(session) + session.commit() + + np.testing.assert_array_equal(table_names, ["dropme", "verybased"]) + + +def test_drop_table(session): + """Drop a table from the database.""" + postgres.drop_table(session, "dropme") + table_names = postgres.query_tables(session) + session.commit() + + np.testing.assert_array_equal(table_names, ["verybased"]) From 390178d11b17f59a5ffe9b0b663b3a3a928eec7f Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 19:37:10 -0400 Subject: [PATCH 4/7] lint a lil with ruff --- elfpy/data/acquire_data.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/elfpy/data/acquire_data.py b/elfpy/data/acquire_data.py index 82baed9acc..608d49df42 100644 --- a/elfpy/data/acquire_data.py +++ b/elfpy/data/acquire_data.py @@ -1,4 +1,4 @@ -"""Script to format on-chain hyperdrive pool, config, and transaction data post-processing""" +"""Script to format on-chain hyperdrive pool, config, and transaction data post-processing.""" from __future__ import annotations import logging @@ -31,13 +31,14 @@ def get_wallet_info( transactions: list[db_schema.Transaction], pool_info: db_schema.PoolInfo, ) -> list[db_schema.WalletInfo]: - """Retrieves wallet information at a given block given a transaction + """Retrieve wallet information at a given block given a transaction. + Transactions are needed here to get (1) the wallet address of a transaction, and (2) the token id of the transaction Arguments - ---------- + --------- hyperdrive_contract : Contract The deployed hyperdrive contract instance. base_contract : Contract @@ -46,6 +47,8 @@ def get_wallet_info( The block number to query transactions : list[db_schema.Transaction] The list of transactions to get events from + pool_info : db_schema.PoolInfo + The associated pool info, used to extract share price Returns ------- @@ -132,7 +135,23 @@ def main( lookback_block_limit: int, sleep_amount: int, ): - """Main entry point for accessing contract & writing pool info""" + """Execute the data acquisition pipeline. + + Arguments + --------- + contracts_url : str + The url of the artifacts server from which we get addresses. + ethereum_node : URI | str + The url to the ethereum node + abi_dir : str + The path to the abi directory + start_block : int + The starting block to filter the query on + lookback_block_limit : int + The maximum number of blocks to loko back when filling in missing data + sleep_amount : int + The amount of seconds to sleep between queries + """ # TODO: refactor this function, its waaay to big as indicated by these pylints # pylint: disable=too-many-locals # pylint: disable=too-many-statements From 1267780f3abb3124f85f374bbce0160d1fcfe74b Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 19:45:11 -0400 Subject: [PATCH 5/7] fix test errors --- elfpy/data/postgres.py | 8 +++++--- tests/data/test_db_utils.py | 9 +++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index 1a21feb71c..528970a938 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -8,7 +8,7 @@ import pandas as pd import sqlalchemy -from sqlalchemy import URL, create_engine, func, create_engine, inspect, MetaData, Table, exc +from sqlalchemy import URL, create_engine, func, inspect, MetaData, Table, exc from sqlalchemy.orm import Session, sessionmaker from elfpy.data.db_schema import Base, PoolConfig, PoolInfo, CheckpointInfo, Transaction, UserMap, WalletInfo @@ -406,9 +406,9 @@ def get_checkpoint_info(session: Session, start_block: int | None = None, end_bl # Support for negative indices if (start_block is not None) and (start_block < 0): - start_block = get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + start_block + 1 + start_block = get_latest_block_number_from_table(CheckpointInfo, session) + start_block + 1 if (end_block is not None) and (end_block < 0): - end_block = get_latest_block_number_from_table(CheckpointInfo.__tablename__, session) + end_block + 1 + end_block = get_latest_block_number_from_table(CheckpointInfo, session) + end_block + 1 if start_block is not None: query = query.filter(CheckpointInfo.blockNumber >= start_block) @@ -687,6 +687,8 @@ def get_latest_block_number_from_table( """ # For some reason, pylint doesn't like func.max from sqlalchemy + if not isinstance(table_obj, (WalletInfo, PoolInfo, Transaction, CheckpointInfo)): + assert "table_obj input is not a WalletInfo, PoolInfo, Transaction, or CheckpointInfo" result = session.query(func.max(table_obj.blockNumber)).first() # pylint: disable=not-callable # If table doesn't exist if result is None: diff --git a/tests/data/test_db_utils.py b/tests/data/test_db_utils.py index 5d3c02e8bb..eb53865ff5 100644 --- a/tests/data/test_db_utils.py +++ b/tests/data/test_db_utils.py @@ -1,19 +1,16 @@ """CRUD tests for CheckpointInfo""" -from datetime import datetime - import pytest import numpy as np -from sqlalchemy import create_engine, MetaData, String +from sqlalchemy import create_engine, String from sqlalchemy.orm import sessionmaker - -from elfpy.data import postgres from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column +from elfpy.data import postgres engine = create_engine("sqlite:///:memory:") # in-memory SQLite database for testing Session = sessionmaker(bind=engine) # fixture arguments in test function have to be the same as the fixture name -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name, too-few-public-methods class Based(MappedAsDataclass, DeclarativeBase): From 8dda8445eeecd7a79c863221925cf1c79d51d7fe Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 20:04:44 -0400 Subject: [PATCH 6/7] fix assert --- elfpy/data/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index 528970a938..b91bb28631 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -688,7 +688,7 @@ def get_latest_block_number_from_table( # For some reason, pylint doesn't like func.max from sqlalchemy if not isinstance(table_obj, (WalletInfo, PoolInfo, Transaction, CheckpointInfo)): - assert "table_obj input is not a WalletInfo, PoolInfo, Transaction, or CheckpointInfo" + raise ValueError("table_obj input is not a WalletInfo, PoolInfo, Transaction, or CheckpointInfo") result = session.query(func.max(table_obj.blockNumber)).first() # pylint: disable=not-callable # If table doesn't exist if result is None: From 638ac78ca32348a514c963ecd658fb259d843cf0 Mon Sep 17 00:00:00 2001 From: Mihai Date: Tue, 18 Jul 2023 21:21:04 -0400 Subject: [PATCH 7/7] remove type check --- elfpy/data/postgres.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/elfpy/data/postgres.py b/elfpy/data/postgres.py index b91bb28631..a743dc99d2 100644 --- a/elfpy/data/postgres.py +++ b/elfpy/data/postgres.py @@ -685,10 +685,7 @@ def get_latest_block_number_from_table( int The latest block number from the specified table """ - # For some reason, pylint doesn't like func.max from sqlalchemy - if not isinstance(table_obj, (WalletInfo, PoolInfo, Transaction, CheckpointInfo)): - raise ValueError("table_obj input is not a WalletInfo, PoolInfo, Transaction, or CheckpointInfo") result = session.query(func.max(table_obj.blockNumber)).first() # pylint: disable=not-callable # If table doesn't exist if result is None: