From bef2ba4bd4cb0e2085740171a15649af77aa289a Mon Sep 17 00:00:00 2001 From: Carter Green Date: Tue, 20 Jun 2023 08:34:10 -0500 Subject: [PATCH 01/17] FIX: Fix definition drop columns --- databento/common/data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/databento/common/data.py b/databento/common/data.py index 16bf1bf..fc6d1d1 100644 --- a/databento/common/data.py +++ b/databento/common/data.py @@ -284,9 +284,10 @@ def get_deriv_ba_fields(level: int) -> list[str]: DEFINITION_DROP_COLUMNS = [ "length", "rtype", - "reserved1", - "reserved2", - "reserved3", + "_reserved1", + "_reserved2", + "_reserved3", + "_reserved4", "dummy", ] From 36f5e76320b1a084d7dd198d8a9b477adfea468c Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Thu, 22 Jun 2023 18:49:06 +1000 Subject: [PATCH 02/17] DOC: Update client library CHANGELOG.md --- CHANGELOG.md | 104 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e22a33..2b9752f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,31 +1,44 @@ # Changelog ## 0.14.1 - 2023-06-16 + - Fixed issue where `DBNStore.to_df()` would raise an exception if no records were present - Fixed exception message when creating a DBNStore from an empty data source ## 0.14.0 - 2023-06-14 -- Added support for reusing a `Live` client to reconnect -- Added `metadata` property to `Live` + +#### Enhancements - Added `DatatbentoLiveProtocol` class +- Added `metadata` property to `Live` +- Added support for reusing a `Live` client to reconnect - Added support for emitting warnings in API response headers -- Changed iteration of `Live` to no longer yield DBN metadata -- Changed `Live` callbacks to no longer yield DBN metadata -- Fixed issue where `Historical.timeseries.get_range` would write empty files on error -- Fixed issue with `numpy` types not being handled in symbols field +- Relaxed 10 minute minimum request time range restriction - Upgraded `aiohttp` to 3.8.3 - Upgraded `numpy` to to 1.23.5 - Upgraded `pandas` to to 1.5.3 - Upgraded `requests` to to 2.28.1 - Upgraded `zstandard` to to 0.21.0 -- Removed 10 minute minimum request time range restriction + +#### Breaking changes - Removed support for Python 3.7 +- Renamed `symbol` to `raw_symbol` in definition schema when converting to a DataFrame +- Changed iteration of `Live` to no longer yield DBN metadata +- Changed `Live` callbacks to no longer yield DBN metadata + +#### Bug fixes +- Fixed issue where `Historical.timeseries.get_range` would write empty files on error +- Fixed issue with `numpy` types not being handled in symbols field +- Fixed optional `end` parameter for `batch.submit_job(...)` ## 0.13.0 - 2023-06-02 + +#### Enhancements - Added support for `statistics` schema - Added batch download support data files (`condition.json` and `symbology.json`) -- Upgraded `databento-dbn` to 0.6.1 - Renamed `booklevel` MBP field to `levels` for brevity and consistent naming +- Upgraded `databento-dbn` to 0.6.1 + +#### Breaking changes - Changed `flags` field to an unsigned int - Changed default of `ts_out` to `False` for `Live` client - Changed `instrument_class` DataFrame representation to be consistent with other `char` types @@ -34,85 +47,118 @@ - Removed support for legacy stypes ## 0.12.0 - 2023-05-01 + +#### Enhancements - Added `Live` client for connecting to Databento's live service +- Added `degraded`, `pending` and `missing` condition variants for `batch.get_dataset_condition` +- Added `last_modified_date` field to `batch.get_dataset_condition` response - Upgraded `databento-dbn` to 0.5.0 - Upgraded `DBNStore` to support mixed schema types to support live data + +#### Breaking changes - Changed iteration `DBNStore` to return record types from `databento-dbn` instead of numpy arrays -- Removed `dtype` property from `DBNStore` -- Removed `record_size` property from `DBNStore` - Renamed the `cost` field to `cost_usd` for `batch.submit_job` and `batch.list_jobs` (value now expressed as US dollars) -- Removed `bad` condition variant from `batch.get_dataset_condition` -- Added `degraded`, `pending` and `missing` condition variants for `batch.get_dataset_condition` -- Added `last_modified_date` field to `batch.get_dataset_condition` response - Renamed `product_id` field to `instrument_id` - Renamed `symbol` field in definitions to `raw_symbol` -- Deprecated `SType.PRODUCT_ID` to `SType.INSTRUMENT_ID` -- Deprecated `SType.NATIVE` to `SType.RAW_SYMBOL` -- Deprecated `SType.SMART` to `SType.PARENT` and `SType.CONTINUOUS` +- Removed `dtype` property from `DBNStore` +- Removed `record_size` property from `DBNStore` +- Removed `bad` condition variant from `batch.get_dataset_condition` - Removed unused `LiveGateway` enum - Removed `STATSTICS` from `Schema` enum - Removed `STATUS` from `Schema` enum - Removed `GATEWAY_ERROR` from `Schema` enum - Removed `SYMBOL_MAPPING` from `Schema` enum +#### Deprecations +- Deprecated `SType.PRODUCT_ID` to `SType.INSTRUMENT_ID` +- Deprecated `SType.NATIVE` to `SType.RAW_SYMBOL` +- Deprecated `SType.SMART` to `SType.PARENT` and `SType.CONTINUOUS` + ## 0.11.0 - 2023-04-13 + - Changed `end` and `end_date` to optional to support new forward-fill behaviour - Upgraded `zstandard` to 0.20.0 ## 0.10.0 - 2023-04-07 + +#### Enhancements +- Added support for `imbalance` schema +- Added `instrument_class`, `strike_price`, and `strike_price_currency` to definition + schema +- Changed parsing of `end` and `end_date` params throughout the API +- Improved exception messages for server and client timeouts - Upgraded `databento-dbn` to 0.4.3 + +#### Breaking changes - Renamed `Bento` class to `DBNStore` - Removed `metadata.list_compressions` (redundant with docs) - Removed `metadata.list_encodings` (redundant with docs) - Removed optional `start` and `end` params from `metadata.list_schemas` (redundant) - Removed `related` and `related_security_id` from definition schema -- Added `instrument_class`, `strike_price`, and `strike_price_currency` to definition - schema -- Added support for `imbalance` schema -- Improved exception messages for server and client timeouts ## 0.9.0 - 2023-03-10 -- Removed `record_count` property from Bento class -- Fixed bug in `Bento` where invalid metadata would prevent iteration + +#### Enhancements - Improved use of the logging module + +#### Breaking changes +- Removed `record_count` property from Bento class - Changed `metadata.get_dataset_condition` response to a list of condition per date +#### Bug fixes +- Fixed bug in `Bento` where invalid metadata would prevent iteration + ## 0.8.1 - 2023-03-05 -- Fixed bug in `Bento` iteration where multiple readers were created + +#### Enhancements - Added `from_dbn` convenience alias for loading DBN files +#### Bug fixes +- Fixed bug in `Bento` iteration where multiple readers were created + ## 0.8.0 - 2023-03-03 -- Integrated DBN encoding 0.3.2 -- Renamed `timeseries.stream` to `timeseries.get_range` -- Renamed `timeseries.stream_async` to `timeseries.get_range_async` -- Deprecated `timeseries.stream(...)` method -- Deprecated `timeseries.stream_async(...)` method + +#### Enhancements - Added `batch.list_files(...)` method - Added `batch.download(...)` method - Added `batch.download_async(...)` method +- Integrated DBN encoding 0.3.2 + +#### Breaking changes +- Dropped support for DBZ encoding +- Renamed `timeseries.stream` to `timeseries.get_range` +- Renamed `timeseries.stream_async` to `timeseries.get_range_async` - Changed `.to_df(...)` `pretty_ts` default argument to `True` - Changed `.to_df(...)` `pretty_px` default argument to `True` - Changed `.to_df(...)` `map_symbols` default argument to `True` -- Drop support for DBZ encoding + +#### Deprecations +- Deprecated `timeseries.stream(...)` method +- Deprecated `timeseries.stream_async(...)` method ## 0.7.0 - 2023-01-10 + - Added support for `definition` schema - Updated `Flags` enum - Upgraded `dbz-python` to 0.2.1 - Upgraded `zstandard` to 0.19.0 ## 0.6.0 - 2022-12-02 + - Added `metadata.get_dataset_condition` method to `Historical` client - Upgraded `dbz-python` to 0.2.0 ## 0.5.0 - 2022-11-07 + - Fixed dataframe columns for derived data schemas (dropped `channel_id`) - Fixed `batch.submit_job` requests for `dbz` encoding - Updated `quickstart.ipynb` jupyter notebook ## 0.4.0 - 2022-09-14 + - Upgraded `dbz-python` to 0.1.5 - Added `map_symbols` option for `.to_df()` (experimental) ## 0.3.0 - 2022-08-30 + - Initial release From 8ec2ec2c1e55875fd2f86a73d3b4d2a30fcc3f1c Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Thu, 22 Jun 2023 00:18:11 -0700 Subject: [PATCH 03/17] ADD: Add Live client symbology mapping --- CHANGELOG.md | 5 +++++ databento/live/client.py | 28 +++++++++++++++++++++++++++- databento/live/protocol.py | 18 ++++++++++++++++-- databento/live/session.py | 14 +++++++------- databento/version.py | 2 +- pyproject.toml | 2 +- tests/test_live_client.py | 12 ++++++++---- 7 files changed, 65 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b9752f..b6b388c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 0.15.0 - TBD + +### Enhancements +- Added `symbology_map` property to `Live` client + ## 0.14.1 - 2023-06-16 - Fixed issue where `DBNStore.to_df()` would raise an exception if no records were present diff --git a/databento/live/client.py b/databento/live/client.py index 8f96087..5088a42 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -88,7 +88,8 @@ def __init__( self._dbn_queue: DBNQueue = DBNQueue(maxsize=DEFAULT_QUEUE_SIZE) self._metadata: SessionMetadata = SessionMetadata() - self._user_callbacks: list[UserCallback] = [] + self._symbology_map: dict[int, str | int] = {} + self._user_callbacks: list[UserCallback] = [self._map_symbol] self._user_streams: list[IO[bytes]] = [] def factory() -> _SessionProtocol: @@ -237,6 +238,23 @@ def port(self) -> int: """ return self._port + @property + def symbology_map(self) -> dict[int, str | int]: + """ + Return the symbology map for this client session. A symbol mapping is + added when the client receives a SymbolMappingMsg. + + This can be used to transform an `instrument_id` in a DBN record + to the input symbology. + + Returns + ------- + dict[int, str | int] + A mapping of the exchange's instrument_id to the subscription symbology. + + """ + return self._symbology_map + @property def ts_out(self) -> bool: """ @@ -548,3 +566,11 @@ async def _shutdown(self) -> None: if self._session is None: return await self._session.wait_for_close() + self._symbology_map.clear() + + def _map_symbol(self, record: DBNRecord) -> None: + if isinstance(record, databento_dbn.SymbolMappingMsg): + out_symbol = record.stype_out_symbol + instrument_id = record.instrument_id + self._symbology_map[instrument_id] = record.stype_out_symbol + logger.info("added symbology mapping %s to %d", out_symbol, instrument_id) diff --git a/databento/live/protocol.py b/databento/live/protocol.py index 409a876..a077b01 100644 --- a/databento/live/protocol.py +++ b/databento/live/protocol.py @@ -312,8 +312,22 @@ def _process_dbn(self, data: bytes) -> None: logger.debug("dispatching %s", type(record).__name__) if isinstance(record, databento_dbn.Metadata): self.received_metadata(record) - else: - self.received_record(record) + continue + + if isinstance(record, databento_dbn.ErrorMsg): + logger.error( + "gateway error: %s", + record.err, + ) + if isinstance(record, databento_dbn.SystemMsg): + if record.is_heartbeat: + logger.debug("gateway heartbeat") + else: + logger.info( + "gateway message: %s", + record.msg, + ) + self.received_record(record) def _process_gateway(self, data: bytes) -> None: try: diff --git a/databento/live/session.py b/databento/live/session.py index 2507590..126a8a8 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -470,15 +470,15 @@ async def _connect_task( ), timeout=CONNECT_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError: + except asyncio.TimeoutError as exc: raise BentoError( f"Connection to {gateway}:{port} timed out after " f"{CONNECT_TIMEOUT_SECONDS} second(s).", - ) - except OSError: + ) from exc + except OSError as exc: raise BentoError( f"Connection to {gateway}:{port} failed.", - ) + ) from exc logger.debug( "connected to %s:%d", @@ -491,13 +491,13 @@ async def _connect_task( protocol.authenticated, timeout=AUTH_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError: + except asyncio.TimeoutError as exc: raise BentoError( f"Authentication with {gateway}:{port} timed out after " f"{AUTH_TIMEOUT_SECONDS} second(s).", - ) + ) from exc except ValueError as exc: - raise BentoError(f"User authentication failed: {str(exc)}") + raise BentoError(f"User authentication failed: {str(exc)}") from exc logger.info( "authentication with remote gateway completed", diff --git a/databento/version.py b/databento/version.py index f075dd3..9da2f8f 100644 --- a/databento/version.py +++ b/databento/version.py @@ -1 +1 @@ -__version__ = "0.14.1" +__version__ = "0.15.0" diff --git a/pyproject.toml b/pyproject.toml index 9bc052a..435d2f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databento" -version = "0.14.1" +version = "0.15.0" description = "Official Python client library for Databento" authors = [ "Databento ", diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 7a5f3ff..bc7d1a0 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -496,7 +496,7 @@ def callback(_: object) -> None: pass live_client.add_callback(callback) - assert live_client._user_callbacks == [callback] + assert callback in live_client._user_callbacks assert live_client._user_streams == [] @@ -509,7 +509,6 @@ def test_live_add_stream( stream = BytesIO() live_client.add_stream(stream) - assert live_client._user_callbacks == [] assert live_client._user_streams == [stream] @@ -581,7 +580,9 @@ async def test_live_async_iteration_backpressure( symbols="TEST", ) - monkeypatch.setattr(live_client._session._transport, "pause_reading", pause_mock:=MagicMock()) + monkeypatch.setattr( + live_client._session._transport, "pause_reading", pause_mock := MagicMock(), + ) live_client.start() it = live_client.__iter__() @@ -618,7 +619,9 @@ async def test_live_async_iteration_dropped( symbols="TEST", ) - monkeypatch.setattr(live_client._session._transport, "pause_reading", pause_mock:=MagicMock()) + monkeypatch.setattr( + live_client._session._transport, "pause_reading", pause_mock := MagicMock(), + ) live_client.start() it = live_client.__iter__() @@ -630,6 +633,7 @@ async def test_live_async_iteration_dropped( assert len(records) == 1 assert live_client._dbn_queue.empty() + @pytest.mark.skipif(platform.system() == "Windows", reason="flaky on windows runner") async def test_live_async_iteration_stop( live_client: client.Live, From 2ab8e486d77ddf9f7b4ec9e7277bc50f8c7cba65 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Thu, 22 Jun 2023 16:02:19 -0500 Subject: [PATCH 04/17] ADD: Add Live test for CRAM fail reconnect --- tests/test_live_client.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_live_client.py b/tests/test_live_client.py index bc7d1a0..89f09ef 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -137,7 +137,7 @@ def test_live_connection_cram_failure( test_api_key: str, ) -> None: """ - Test that a failed auth message due to an incorrect CRAM raies a + Test that a failed auth message due to an incorrect CRAM raises a BentoError. """ # Dork up the API key in the mock client to fail CRAM @@ -1029,3 +1029,38 @@ async def test_live_stream_with_reconnect( records = list(data) for record in records: assert isinstance(record, schema.get_record_type()) + +def test_live_connection_reconnect_cram_failure( + mock_live_server: MockLiveServer, + monkeypatch: pytest.MonkeyPatch, + test_api_key: str, +) -> None: + """ + Test that a failed connection can reconnect. + """ + # Dork up the API key in the mock client to fail CRAM + bucket_id = test_api_key[-BUCKET_ID_LENGTH:] + invalid_key = "db-invalidkey00000000000000FFFFF" + monkeypatch.setitem(mock_live_server._user_api_keys, bucket_id, invalid_key) + + live_client = client.Live( + key=test_api_key, + gateway=mock_live_server.host, + port=mock_live_server.port, + ) + + with pytest.raises(BentoError) as exc: + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + ) + + # Ensure this was an authentication error + exc.match(r"User authentication failed:") + + # Fix the key in the mock live server to connect + monkeypatch.setitem(mock_live_server._user_api_keys, bucket_id, test_api_key) + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + ) From d91880bd078e8a2c4f48c818cfd5629e4cc71ad9 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Thu, 22 Jun 2023 02:40:48 -0700 Subject: [PATCH 05/17] ADD: Add exception handler to live client --- CHANGELOG.md | 2 ++ databento/live/__init__.py | 6 +++- databento/live/client.py | 45 ++++++++++++++++++++--------- databento/live/protocol.py | 2 +- databento/live/session.py | 41 ++++++++++++++++---------- tests/test_live_client.py | 59 ++++++++++++++++++++++++++++++++++++-- 6 files changed, 122 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6b388c..8880b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### Enhancements - Added `symbology_map` property to `Live` client +- Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback +- Changed `Live.add_callback` and `Live.add_stream` `func` parameter to `record_callback` ## 0.14.1 - 2023-06-16 diff --git a/databento/live/__init__.py b/databento/live/__init__.py index 179a74a..a71e4f3 100644 --- a/databento/live/__init__.py +++ b/databento/live/__init__.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Callable, Union import databento_dbn @@ -6,6 +6,7 @@ AUTH_TIMEOUT_SECONDS: float = 2 CONNECT_TIMEOUT_SECONDS: float = 5 + DBNRecord = Union[ databento_dbn.MBOMsg, databento_dbn.MBP1Msg, @@ -19,3 +20,6 @@ databento_dbn.SystemMsg, databento_dbn.ErrorMsg, ] + +RecordCallback = Callable[[DBNRecord], None] +ExceptionCallback = Callable[[Exception], None] diff --git a/databento/live/client.py b/databento/live/client.py index 5088a42..bd8d98a 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from concurrent import futures from numbers import Number -from typing import IO, Callable +from typing import IO import databento_dbn @@ -23,6 +23,8 @@ from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string from databento.live import DBNRecord +from databento.live import ExceptionCallback +from databento.live import RecordCallback from databento.live.session import DEFAULT_REMOTE_PORT from databento.live.session import DBNQueue from databento.live.session import Session @@ -31,9 +33,6 @@ logger = logging.getLogger(__name__) - -UserCallback = Callable[[DBNRecord], None] - DEFAULT_QUEUE_SIZE = 2048 @@ -89,8 +88,8 @@ def __init__( self._dbn_queue: DBNQueue = DBNQueue(maxsize=DEFAULT_QUEUE_SIZE) self._metadata: SessionMetadata = SessionMetadata() self._symbology_map: dict[int, str | int] = {} - self._user_callbacks: list[UserCallback] = [self._map_symbol] - self._user_streams: list[IO[bytes]] = [] + self._user_callbacks: dict[RecordCallback, ExceptionCallback | None] = {} + self._user_streams: dict[IO[bytes], ExceptionCallback | None] = {} def factory() -> _SessionProtocol: return _SessionProtocol( @@ -269,15 +268,19 @@ def ts_out(self) -> bool: def add_callback( self, - func: UserCallback, + record_callback: RecordCallback, + exception_callback: ExceptionCallback | None = None, ) -> None: """ Add a callback for handling records. Parameters ---------- - func : Callable[[DBNRecord], None] + record_callback : Callable[[DBNRecord], None] A callback to register for handling live records as they arrive. + exception_callback : Callable[[Exception], None], optional + An error handling callback to process exceptions that are raised + in `record_callback`. Raises ------ @@ -289,13 +292,21 @@ def add_callback( Live.add_stream """ - if not callable(func): - raise ValueError(f"{func} is not callable") - callback_name = getattr(func, "__name__", str(func)) + if not callable(record_callback): + raise ValueError(f"{record_callback} is not callable") + + if exception_callback is not None and not callable(exception_callback): + raise ValueError(f"{exception_callback} is not callable") + + callback_name = getattr(record_callback, "__name__", str(record_callback)) logger.info("adding user callback %s", callback_name) - self._user_callbacks.append(func) + self._user_callbacks[record_callback] = exception_callback - def add_stream(self, stream: IO[bytes]) -> None: + def add_stream( + self, + stream: IO[bytes], + exception_callback: ExceptionCallback | None = None, + ) -> None: """ Add an IO stream to write records to. @@ -303,6 +314,9 @@ def add_stream(self, stream: IO[bytes]) -> None: ---------- stream : IO[bytes] The IO stream to write to when handling live records as they arrive. + exception_callback : Callable[[Exception], None], optional + An error handling callback to process exceptions that are raised + when writing to the stream. Raises ------ @@ -320,11 +334,14 @@ def add_stream(self, stream: IO[bytes]) -> None: if not hasattr(stream, "writable") or not stream.writable(): raise ValueError(f"{type(stream).__name__} is not a writable stream") + if exception_callback is not None and not callable(exception_callback): + raise ValueError(f"{exception_callback} is not callable") + stream_name = getattr(stream, "name", str(stream)) logger.info("adding user stream %s", stream_name) if self.metadata is not None: stream.write(bytes(self.metadata)) - self._user_streams.append(stream) + self._user_streams[stream] = exception_callback def start( self, diff --git a/databento/live/protocol.py b/databento/live/protocol.py index a077b01..9aa48c8 100644 --- a/databento/live/protocol.py +++ b/databento/live/protocol.py @@ -186,7 +186,7 @@ def eof_received(self) -> bool | None: asycnio.BufferedProtocol.eof_received """ - logger.info("received EOF file from remote") + logger.info("received EOF from remote") return super().eof_received() def get_buffer(self, sizehint: int) -> bytearray: diff --git a/databento/live/session.py b/databento/live/session.py index 126a8a8..c12203e 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -20,13 +20,14 @@ from databento.live import AUTH_TIMEOUT_SECONDS from databento.live import CONNECT_TIMEOUT_SECONDS from databento.live import DBNRecord +from databento.live import ExceptionCallback +from databento.live import RecordCallback from databento.live.protocol import DatabentoLiveProtocol logger = logging.getLogger(__name__) -UserCallback = Callable[[DBNRecord], None] DEFAULT_REMOTE_PORT = 13000 @@ -122,8 +123,8 @@ def __init__( api_key: str, dataset: Dataset | str, dbn_queue: DBNQueue, - user_callbacks: list[UserCallback], - user_streams: list[IO[bytes]], + user_callbacks: dict[RecordCallback, ExceptionCallback | None], + user_streams: dict[IO[bytes], ExceptionCallback | None], loop: asyncio.AbstractEventLoop, metadata: SessionMetadata, ts_out: bool = False, @@ -140,8 +141,10 @@ def __init__( def received_metadata(self, metadata: databento_dbn.Metadata) -> None: if not self._metadata: self._metadata.data = metadata - for stream in self._user_streams: - task = self._loop.create_task(self._stream_task(stream, metadata)) + for stream, exc_callback in self._user_streams.items(): + task = self._loop.create_task( + self._stream_task(stream, metadata, exc_callback), + ) task.add_done_callback(self._tasks.remove) self._tasks.add(task) else: @@ -149,13 +152,17 @@ def received_metadata(self, metadata: databento_dbn.Metadata) -> None: return super().received_metadata(metadata) def received_record(self, record: DBNRecord) -> None: - for callback in self._user_callbacks: - task = self._loop.create_task(self._callback_task(callback, record)) + for callback, exc_callback in self._user_callbacks.items(): + task = self._loop.create_task( + self._callback_task(callback, record, exc_callback), + ) task.add_done_callback(self._tasks.remove) self._tasks.add(task) - for stream in self._user_streams: - task = self._loop.create_task(self._stream_task(stream, record)) + for stream, exc_callback in self._user_streams.items(): + task = self._loop.create_task( + self._stream_task(stream, record, exc_callback), + ) task.add_done_callback(self._tasks.remove) self._tasks.add(task) @@ -180,26 +187,29 @@ def received_record(self, record: DBNRecord) -> None: async def _callback_task( self, - func: UserCallback, + record_callback: RecordCallback, record: DBNRecord, + exception_callback: ExceptionCallback | None, ) -> None: try: - func(record) + record_callback(record) except Exception as exc: logger.error( "error dispatching %s to `%s` callback", type(record).__name__, - func.__name__, + record_callback.__name__, exc_info=exc, ) - raise + if exception_callback is not None: + self._loop.call_soon_threadsafe(exception_callback, exc) async def _stream_task( self, stream: IO[bytes], record: databento_dbn.Metadata | DBNRecord, + exc_callback: ExceptionCallback | None, ) -> None: - has_ts_out = self._metadata and self._metadata.data.ts_out + has_ts_out = self._metadata.data and self._metadata.data.ts_out try: stream.write(bytes(record)) if not isinstance(record, databento_dbn.Metadata) and has_ts_out: @@ -212,7 +222,8 @@ async def _stream_task( stream_name, exc_info=exc, ) - raise + if exc_callback is not None: + self._loop.call_soon_threadsafe(exc_callback, exc) async def wait_for_processing(self) -> None: while self._tasks: diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 89f09ef..e275233 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -497,7 +497,8 @@ def callback(_: object) -> None: live_client.add_callback(callback) assert callback in live_client._user_callbacks - assert live_client._user_streams == [] + assert live_client._user_callbacks[callback] is None + assert live_client._user_streams == {} def test_live_add_stream( @@ -509,7 +510,9 @@ def test_live_add_stream( stream = BytesIO() live_client.add_stream(stream) - assert live_client._user_streams == [stream] + assert stream in live_client._user_streams + assert live_client._user_streams[stream] is None + assert live_client._user_callbacks == {} def test_live_add_stream_invalid( @@ -1064,3 +1067,55 @@ def test_live_connection_reconnect_cram_failure( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, ) + +async def test_live_callback_exception_handler( + live_client: client.Live, +) -> None: + """ + Test exceptions that occur during callbacks are dispatched to the assigned + exception handler. + """ + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.RAW_SYMBOL, + symbols="TEST", + ) + + exceptions: list[Exception] = [] + + def callback(_: DBNRecord) -> None: + raise RuntimeError("this is a test") + + live_client.add_callback(callback, exceptions.append) + + live_client.start() + + await live_client.wait_for_close() + assert len(exceptions) == 4 + + +async def test_live_stream_exception_handler( + live_client: client.Live, +) -> None: + """ + Test exceptions that occur during stream writes are dispatched to the + assigned exception handler. + """ + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.RAW_SYMBOL, + symbols="TEST", + ) + + exceptions: list[Exception] = [] + + stream = BytesIO() + live_client.add_stream(stream, exceptions.append) + stream.close() + + live_client.start() + + await live_client.wait_for_close() + assert len(exceptions) == 5 # extra write from metadata From 6dc3dc1c58bf059db5b9077deaae98f6fea97788 Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Mon, 26 Jun 2023 09:59:05 +1000 Subject: [PATCH 06/17] MOD: Support form data for POST requests --- databento/common/parsing.py | 2 +- databento/historical/api/batch.py | 51 +++++++++------------- databento/historical/api/symbology.py | 24 +++++------ databento/historical/api/timeseries.py | 60 +++++++++++++------------- databento/historical/http.py | 18 ++++---- tests/test_historical_batch.py | 46 ++++++++------------ tests/test_historical_client.py | 48 ++++++++++----------- tests/test_historical_timeseries.py | 56 ++++++++++++------------ 8 files changed, 145 insertions(+), 160 deletions(-) diff --git a/databento/common/parsing.py b/databento/common/parsing.py index d6409cf..a4008fb 100644 --- a/databento/common/parsing.py +++ b/databento/common/parsing.py @@ -233,7 +233,7 @@ def datetime_to_date_string(value: pd.Timestamp | date | str | int) -> str: def optional_datetime_to_string( - value: pd.Timestamp | date | str | int, + value: pd.Timestamp | date | str | int | None, ) -> str | None: """ Return a valid datetime string from the given value (if not None). diff --git a/databento/historical/api/batch.py b/databento/historical/api/batch.py index 260cb2a..fe4cc39 100644 --- a/databento/historical/api/batch.py +++ b/databento/historical/api/batch.py @@ -119,43 +119,34 @@ def submit_job( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) - params: list[tuple[str, str | None]] = [ - ("dataset", validate_semantic_string(dataset, "dataset")), - ("start", datetime_to_string(start)), - ("end", optional_datetime_to_string(end)), - ("symbols", str(symbols_list)), - ("schema", str(validate_enum(schema, Schema, "schema"))), - ("stype_in", str(stype_in_valid)), - ("stype_out", str(validate_enum(stype_out, SType, "stype_out"))), - ("encoding", str(validate_enum(encoding, Encoding, "encoding"))), - ( - "compression", - str(validate_enum(compression, Compression, "compression")) - if compression - else None, - ), - ( - "split_duration", - str(validate_enum(split_duration, SplitDuration, "split_duration")), - ), - ( - "packaging", - str(validate_enum(packaging, Packaging, "packaging")) - if packaging - else None, - ), - ("delivery", str(validate_enum(delivery, Delivery, "delivery"))), - ] + data: dict[str, object | None] = { + "dataset": validate_semantic_string(dataset, "dataset"), + "start": datetime_to_string(start), + "end": optional_datetime_to_string(end), + "symbols": str(symbols_list), + "schema": str(validate_enum(schema, Schema, "schema")), + "stype_in": str(stype_in_valid), + "stype_out": str(validate_enum(stype_out, SType, "stype_out")), + "encoding": str(validate_enum(encoding, Encoding, "encoding")), + "compression": str(validate_enum(compression, Compression, "compression")) + if compression + else None, + "split_duration": str(validate_enum(split_duration, SplitDuration, "split_duration")), + "packaging": str(validate_enum(packaging, Packaging, "packaging")) + if packaging + else None, + "delivery": str(validate_enum(delivery, Delivery, "delivery")), + } # Optional Parameters if limit is not None: - params.append(("limit", str(limit))) + data["limit"] = str(limit) if split_size is not None: - params.append(("split_size", str(split_size))) + data["split_size"] = str(split_size) return self._post( url=self._base_url + ".submit_job", - params=params, + data=data, basic_auth=True, ).json() diff --git a/databento/historical/api/symbology.py b/databento/historical/api/symbology.py index 8c41ad8..bfb5a0d 100644 --- a/databento/historical/api/symbology.py +++ b/databento/historical/api/symbology.py @@ -38,7 +38,7 @@ def resolve( """ Request symbology mappings resolution from Databento. - Makes a `GET /symbology.resolve` HTTP request. + Makes a `POST /symbology.resolve` HTTP request. Parameters ---------- @@ -66,19 +66,19 @@ def resolve( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) - params: list[tuple[str, str | None]] = [ - ("dataset", validate_semantic_string(dataset, "dataset")), - ("symbols", symbols_list), - ("stype_in", str(stype_in_valid)), - ("stype_out", str(validate_enum(stype_out, SType, "stype_out"))), - ("start_date", datetime_to_date_string(start_date)), - ("end_date", optional_date_to_string(end_date)), - ("default_value", default_value), - ] + data: dict[str, object | None] = { + "dataset": validate_semantic_string(dataset, "dataset"), + "symbols": symbols_list, + "stype_in": str(stype_in_valid), + "stype_out": str(validate_enum(stype_out, SType, "stype_out")), + "start_date": datetime_to_date_string(start_date), + "end_date": optional_date_to_string(end_date), + "default_value": default_value, + } - response: Response = self._get( + response: Response = self._post( url=self._base_url + ".resolve", - params=params, + data=data, basic_auth=True, ) diff --git a/databento/historical/api/timeseries.py b/databento/historical/api/timeseries.py index dc6e060..e0738b9 100644 --- a/databento/historical/api/timeseries.py +++ b/databento/historical/api/timeseries.py @@ -31,7 +31,7 @@ def __init__(self, key: str, gateway: str) -> None: def get_range( self, - dataset: Dataset | str | None, + dataset: Dataset | str, start: pd.Timestamp | date | str | int, end: pd.Timestamp | date | str | int | None = None, symbols: list[str] | str | None = None, @@ -44,7 +44,7 @@ def get_range( """ Request a historical time series data stream from Databento. - Makes a `GET /timeseries.get_range` HTTP request. + Makes a `POST /timeseries.get_range` HTTP request. Primary method for getting historical intraday market data, daily data, instrument definitions and market status data directly into your application. @@ -99,32 +99,32 @@ def get_range( schema_valid = validate_enum(schema, Schema, "schema") start_valid = datetime_to_string(start) end_valid = optional_datetime_to_string(end) - params: list[tuple[str, str | None]] = [ - ("dataset", validate_semantic_string(dataset, "dataset")), - ("start", start_valid), - ("end", end_valid), - ("symbols", symbols_list), - ("schema", str(schema_valid)), - ("stype_in", str(stype_in_valid)), - ("stype_out", str(validate_enum(stype_out, SType, "stype_out"))), - ("encoding", str(Encoding.DBN)), # Always request dbn - ("compression", str(Compression.ZSTD)), # Always request zstd - ] + data: dict[str, object | None] = { + "dataset": validate_semantic_string(dataset, "dataset"), + "start": start_valid, + "end": end_valid, + "symbols": symbols_list, + "schema": str(schema_valid), + "stype_in": str(stype_in_valid), + "stype_out": str(validate_enum(stype_out, SType, "stype_out")), + "encoding": str(Encoding.DBN), # Always request dbn + "compression": str(Compression.ZSTD), # Always request zstd + } # Optional Parameters if limit is not None: - params.append(("limit", str(limit))) + data["limit"] = str(limit) return self._stream( url=self._base_url + ".get_range", - params=params, + data=data, basic_auth=True, path=path, ) async def get_range_async( self, - dataset: Dataset | str | None, + dataset: Dataset | str, start: pd.Timestamp | date | str | int, end: pd.Timestamp | date | str | int | None = None, symbols: list[str] | str | None = None, @@ -138,7 +138,7 @@ async def get_range_async( Asynchronously request a historical time series data stream from Databento. - Makes a `GET /timeseries.get_range` HTTP request. + Makes a `POST /timeseries.get_range` HTTP request. Primary method for getting historical intraday market data, daily data, instrument definitions and market status data directly into your application. @@ -193,25 +193,25 @@ async def get_range_async( schema_valid = validate_enum(schema, Schema, "schema") start_valid = datetime_to_string(start) end_valid = optional_datetime_to_string(end) - params: list[tuple[str, str | None]] = [ - ("dataset", validate_semantic_string(dataset, "dataset")), - ("start", start_valid), - ("end", end_valid), - ("symbols", symbols_list), - ("schema", str(schema_valid)), - ("stype_in", str(stype_in_valid)), - ("stype_out", str(validate_enum(stype_out, SType, "stype_out"))), - ("encoding", str(Encoding.DBN)), # Always request dbn - ("compression", str(Compression.ZSTD)), # Always request zstd - ] + data: dict[str, object | None] = { + "dataset": validate_semantic_string(dataset, "dataset"), + "start": start_valid, + "end": end_valid, + "symbols": symbols_list, + "schema": str(schema_valid), + "stype_in": str(stype_in_valid), + "stype_out": str(validate_enum(stype_out, SType, "stype_out")), + "encoding": str(Encoding.DBN), # Always request dbn + "compression": str(Compression.ZSTD), # Always request zstd + } # Optional Parameters if limit is not None: - params.append(("limit", str(limit))) + data["limit"] = str(limit) return await self._stream_async( url=self._base_url + ".get_range", - params=params, + data=data, basic_auth=True, path=path, ) diff --git a/databento/historical/http.py b/databento/historical/http.py index 805f5df..dc946b6 100644 --- a/databento/historical/http.py +++ b/databento/historical/http.py @@ -24,7 +24,7 @@ from databento.version import __version__ -_32KB = 1024 * 32 # 32_768 +_32KIB = 1024 * 32 # 32_768 WARNING_HEADER_FIELD: str = "X-Warning" @@ -94,6 +94,7 @@ async def _get_json_async( def _post( self, url: str, + data: dict[str, object | None] | None = None, params: Iterable[tuple[str, str | None]] | None = None, basic_auth: bool = False, ) -> Response: @@ -101,6 +102,7 @@ def _post( with requests.post( url=url, + data=data, params=params, headers=self._headers, auth=HTTPBasicAuth(username=self._key, password="") if basic_auth else None, @@ -113,15 +115,15 @@ def _post( def _stream( self, url: str, - params: Iterable[tuple[str, str | None]], + data: dict[str, object | None], basic_auth: bool, path: PathLike[str] | str | None = None, ) -> DBNStore: self._check_api_key() - with requests.get( + with requests.post( url=url, - params=params, + data=data, headers=self._headers, auth=HTTPBasicAuth(username=self._key, password="") if basic_auth else None, timeout=(self.TIMEOUT, self.TIMEOUT), @@ -135,7 +137,7 @@ def _stream( else: writer = open(path, "x+b") - for chunk in response.iter_content(chunk_size=_32KB): + for chunk in response.iter_content(chunk_size=_32KIB): writer.write(chunk) if path is None: @@ -148,16 +150,16 @@ def _stream( async def _stream_async( self, url: str, - params: Iterable[tuple[str, str | None]], + data: dict[str, object | None] | None, basic_auth: bool, path: PathLike[str] | str | None = None, ) -> DBNStore: self._check_api_key() async with aiohttp.ClientSession() as session: - async with session.get( + async with session.post( url=url, - params=[x for x in params if x[1] is not None], + data=data, headers=self._headers, auth=aiohttp.BasicAuth(login=self._key, password="", encoding="utf-8") if basic_auth diff --git a/tests/test_historical_batch.py b/tests/test_historical_batch.py index 00bfc5d..de37762 100644 --- a/tests/test_historical_batch.py +++ b/tests/test_historical_batch.py @@ -77,27 +77,25 @@ def test_batch_submit_job_sends_expected_request( # Assert call = mocked_post.call_args.kwargs - assert ( - call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.submit_job" - ) + assert call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.submit_job" assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) - assert call["params"] == [ - ("dataset", "GLBX.MDP3"), - ("start", "2020-12-28T12:00"), - ("end", "2020-12-29"), - ("symbols", "ESH1"), - ("schema", "trades"), - ("stype_in", "raw_symbol"), - ("stype_out", "instrument_id"), - ("encoding", "csv"), - ("compression", "zstd"), - ("split_duration", "day"), - ("packaging", "none"), - ("delivery", "download"), - ("split_size", "10000000000"), - ] + assert call["data"] == { + "dataset": "GLBX.MDP3", + "start": "2020-12-28T12:00", + "end": "2020-12-29", + "symbols": "ESH1", + "schema": "trades", + "stype_in": "raw_symbol", + "stype_out": "instrument_id", + "encoding": "csv", + "compression": "zstd", + "split_duration": "day", + "packaging": "none", + "delivery": "download", + "split_size": "10000000000", + } assert call["timeout"] == (100, 100) assert isinstance(call["auth"], requests.auth.HTTPBasicAuth) @@ -114,9 +112,7 @@ def test_batch_list_jobs_sends_expected_request( # Assert call = mocked_get.call_args.kwargs - assert ( - call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_jobs" - ) + assert call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_jobs" assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) @@ -141,9 +137,7 @@ def test_batch_list_files_sends_expected_request( # Assert call = mocked_get.call_args.kwargs - assert ( - call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_files" - ) + assert call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_files" assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) @@ -174,9 +168,7 @@ def test_batch_download_single_file_sends_expected_request( # Assert call = mocked_get.call_args.kwargs - assert ( - call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_files" - ) + assert call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/batch.list_files" assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) diff --git a/tests/test_historical_client.py b/tests/test_historical_client.py index eb7b327..257885d 100644 --- a/tests/test_historical_client.py +++ b/tests/test_historical_client.py @@ -106,7 +106,7 @@ def test_re_request_symbology_makes_expected_request( historical_client: Historical, ) -> None: # Arrange - monkeypatch.setattr(requests, "get", mocked_get := MagicMock()) + monkeypatch.setattr(requests, "post", mocked_post := MagicMock()) bento = DBNStore.from_file(path=test_data_path(Schema.MBO)) @@ -114,20 +114,20 @@ def test_re_request_symbology_makes_expected_request( bento.request_symbology(historical_client) # Assert - call = mocked_get.call_args.kwargs + call = mocked_post.call_args.kwargs assert ( call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/symbology.resolve" ) - assert call["params"] == [ - ("dataset", "GLBX.MDP3"), - ("symbols", "ESH1"), - ("stype_in", "raw_symbol"), - ("stype_out", "instrument_id"), - ("start_date", "2020-12-28"), - ("end_date", "2020-12-29"), - ("default_value", ""), - ] + assert call["data"] == { + "dataset": "GLBX.MDP3", + "symbols": "ESH1", + "stype_in": "raw_symbol", + "stype_out": "instrument_id", + "start_date": "2020-12-28", + "end_date": "2020-12-29", + "default_value": "", + } assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) @@ -142,7 +142,7 @@ def test_request_full_definitions_expected_request( historical_client: Historical, ) -> None: # Arrange - monkeypatch.setattr(requests, "get", mocked_get := MagicMock()) + monkeypatch.setattr(requests, "post", mocked_post := MagicMock()) # Create an MBO bento bento = DBNStore.from_file(path=test_data_path(Schema.MBO)) @@ -159,22 +159,22 @@ def test_request_full_definitions_expected_request( definition_bento = bento.request_full_definitions(historical_client) # Assert - call = mocked_get.call_args.kwargs + call = mocked_post.call_args.kwargs assert ( call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/timeseries.get_range" ) - assert call["params"] == [ - ("dataset", "GLBX.MDP3"), - ("start", "2020-12-28T13:00:00+00:00"), - ("end", "2020-12-29T13:01:00+00:00"), - ("symbols", "ESH1"), - ("schema", "definition"), - ("stype_in", "raw_symbol"), - ("stype_out", "instrument_id"), - ("encoding", "dbn"), - ("compression", "zstd"), - ] + assert call["data"] == { + "dataset": "GLBX.MDP3", + "start": "2020-12-28T13:00:00+00:00", + "end": "2020-12-29T13:01:00+00:00", + "symbols": "ESH1", + "schema": "definition", + "stype_in": "raw_symbol", + "stype_out": "instrument_id", + "encoding": "dbn", + "compression": "zstd", + } assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) diff --git a/tests/test_historical_timeseries.py b/tests/test_historical_timeseries.py index 16fe776..7286447 100644 --- a/tests/test_historical_timeseries.py +++ b/tests/test_historical_timeseries.py @@ -63,7 +63,7 @@ def test_get_range_error_no_file_write( # Arrange mocked_response = MagicMock() mocked_response.__enter__.return_value = MagicMock(status_code=500) - monkeypatch.setattr(requests, "get", MagicMock(return_value=mocked_response)) + monkeypatch.setattr(requests, "post", MagicMock(return_value=mocked_response)) output_file = tmp_path / "output.dbn" @@ -89,7 +89,7 @@ def test_get_range_sends_expected_request( historical_client: Historical, ) -> None: # Arrange - monkeypatch.setattr(requests, "get", mocked_get := MagicMock()) + monkeypatch.setattr(requests, "post", mocked_post := MagicMock()) stream_bytes = test_data(Schema.TRADES) monkeypatch.setattr( @@ -109,7 +109,7 @@ def test_get_range_sends_expected_request( ) # Assert - call = mocked_get.call_args.kwargs + call = mocked_post.call_args.kwargs assert ( call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/timeseries.get_range" @@ -117,17 +117,17 @@ def test_get_range_sends_expected_request( assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) - assert call["params"] == [ - ("dataset", "GLBX.MDP3"), - ("start", "2020-12-28T12:00"), - ("end", "2020-12-29"), - ("symbols", "ES.c.0"), - ("schema", "trades"), - ("stype_in", "continuous"), - ("stype_out", "instrument_id"), - ("encoding", "dbn"), - ("compression", "zstd"), - ] + assert call["data"] == { + "dataset": "GLBX.MDP3", + "start": "2020-12-28T12:00", + "end": "2020-12-29", + "symbols": "ES.c.0", + "schema": "trades", + "stype_in": "continuous", + "stype_out": "instrument_id", + "encoding": "dbn", + "compression": "zstd", + } assert call["timeout"] == (100, 100) assert isinstance(call["auth"], requests.auth.HTTPBasicAuth) @@ -138,7 +138,7 @@ def test_get_range_with_limit_sends_expected_request( historical_client: Historical, ) -> None: # Arrange - monkeypatch.setattr(requests, "get", mocked_get := MagicMock()) + monkeypatch.setattr(requests, "post", mocked_post := MagicMock()) # Mock from_bytes with the definition stub stream_bytes = test_data(Schema.TRADES) @@ -159,7 +159,7 @@ def test_get_range_with_limit_sends_expected_request( ) # Assert - call = mocked_get.call_args.kwargs + call = mocked_post.call_args.kwargs assert ( call["url"] == f"{historical_client.gateway}/v{db.API_VERSION}/timeseries.get_range" @@ -167,17 +167,17 @@ def test_get_range_with_limit_sends_expected_request( assert sorted(call["headers"].keys()) == ["accept", "user-agent"] assert call["headers"]["accept"] == "application/json" assert all(v in call["headers"]["user-agent"] for v in ("Databento/", "Python/")) - assert call["params"] == [ - ("dataset", "GLBX.MDP3"), - ("start", "2020-12-28T12:00"), - ("end", "2020-12-29"), - ("symbols", "ESH1"), - ("schema", "trades"), - ("stype_in", "raw_symbol"), - ("stype_out", "instrument_id"), - ("encoding", "dbn"), - ("compression", "zstd"), - ("limit", "1000000"), - ] + assert call["data"] == { + "dataset": "GLBX.MDP3", + "start": "2020-12-28T12:00", + "end": "2020-12-29", + "limit": "1000000", + "symbols": "ESH1", + "schema": "trades", + "stype_in": "raw_symbol", + "stype_out": "instrument_id", + "encoding": "dbn", + "compression": "zstd", + } assert call["timeout"] == (100, 100) assert isinstance(call["auth"], requests.auth.HTTPBasicAuth) From 601cffdc724954f7dafec83f0da46089a8313555 Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Mon, 26 Jun 2023 14:25:38 +1000 Subject: [PATCH 07/17] FIX: Fix minor typos in release notes --- CHANGELOG.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8880b14..68407aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,10 +21,10 @@ - Added support for emitting warnings in API response headers - Relaxed 10 minute minimum request time range restriction - Upgraded `aiohttp` to 3.8.3 -- Upgraded `numpy` to to 1.23.5 -- Upgraded `pandas` to to 1.5.3 -- Upgraded `requests` to to 2.28.1 -- Upgraded `zstandard` to to 0.21.0 +- Upgraded `numpy` to 1.23.5 +- Upgraded `pandas` to 1.5.3 +- Upgraded `requests` to 2.28.1 +- Upgraded `zstandard` to 0.21.0 #### Breaking changes - Removed support for Python 3.7 From 2df19768889387deeb9b6303c866967a92b626b0 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Fri, 23 Jun 2023 11:12:23 -0500 Subject: [PATCH 08/17] MOD: Update workflow for new release format --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 17b737c..1d5d1f0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,7 +27,7 @@ jobs: echo "RELEASE_NAME=$(poetry version)" >> $GITHUB_ENV echo "TAG_NAME=v$(poetry version -s)" >> $GITHUB_ENV echo "## Release notes" > NOTES.md - sed -n '4,/^$/p' CHANGELOG.md >> NOTES.md + sed -n '/^## /{n; :a; /^## /q; p; n; ba}' CHANGELOG.md >> NOTES.md - name: Release uses: softprops/action-gh-release@v1 From e9e9c11216c1144e96bad4c0a9bbcc57ed6c3aa8 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Tue, 27 Jun 2023 15:41:35 -0500 Subject: [PATCH 09/17] FIX: Fix missing map_symbols callback --- databento/live/client.py | 4 +++- tests/test_live_client.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/databento/live/client.py b/databento/live/client.py index bd8d98a..023e915 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -88,7 +88,9 @@ def __init__( self._dbn_queue: DBNQueue = DBNQueue(maxsize=DEFAULT_QUEUE_SIZE) self._metadata: SessionMetadata = SessionMetadata() self._symbology_map: dict[int, str | int] = {} - self._user_callbacks: dict[RecordCallback, ExceptionCallback | None] = {} + self._user_callbacks: dict[RecordCallback, ExceptionCallback | None] = { + self._map_symbol: None, + } self._user_streams: dict[IO[bytes], ExceptionCallback | None] = {} def factory() -> _SessionProtocol: diff --git a/tests/test_live_client.py b/tests/test_live_client.py index e275233..0c6e71d 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -179,6 +179,7 @@ def test_live_creation( port=mock_live_server.port, ) + # Subscribe to connect live_client.subscribe( dataset=dataset, schema=Schema.MBO, @@ -189,6 +190,7 @@ def test_live_creation( assert live_client._key == test_api_key assert live_client.dataset == dataset assert live_client.is_connected() is True + assert live_client._map_symbol in live_client._user_callbacks def test_live_connect_auth( @@ -512,7 +514,6 @@ def test_live_add_stream( live_client.add_stream(stream) assert stream in live_client._user_streams assert live_client._user_streams[stream] is None - assert live_client._user_callbacks == {} def test_live_add_stream_invalid( From 76f6749f038e0699f116193084a882626cd6c598 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Fri, 23 Jun 2023 10:24:38 -0500 Subject: [PATCH 10/17] MOD: Upgrade client to databento-dbn 0.7.1 --- CHANGELOG.md | 8 ++- databento/__init__.py | 8 +-- databento/common/data.py | 30 +++++++- databento/common/dbnstore.py | 9 +-- databento/common/enums.py | 98 +------------------------- databento/common/parsing.py | 2 +- databento/common/validation.py | 6 +- databento/historical/api/batch.py | 8 +-- databento/historical/api/metadata.py | 14 ++-- databento/historical/api/symbology.py | 2 +- databento/historical/api/timeseries.py | 8 +-- databento/live/client.py | 4 +- databento/live/gateway.py | 16 ++--- databento/live/protocol.py | 4 +- databento/live/session.py | 4 +- pyproject.toml | 2 +- tests/conftest.py | 7 +- tests/data/generator.py | 2 +- tests/mock_live_server.py | 2 +- tests/test_bento_data_source.py | 6 +- tests/test_common_enums.py | 20 ++++-- tests/test_common_parsing.py | 20 ++++-- tests/test_common_validation.py | 4 +- tests/test_historical_bento.py | 18 ++--- tests/test_historical_client.py | 2 +- tests/test_historical_metadata.py | 2 +- tests/test_historical_timeseries.py | 2 +- tests/test_live_client.py | 41 +++++++---- tests/test_live_gateway_messages.py | 6 +- tests/test_live_protocol.py | 12 ++-- 30 files changed, 169 insertions(+), 198 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68407aa..b2e84ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,14 @@ ## 0.15.0 - TBD -### Enhancements +#### Enhancements - Added `symbology_map` property to `Live` client - Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback -- Changed `Live.add_callback` and `Live.add_stream` `func` parameter to `record_callback` +- Upgraded `databento-dbn` to 0.7.1 +- Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` + +#### Breaking changes +- Renamed `func` parameter to `record_callback` for `Live.add_callback` and `Live.add_stream` ## 0.14.1 - 2023-06-16 diff --git a/databento/__init__.py b/databento/__init__.py index 371feea..1412fcc 100644 --- a/databento/__init__.py +++ b/databento/__init__.py @@ -1,6 +1,8 @@ import logging import warnings +from databento_dbn import Compression +from databento_dbn import Encoding from databento_dbn import ErrorMsg from databento_dbn import ImbalanceMsg from databento_dbn import InstrumentDefMsg @@ -9,25 +11,23 @@ from databento_dbn import MBP10Msg from databento_dbn import Metadata from databento_dbn import OHLCVMsg +from databento_dbn import Schema from databento_dbn import StatMsg +from databento_dbn import SType from databento_dbn import SymbolMappingMsg from databento_dbn import SystemMsg from databento_dbn import TradeMsg from databento.common import bentologging from databento.common.dbnstore import DBNStore -from databento.common.enums import Compression from databento.common.enums import Dataset from databento.common.enums import Delivery -from databento.common.enums import Encoding from databento.common.enums import FeedMode from databento.common.enums import HistoricalGateway from databento.common.enums import Packaging from databento.common.enums import RecordFlags from databento.common.enums import RollRule -from databento.common.enums import Schema from databento.common.enums import SplitDuration -from databento.common.enums import SType from databento.common.enums import SymbologyResolution from databento.common.error import BentoClientError from databento.common.error import BentoError diff --git a/databento/common/data.py b/databento/common/data.py index fc6d1d1..5344d18 100644 --- a/databento/common/data.py +++ b/databento/common/data.py @@ -1,9 +1,33 @@ from __future__ import annotations import numpy as np - -from databento.common.enums import Schema - +from databento_dbn import ImbalanceMsg +from databento_dbn import InstrumentDefMsg +from databento_dbn import MBOMsg +from databento_dbn import MBP1Msg +from databento_dbn import MBP10Msg +from databento_dbn import OHLCVMsg +from databento_dbn import Schema +from databento_dbn import StatMsg +from databento_dbn import TradeMsg + +from databento.live import DBNRecord + + +SCHEMA_STRUCT_MAP: dict[Schema, type[DBNRecord]] = { + Schema.DEFINITION: InstrumentDefMsg, + Schema.IMBALANCE: ImbalanceMsg, + Schema.MBO: MBOMsg, + Schema.MBP_1: MBP1Msg, + Schema.MBP_10: MBP10Msg, + Schema.OHLCV_1S: OHLCVMsg, + Schema.OHLCV_1M: OHLCVMsg, + Schema.OHLCV_1H: OHLCVMsg, + Schema.OHLCV_1D: OHLCVMsg, + Schema.STATISTICS: StatMsg, + Schema.TBBO: MBP1Msg, + Schema.TRADES: TradeMsg, +} ################################################################################ # DBN struct schema diff --git a/databento/common/dbnstore.py b/databento/common/dbnstore.py index 22444ec..a28248d 100644 --- a/databento/common/dbnstore.py +++ b/databento/common/dbnstore.py @@ -19,9 +19,12 @@ import numpy as np import pandas as pd import zstandard +from databento_dbn import Compression from databento_dbn import DBNDecoder from databento_dbn import ErrorMsg from databento_dbn import Metadata +from databento_dbn import Schema +from databento_dbn import SType from databento_dbn import SymbolMappingMsg from databento_dbn import SystemMsg @@ -30,10 +33,8 @@ from databento.common.data import DEFINITION_PRICE_COLUMNS from databento.common.data import DEFINITION_TYPE_MAX_MAP from databento.common.data import DERIV_SCHEMAS +from databento.common.data import SCHEMA_STRUCT_MAP from databento.common.data import STRUCT_MAP -from databento.common.enums import Compression -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.common.symbology import InstrumentIdMappingInterval from databento.common.validation import validate_maybe_enum @@ -1059,7 +1060,7 @@ def to_ndarray( schema = self.schema schema_records = filter( - lambda r: isinstance(r, schema.get_record_type()), # type: ignore + lambda r: isinstance(r, SCHEMA_STRUCT_MAP[schema]), # type: ignore self, ) diff --git a/databento/common/enums.py b/databento/common/enums.py index 3f57e5c..23e3110 100644 --- a/databento/common/enums.py +++ b/databento/common/enums.py @@ -6,17 +6,6 @@ from enum import unique from typing import Callable, TypeVar -from databento_dbn import ImbalanceMsg -from databento_dbn import InstrumentDefMsg -from databento_dbn import MBOMsg -from databento_dbn import MBP1Msg -from databento_dbn import MBP10Msg -from databento_dbn import OHLCVMsg -from databento_dbn import StatMsg -from databento_dbn import TradeMsg - -from databento.live import DBNRecord - M = TypeVar("M", bound=Enum) @@ -71,7 +60,7 @@ def _cast_str(value: object) -> str: def coerced_new(enum: type[M], value: object) -> M: if value is None: - raise TypeError( + raise ValueError( f"value `{value}` is not coercible to {enum_type.__name__}.", ) try: @@ -111,7 +100,6 @@ def __str__(self) -> str: return getattr(self, "name").lower() return getattr(self, "value") - @unique @coercible class HistoricalGateway(StringyMixin, str, Enum): @@ -145,77 +133,6 @@ class Dataset(StringyMixin, str, Enum): XNAS_ITCH = "XNAS.ITCH" -@unique -@coercible -class Schema(StringyMixin, str, Enum): - """ - Represents a data record schema. - """ - - MBO = "mbo" - MBP_1 = "mbp-1" - MBP_10 = "mbp-10" - TBBO = "tbbo" - TRADES = "trades" - OHLCV_1S = "ohlcv-1s" - OHLCV_1M = "ohlcv-1m" - OHLCV_1H = "ohlcv-1h" - OHLCV_1D = "ohlcv-1d" - DEFINITION = "definition" - IMBALANCE = "imbalance" - STATISTICS = "statistics" - - def get_record_type(self) -> type[DBNRecord]: - if self == Schema.MBO: - return MBOMsg - if self == Schema.MBP_1: - return MBP1Msg - if self == Schema.MBP_10: - return MBP10Msg - if self == Schema.TBBO: - return MBP1Msg - if self == Schema.TRADES: - return TradeMsg - if self == Schema.OHLCV_1S: - return OHLCVMsg - if self == Schema.OHLCV_1M: - return OHLCVMsg - if self == Schema.OHLCV_1H: - return OHLCVMsg - if self == Schema.OHLCV_1D: - return OHLCVMsg - if self == Schema.DEFINITION: - return InstrumentDefMsg - if self == Schema.IMBALANCE: - return ImbalanceMsg - if self == Schema.STATISTICS: - return StatMsg - raise NotImplementedError(f"No message type for {self}") - - -@unique -@coercible -class Encoding(StringyMixin, str, Enum): - """ - Represents a data output encoding. - """ - - DBN = "dbn" - CSV = "csv" - JSON = "json" - - -@unique -@coercible -class Compression(StringyMixin, str, Enum): - """ - Represents a data compression format (if any). - """ - - NONE = "none" - ZSTD = "zstd" - - @unique @coercible class SplitDuration(StringyMixin, str, Enum): @@ -253,19 +170,6 @@ class Delivery(StringyMixin, str, Enum): DISK = "disk" -@unique -@coercible -class SType(StringyMixin, str, Enum): - """ - Represents a symbology type. - """ - - INSTRUMENT_ID = "instrument_id" - RAW_SYMBOL = "raw_symbol" - PARENT = "parent" - CONTINUOUS = "continuous" - - @unique @coercible class RollRule(StringyMixin, str, Enum): diff --git a/databento/common/parsing.py b/databento/common/parsing.py index a4008fb..df845f0 100644 --- a/databento/common/parsing.py +++ b/databento/common/parsing.py @@ -7,8 +7,8 @@ from numbers import Number import pandas as pd +from databento_dbn import SType -from databento.common.enums import SType from databento.common.symbology import ALL_SYMBOLS from databento.common.validation import validate_smart_symbol diff --git a/databento/common/validation.py b/databento/common/validation.py index b3073e8..3559c0b 100644 --- a/databento/common/validation.py +++ b/databento/common/validation.py @@ -69,7 +69,11 @@ def validate_enum( try: return enum(value) except ValueError as e: - valid = list(map(str, enum)) + if hasattr(enum, "variants"): + valid = list(map(str, enum.variants())) # type: ignore [attr-defined] + else: + valid = list(map(str, enum)) + raise ValueError( f"The `{param}` was not a valid value of {enum}, was '{value}'. " f"Use any of {valid}.", diff --git a/databento/historical/api/batch.py b/databento/historical/api/batch.py index fe4cc39..9acf23b 100644 --- a/databento/historical/api/batch.py +++ b/databento/historical/api/batch.py @@ -10,16 +10,16 @@ import aiohttp import pandas as pd import requests +from databento_dbn import Compression +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType from requests.auth import HTTPBasicAuth -from databento.common.enums import Compression from databento.common.enums import Dataset from databento.common.enums import Delivery -from databento.common.enums import Encoding from databento.common.enums import Packaging -from databento.common.enums import Schema from databento.common.enums import SplitDuration -from databento.common.enums import SType from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_datetime_to_string from databento.common.parsing import optional_symbols_list_to_string diff --git a/databento/historical/api/metadata.py b/databento/historical/api/metadata.py index d1ef510..ad116d2 100644 --- a/databento/historical/api/metadata.py +++ b/databento/historical/api/metadata.py @@ -4,13 +4,13 @@ from typing import Any import pandas as pd +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType from requests import Response from databento.common.enums import Dataset -from databento.common.enums import Encoding from databento.common.enums import FeedMode -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_date_to_string from databento.common.parsing import optional_datetime_to_string @@ -145,7 +145,7 @@ def list_fields( A mapping of dataset to encoding to schema to field to data type. """ - params: list[tuple[str, str | None]] = [ + params: list[tuple[str, Dataset | Schema | Encoding | str | None]] = [ ("dataset", validate_semantic_string(dataset, "dataset")), ("schema", validate_maybe_enum(schema, Schema, "schema")), ("encoding", validate_maybe_enum(encoding, Encoding, "encoding")), @@ -153,7 +153,7 @@ def list_fields( response: Response = self._get( url=self._base_url + ".list_fields", - params=params, + params=params, # type: ignore [arg-type] basic_auth=True, ) return response.json() @@ -185,7 +185,7 @@ def list_unit_prices( Otherwise, return a map of feed mode to schema to unit price. """ - params: list[tuple[str, str | None]] = [ + params: list[tuple[str, Dataset | FeedMode | Schema | str | None]] = [ ("dataset", validate_semantic_string(dataset, "dataset")), ("mode", validate_maybe_enum(mode, FeedMode, "mode")), ("schema", validate_maybe_enum(schema, Schema, "schema")), @@ -193,7 +193,7 @@ def list_unit_prices( response: Response = self._get( url=self._base_url + ".list_unit_prices", - params=params, + params=params, # type: ignore [arg-type] basic_auth=True, ) return response.json() diff --git a/databento/historical/api/symbology.py b/databento/historical/api/symbology.py index bfb5a0d..859f3aa 100644 --- a/databento/historical/api/symbology.py +++ b/databento/historical/api/symbology.py @@ -3,10 +3,10 @@ from datetime import date from typing import Any +from databento_dbn import SType from requests import Response from databento.common.enums import Dataset -from databento.common.enums import SType from databento.common.parsing import datetime_to_date_string from databento.common.parsing import optional_date_to_string from databento.common.parsing import optional_symbols_list_to_string diff --git a/databento/historical/api/timeseries.py b/databento/historical/api/timeseries.py index e0738b9..2f080c7 100644 --- a/databento/historical/api/timeseries.py +++ b/databento/historical/api/timeseries.py @@ -4,13 +4,13 @@ from os import PathLike import pandas as pd +from databento_dbn import Compression +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType from databento.common.dbnstore import DBNStore -from databento.common.enums import Compression from databento.common.enums import Dataset -from databento.common.enums import Encoding -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_datetime_to_string from databento.common.parsing import optional_symbols_list_to_string diff --git a/databento/live/client.py b/databento/live/client.py index 023e915..4b95bd2 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -11,11 +11,11 @@ from typing import IO import databento_dbn +from databento_dbn import Schema +from databento_dbn import SType from databento.common.cram import BUCKET_ID_LENGTH from databento.common.enums import Dataset -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.common.parsing import optional_datetime_to_unix_nanoseconds from databento.common.parsing import optional_symbols_list_to_string diff --git a/databento/live/gateway.py b/databento/live/gateway.py index 9f29890..5f73989 100644 --- a/databento/live/gateway.py +++ b/databento/live/gateway.py @@ -4,12 +4,14 @@ import logging from functools import partial from io import BytesIO +from operator import attrgetter from typing import TypeVar +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType + from databento.common.enums import Dataset -from databento.common.enums import Encoding -from databento.common.enums import Schema -from databento.common.enums import SType logger = logging.getLogger(__name__) @@ -53,11 +55,9 @@ def parse(cls: type[T], line: str) -> T: ) from type_err def __str__(self) -> str: - tokens = "|".join( - f"{k}={str(v)}" - for k, v in dataclasses.asdict(self).items() - if v is not None - ) + fields = tuple(map(attrgetter("name"), dataclasses.fields(self))) + values = tuple(getattr(self, f) for f in fields) + tokens = "|".join(f"{k}={v}" for k, v in zip(fields, values) if v is not None) return f"{tokens}\n" def __bytes__(self) -> bytes: diff --git a/databento/live/protocol.py b/databento/live/protocol.py index 9aa48c8..c3a1ee3 100644 --- a/databento/live/protocol.py +++ b/databento/live/protocol.py @@ -7,11 +7,11 @@ from numbers import Number import databento_dbn +from databento_dbn import Schema +from databento_dbn import SType from databento.common import cram from databento.common.enums import Dataset -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.common.parsing import optional_datetime_to_unix_nanoseconds from databento.common.parsing import optional_symbols_list_to_string diff --git a/databento/live/session.py b/databento/live/session.py index c12203e..1e2725f 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -11,10 +11,10 @@ from typing import IO, Callable import databento_dbn +from databento_dbn import Schema +from databento_dbn import SType from databento.common.enums import Dataset -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.common.symbology import ALL_SYMBOLS from databento.live import AUTH_TIMEOUT_SECONDS diff --git a/pyproject.toml b/pyproject.toml index 435d2f5..752ba1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ repository = "https://github.com/databento/databento-python" [tool.poetry.dependencies] python = "^3.8" aiohttp = "^3.8.3" -databento-dbn = "0.6.1" +databento-dbn = "0.7.1" numpy= ">=1.23.5" pandas = ">=1.5.3" requests = ">=2.24.0" diff --git a/tests/conftest.py b/tests/conftest.py index 75dbfa2..521da44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import pytest_asyncio from databento import historical from databento import live -from databento.common.enums import Schema +from databento_dbn import Schema from tests import TESTS_ROOT from tests.mock_live_server import MockLiveServer @@ -106,7 +106,10 @@ def fixture_test_data_path() -> Callable[[Schema], pathlib.Path]: """ def func(schema: Schema) -> pathlib.Path: - return pathlib.Path(TESTS_ROOT) / "data" / f"test_data.{schema}.dbn.zst" + path = pathlib.Path(TESTS_ROOT) / "data" / f"test_data.{schema}.dbn.zst" + if not path.exists(): + pytest.skip(f"no test data for schema: {schema}") + return path return func diff --git a/tests/data/generator.py b/tests/data/generator.py index 9fe1db1..a17bcbf 100644 --- a/tests/data/generator.py +++ b/tests/data/generator.py @@ -1,6 +1,6 @@ import databento as db from databento import DBNStore -from databento.common.enums import Schema +from databento_dbn import Schema if __name__ == "__main__": diff --git a/tests/mock_live_server.py b/tests/mock_live_server.py index 5fc5f84..7521e7d 100644 --- a/tests/mock_live_server.py +++ b/tests/mock_live_server.py @@ -17,7 +17,6 @@ import zstandard from databento.common import cram -from databento.common.enums import Schema from databento.live.gateway import AuthenticationRequest from databento.live.gateway import AuthenticationResponse from databento.live.gateway import ChallengeRequest @@ -26,6 +25,7 @@ from databento.live.gateway import SessionStart from databento.live.gateway import SubscriptionRequest from databento.live.gateway import parse_gateway_message +from databento_dbn import Schema LIVE_SERVER_VERSION: str = "1.0.0" diff --git a/tests/test_bento_data_source.py b/tests/test_bento_data_source.py index c1bbbb2..cd81778 100644 --- a/tests/test_bento_data_source.py +++ b/tests/test_bento_data_source.py @@ -4,10 +4,10 @@ import pytest from databento.common.dbnstore import FileDataSource from databento.common.dbnstore import MemoryDataSource -from databento.common.enums import Schema +from databento_dbn import Schema -@pytest.mark.parametrize("schema", [pytest.param(x) for x in Schema]) +@pytest.mark.parametrize("schema", [pytest.param(x) for x in Schema.variants()]) def test_memory_data_source( test_data: Callable[[Schema], bytes], schema: Schema, @@ -22,7 +22,7 @@ def test_memory_data_source( assert repr(data) == data_source.name -@pytest.mark.parametrize("schema", [pytest.param(x) for x in Schema]) +@pytest.mark.parametrize("schema", [pytest.param(x) for x in Schema.variants()]) def test_file_data_source( test_data_path: Callable[[Schema], pathlib.Path], schema: Schema, diff --git a/tests/test_common_enums.py b/tests/test_common_enums.py index 16069e5..b4f262d 100644 --- a/tests/test_common_enums.py +++ b/tests/test_common_enums.py @@ -8,20 +8,20 @@ from itertools import combinations import pytest -from databento.common.enums import Compression from databento.common.enums import Dataset from databento.common.enums import Delivery -from databento.common.enums import Encoding from databento.common.enums import FeedMode from databento.common.enums import HistoricalGateway from databento.common.enums import Packaging from databento.common.enums import RecordFlags from databento.common.enums import RollRule -from databento.common.enums import Schema from databento.common.enums import SplitDuration from databento.common.enums import StringyMixin -from databento.common.enums import SType from databento.common.enums import SymbologyResolution +from databento_dbn import Compression +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType DATABENTO_ENUMS = ( @@ -88,7 +88,12 @@ def test_enum_name_coercion(enum_type: type[Enum]) -> None: See: databento.common.enums.coercible """ - for enum in enum_type: + if enum_type in (Compression, Encoding, Schema, SType): + enum_it = iter(enum_type.variants()) # type: ignore [attr-defined] + else: + enum_it = iter(enum_type) + + for enum in enum_it: assert enum == enum_type(enum.name) assert enum == enum_type(enum.name.replace("_", "-")) assert enum == enum_type(enum.name.lower()) @@ -108,8 +113,11 @@ def test_enum_none_not_coercible(enum_type: type[Enum]) -> None: See: databento.common.enum.coercible """ - with pytest.raises(TypeError): + if enum_type == Compression: enum_type(None) + else: + with pytest.raises(ValueError): + enum_type(None) @pytest.mark.parametrize( diff --git a/tests/test_common_parsing.py b/tests/test_common_parsing.py index a95e371..90a5f69 100644 --- a/tests/test_common_parsing.py +++ b/tests/test_common_parsing.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd import pytest -from databento.common.enums import SType from databento.common.parsing import optional_date_to_string from databento.common.parsing import optional_datetime_to_string from databento.common.parsing import optional_datetime_to_unix_nanoseconds from databento.common.parsing import optional_symbols_list_to_string from databento.common.parsing import optional_values_list_to_string +from databento_dbn import SType # Set the type to `Any` to disable mypy type checking. Used to test if functions @@ -104,7 +104,7 @@ def test_optional_symbols_list_to_string_given_valid_inputs_returns_expected( ], ) def test_optional_symbols_list_to_string_int( - symbols: list[Number] | Number | None, + symbols: list[Number] | Number | None, stype: SType, expected: str | type[Exception], ) -> None: @@ -127,16 +127,24 @@ def test_optional_symbols_list_to_string_int( pytest.param(np.byte(120), SType.INSTRUMENT_ID, "120"), pytest.param(np.short(32_000), SType.INSTRUMENT_ID, "32000"), pytest.param( - [np.intc(12345), np.intc(67890)], SType.INSTRUMENT_ID, "12345,67890", + [np.intc(12345), np.intc(67890)], + SType.INSTRUMENT_ID, + "12345,67890", ), pytest.param( - [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890", + [np.int_(12345), np.longlong(67890)], + SType.INSTRUMENT_ID, + "12345,67890", ), pytest.param( - [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890", + [np.int_(12345), np.longlong(67890)], + SType.INSTRUMENT_ID, + "12345,67890", ), pytest.param( - [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, "12345,67890", + [np.int_(12345), np.longlong(67890)], + SType.INSTRUMENT_ID, + "12345,67890", ), ], ) diff --git a/tests/test_common_validation.py b/tests/test_common_validation.py index a7ed707..91228b5 100644 --- a/tests/test_common_validation.py +++ b/tests/test_common_validation.py @@ -4,13 +4,13 @@ from typing import Any import pytest -from databento.common.enums import Encoding from databento.common.validation import validate_enum from databento.common.validation import validate_gateway from databento.common.validation import validate_maybe_enum from databento.common.validation import validate_path from databento.common.validation import validate_semantic_string from databento.common.validation import validate_smart_symbol +from databento_dbn import Encoding @pytest.mark.parametrize( @@ -37,7 +37,7 @@ def test_validate_enum_given_wrong_types_raises_type_error( enum: type[Enum], ) -> None: # Arrange, Act, Assert - with pytest.raises(TypeError): + with pytest.raises(ValueError): validate_enum(value, enum, "param") def test_validate_enum_given_invalid_value_raises_value_error() -> None: diff --git a/tests/test_historical_bento.py b/tests/test_historical_bento.py index 0e82c89..edf0ffd 100644 --- a/tests/test_historical_bento.py +++ b/tests/test_historical_bento.py @@ -14,11 +14,11 @@ import zstandard from databento.common.data import DEFINITION_DROP_COLUMNS from databento.common.dbnstore import DBNStore -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.live import DBNRecord from databento_dbn import MBOMsg +from databento_dbn import Schema +from databento_dbn import SType def test_from_file_when_not_exists_raises_expected_exception() -> None: @@ -270,7 +270,7 @@ def test_replay_with_stub_data_record_passes_to_callback( "schema", [ s - for s in Schema + for s in Schema.variants() if s not in ( Schema.OHLCV_1H, @@ -298,7 +298,7 @@ def test_to_df_across_schemas_returns_identical_dimension_dfs( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_to_df_drop_columns( test_data: Callable[[Schema], bytes], @@ -451,7 +451,7 @@ def test_to_df_with_pretty_px_with_various_schemas_converts_prices_as_expected( @pytest.mark.parametrize( "expected_schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_from_file_given_various_paths_returns_expected_metadata( test_data_path: Callable[[Schema], Path], @@ -746,7 +746,7 @@ def test_mbp_1_to_json_with_all_options_writes_expected_file_to_disk( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_dbnstore_repr( test_data: Callable[[Schema], bytes], @@ -839,7 +839,7 @@ def test_dbnstore_iterable_parallel( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_dbnstore_compression_equality( test_data: Callable[[Schema], bytes], @@ -932,7 +932,7 @@ def test_dbnstore_buffer_long( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_dbnstore_to_ndarray_with_schema( schema: Schema, @@ -979,7 +979,7 @@ def test_dbnstore_to_ndarray_with_schema_empty( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) def test_dbnstore_to_df_with_schema( schema: Schema, diff --git a/tests/test_historical_client.py b/tests/test_historical_client.py index 257885d..ed146c0 100644 --- a/tests/test_historical_client.py +++ b/tests/test_historical_client.py @@ -10,7 +10,7 @@ from databento import DBNStore from databento import Historical from databento.common.enums import HistoricalGateway -from databento.common.enums import Schema +from databento_dbn import Schema def test_key_returns_expected() -> None: diff --git a/tests/test_historical_metadata.py b/tests/test_historical_metadata.py index d617265..41ed5be 100644 --- a/tests/test_historical_metadata.py +++ b/tests/test_historical_metadata.py @@ -7,8 +7,8 @@ import requests from databento.common.enums import Dataset from databento.common.enums import FeedMode -from databento.common.enums import Schema from databento.historical.client import Historical +from databento_dbn import Schema def test_list_publishers_sends_expected_request( diff --git a/tests/test_historical_timeseries.py b/tests/test_historical_timeseries.py index 7286447..7161397 100644 --- a/tests/test_historical_timeseries.py +++ b/tests/test_historical_timeseries.py @@ -6,9 +6,9 @@ import pytest import requests from databento import DBNStore -from databento.common.enums import Schema from databento.common.error import BentoServerError from databento.historical.client import Historical +from databento_dbn import Schema def test_get_range_given_invalid_schema_raises_error( diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 0c6e71d..d90e37d 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -13,11 +13,9 @@ import pytest import zstandard from databento.common.cram import BUCKET_ID_LENGTH +from databento.common.data import SCHEMA_STRUCT_MAP from databento.common.dbnstore import DBNStore from databento.common.enums import Dataset -from databento.common.enums import Encoding -from databento.common.enums import Schema -from databento.common.enums import SType from databento.common.error import BentoError from databento.common.symbology import ALL_SYMBOLS from databento.live import DBNRecord @@ -25,6 +23,9 @@ from databento.live import gateway from databento.live import protocol from databento.live import session +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType from tests.mock_live_server import MockLiveServer @@ -309,11 +310,11 @@ def test_live_start_twice( @pytest.mark.parametrize( "schema", - [pytest.param(schema, id=str(schema)) for schema in Schema], + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], ) @pytest.mark.parametrize( "stype_in", - [pytest.param(stype, id=str(stype)) for stype in SType], + [pytest.param(stype, id=str(stype)) for stype in SType.variants()], ) @pytest.mark.parametrize( "symbols", @@ -585,7 +586,9 @@ async def test_live_async_iteration_backpressure( ) monkeypatch.setattr( - live_client._session._transport, "pause_reading", pause_mock := MagicMock(), + live_client._session._transport, + "pause_reading", + pause_mock := MagicMock(), ) live_client.start() @@ -624,7 +627,9 @@ async def test_live_async_iteration_dropped( ) monkeypatch.setattr( - live_client._session._transport, "pause_reading", pause_mock := MagicMock(), + live_client._session._transport, + "pause_reading", + pause_mock := MagicMock(), ) live_client.start() @@ -725,7 +730,7 @@ def callback(record: DBNRecord) -> None: @pytest.mark.parametrize( "schema", - (pytest.param(schema, id=str(schema)) for schema in Schema), + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), ) async def test_live_stream_to_dbn( tmp_path: pathlib.Path, @@ -763,7 +768,7 @@ async def test_live_stream_to_dbn( @pytest.mark.parametrize( "schema", - (pytest.param(schema, id=str(schema)) for schema in Schema), + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), ) @pytest.mark.parametrize( "buffer_size", @@ -889,7 +894,7 @@ async def test_live_terminate( @pytest.mark.parametrize( "schema", - (pytest.param(schema, id=str(schema)) for schema in Schema), + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), ) async def test_live_iteration_with_reconnect( live_client: client.Live, @@ -943,12 +948,12 @@ async def test_live_iteration_with_reconnect( records = list(my_iter) assert len(records) == 2 * len(list(dbn)) for record in records: - assert isinstance(record, schema.get_record_type()) + assert isinstance(record, SCHEMA_STRUCT_MAP[schema]) @pytest.mark.parametrize( "schema", - (pytest.param(schema, id=str(schema)) for schema in Schema), + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), ) async def test_live_callback_with_reconnect( live_client: client.Live, @@ -990,12 +995,12 @@ async def test_live_callback_with_reconnect( assert len(records) == 5 * len(list(dbn)) for record in records: - assert isinstance(record, schema.get_record_type()) + assert isinstance(record, SCHEMA_STRUCT_MAP[schema]) @pytest.mark.parametrize( "schema", - (pytest.param(schema, id=str(schema)) for schema in Schema), + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), ) async def test_live_stream_with_reconnect( tmp_path: pathlib.Path, @@ -1009,6 +1014,10 @@ async def test_live_stream_with_reconnect( That output stream should be readable. """ + # TODO: Remove when status schema is available + if schema == "status": + pytest.skip("no stub data for status schema") + output = tmp_path / "output.dbn" live_client.add_stream(output.open("wb", buffering=0)) @@ -1032,7 +1041,8 @@ async def test_live_stream_with_reconnect( records = list(data) for record in records: - assert isinstance(record, schema.get_record_type()) + assert isinstance(record, SCHEMA_STRUCT_MAP[schema]) + def test_live_connection_reconnect_cram_failure( mock_live_server: MockLiveServer, @@ -1069,6 +1079,7 @@ def test_live_connection_reconnect_cram_failure( schema=Schema.MBO, ) + async def test_live_callback_exception_handler( live_client: client.Live, ) -> None: diff --git a/tests/test_live_gateway_messages.py b/tests/test_live_gateway_messages.py index 86a955a..6529942 100644 --- a/tests/test_live_gateway_messages.py +++ b/tests/test_live_gateway_messages.py @@ -2,9 +2,6 @@ import pytest from databento.common.enums import Dataset -from databento.common.enums import Encoding -from databento.common.enums import Schema -from databento.common.enums import SType from databento.live.gateway import AuthenticationRequest from databento.live.gateway import AuthenticationResponse from databento.live.gateway import ChallengeRequest @@ -12,6 +9,9 @@ from databento.live.gateway import Greeting from databento.live.gateway import SessionStart from databento.live.gateway import SubscriptionRequest +from databento_dbn import Encoding +from databento_dbn import Schema +from databento_dbn import SType ALL_MESSAGES = ( diff --git a/tests/test_live_protocol.py b/tests/test_live_protocol.py index a9ac85c..78bed80 100644 --- a/tests/test_live_protocol.py +++ b/tests/test_live_protocol.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock import pytest -from databento.common.enums import Schema -from databento.common.enums import SType from databento.live.protocol import DatabentoLiveProtocol +from databento_dbn import Schema +from databento_dbn import SType from tests.mock_live_server import MockLiveServer @@ -42,8 +42,12 @@ async def test_protocol_connection_streaming( Test the low-level DatabentoLiveProtocol can be used to stream DBN records from the live subscription gateway. """ - monkeypatch.setattr(DatabentoLiveProtocol, "received_metadata", metadata_mock := MagicMock()) - monkeypatch.setattr(DatabentoLiveProtocol, "received_record", record_mock := MagicMock()) + monkeypatch.setattr( + DatabentoLiveProtocol, "received_metadata", metadata_mock := MagicMock(), + ) + monkeypatch.setattr( + DatabentoLiveProtocol, "received_record", record_mock := MagicMock(), + ) transport, protocol = await asyncio.get_event_loop().create_connection( protocol_factory=lambda: DatabentoLiveProtocol( From 8bb77b0e0ad1696561fd24c9dd519531d56ae784 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Thu, 29 Jun 2023 00:07:04 -0500 Subject: [PATCH 11/17] MOD: Live client to start session on iteration --- CHANGELOG.md | 1 + databento/live/client.py | 4 +- databento/live/session.py | 158 +++++++++++++++++++++----------------- tests/test_live_client.py | 21 ++--- 4 files changed, 98 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2e84ce..821d06c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ #### Enhancements - Added `symbology_map` property to `Live` client - Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback +- Changed `Live.__iter__()` and `Live.__aiter__()` to send the session start message if the session is connected but not started - Upgraded `databento-dbn` to 0.7.1 - Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` diff --git a/databento/live/client.py b/databento/live/client.py index 4b95bd2..d5c577c 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -128,13 +128,15 @@ async def __anext__(self) -> DBNRecord: def __iter__(self) -> Live: logger.debug("starting iteration") self._dbn_queue._enabled.set() + if not self._session.is_started() and self.is_connected(): + self.start() return self def __next__(self) -> DBNRecord: if self._dbn_queue is None: raise ValueError("iteration has not started") - while not self._session.is_disconnected() or self._dbn_queue._qsize() > 0: + while not self._session.is_disconnected() or self._dbn_queue.qsize() > 0: try: record = self._dbn_queue.get(block=False) except queue.Empty: diff --git a/databento/live/session.py b/databento/live/session.py index 1e2725f..107ea9d 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -258,6 +258,7 @@ def __init__( port: int = DEFAULT_REMOTE_PORT, ts_out: bool = False, ) -> None: + self._lock = threading.RLock() self._loop = loop self._ts_out = ts_out self._protocol_factory = protocol_factory @@ -277,13 +278,14 @@ def is_authenticated(self) -> bool: bool """ - if self._protocol is None: - return False - try: - self._protocol.authenticated.result() - except (asyncio.InvalidStateError, asyncio.CancelledError, BentoError): - return False - return True + with self._lock: + if self._protocol is None: + return False + try: + self._protocol.authenticated.result() + except (asyncio.InvalidStateError, asyncio.CancelledError, BentoError): + return False + return True def is_disconnected(self) -> bool: """ @@ -294,9 +296,10 @@ def is_disconnected(self) -> bool: bool """ - if self._protocol is None: - return True - return self._protocol.disconnected.done() + with self._lock: + if self._protocol is None: + return True + return self._protocol.disconnected.done() def is_reading(self) -> bool: """ @@ -307,9 +310,10 @@ def is_reading(self) -> bool: bool """ - if self._transport is None: - return False - return self._transport.is_reading() + with self._lock: + if self._transport is None: + return False + return self._transport.is_reading() def is_started(self) -> bool: """ @@ -320,9 +324,10 @@ def is_started(self) -> bool: bool """ - if self._protocol is None: - return False - return self._protocol.started.is_set() + with self._lock: + if self._protocol is None: + return False + return self._protocol.started.is_set() @property def metadata(self) -> databento_dbn.Metadata | None: @@ -334,9 +339,10 @@ def metadata(self) -> databento_dbn.Metadata | None: databento_dbn.Metadata """ - if self._protocol is None: - return None - return self._protocol._metadata.data + with self._lock: + if self._protocol is None: + return None + return self._protocol._metadata.data def abort(self) -> None: """ @@ -347,20 +353,22 @@ def abort(self) -> None: Session.close """ - if self._transport is None: - return - self._transport.abort() - self._protocol = None + with self._lock: + if self._transport is None: + return + self._transport.abort() + self._protocol = None def close(self) -> None: """ Close the current connection. """ - if self._transport is None: - return - if self._transport.can_write_eof(): - self._loop.call_soon_threadsafe(self._transport.write_eof) - self._loop.call_soon_threadsafe(self._transport.close) + with self._lock: + if self._transport is None: + return + if self._transport.can_write_eof(): + self._loop.call_soon_threadsafe(self._transport.write_eof) + self._loop.call_soon_threadsafe(self._transport.close) def subscribe( self, @@ -389,27 +397,29 @@ def subscribe( within 24 hours. """ - if self._protocol is None: - self._connect( - dataset=dataset, - port=self._port, - loop=self._loop, - ) + with self._lock: + if self._protocol is None: + self._connect( + dataset=dataset, + port=self._port, + loop=self._loop, + ) - self._protocol.subscribe( - schema=schema, - symbols=symbols, - stype_in=stype_in, - start=start, - ) + self._protocol.subscribe( + schema=schema, + symbols=symbols, + stype_in=stype_in, + start=start, + ) def resume_reading(self) -> None: """ Resume reading from the connection. """ - if self._transport is None: - return - self._loop.call_soon_threadsafe(self._transport.resume_reading) + with self._lock: + if self._transport is None: + return + self._loop.call_soon_threadsafe(self._transport.resume_reading) def start(self) -> None: """ @@ -421,9 +431,10 @@ def start(self) -> None: If there is no connection. """ - if self._protocol is None: - raise ValueError("session is not connected") - self._protocol.start() + with self._lock: + if self._protocol is None: + raise ValueError("session is not connected") + self._protocol.start() async def wait_for_close(self) -> None: """ @@ -433,14 +444,21 @@ async def wait_for_close(self) -> None: if self._protocol is None: return + await self._protocol.authenticated await self._protocol.disconnected - disconnect_exc = self._protocol.disconnected.exception() - await self._protocol.wait_for_processing() - self._protocol = self._transport = None - if disconnect_exc is not None: - raise BentoError(disconnect_exc) + try: + self._protocol.authenticated.result() + except Exception as exc: + raise BentoError(exc) + + try: + self._protocol.disconnected.result() + except Exception as exc: + raise BentoError(exc) + + self._protocol = self._transport = None def _connect( self, @@ -448,29 +466,30 @@ def _connect( port: int, loop: asyncio.AbstractEventLoop, ) -> None: - if self._user_gateway is None: - subdomain = dataset.lower().replace(".", "-") - gateway = f"{subdomain}.lsg.databento.com" - logger.debug("using default gateway for dataset %s", dataset) - else: - gateway = self._user_gateway - logger.debug("using user specified gateway: %s", gateway) + with self._lock: + if not self.is_disconnected(): + return + if self._user_gateway is None: + subdomain = dataset.lower().replace(".", "-") + gateway = f"{subdomain}.lsg.databento.com" + logger.debug("using default gateway for dataset %s", dataset) + else: + gateway = self._user_gateway + logger.debug("using user specified gateway: %s", gateway) - asyncio.run_coroutine_threadsafe( - coro=self._connect_task( - gateway=gateway, - port=port, - ), - loop=loop, - ).result() + self._transport, self._protocol = asyncio.run_coroutine_threadsafe( + coro=self._connect_task( + gateway=gateway, + port=port, + ), + loop=loop, + ).result() async def _connect_task( self, gateway: str, port: int, - ) -> None: - if not self.is_disconnected(): - return + ) -> tuple[asyncio.Transport, _SessionProtocol]: logger.info("connecting to remote gateway") try: transport, protocol = await asyncio.wait_for( @@ -514,5 +533,4 @@ async def _connect_task( "authentication with remote gateway completed", ) - self._transport = transport - self._protocol = protocol + return transport, protocol diff --git a/tests/test_live_client.py b/tests/test_live_client.py index d90e37d..cc6b310 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -547,8 +547,6 @@ async def test_live_async_iteration( symbols="TEST", ) - live_client.start() - records: list[DBNRecord] = [] async for record in live_client: records.append(record) @@ -591,13 +589,12 @@ async def test_live_async_iteration_backpressure( pause_mock := MagicMock(), ) - live_client.start() - it = live_client.__iter__() + live_it = iter(live_client) await live_client.wait_for_close() - assert pause_mock.called + pause_mock.assert_called() - records = list(it) + records: list[DBNRecord] = list(live_it) assert len(records) == 4 assert live_client._dbn_queue.empty() @@ -632,13 +629,12 @@ async def test_live_async_iteration_dropped( pause_mock := MagicMock(), ) - live_client.start() - it = live_client.__iter__() + live_it = iter(live_client) await live_client.wait_for_close() - assert pause_mock.called + pause_mock.assert_called() - records = list(it) + records = list(live_it) assert len(records) == 1 assert live_client._dbn_queue.empty() @@ -658,8 +654,6 @@ async def test_live_async_iteration_stop( symbols="TEST", ) - live_client.start() - records = [] async for record in live_client: records.append(record) @@ -683,8 +677,6 @@ def test_live_sync_iteration( symbols="TEST", ) - live_client.start() - records = [] for record in live_client: records.append(record) @@ -918,7 +910,6 @@ async def test_live_iteration_with_reconnect( assert live_client.is_connected() assert live_client.dataset == Dataset.GLBX_MDP3 - live_client.start() my_iter = iter(live_client) await live_client.wait_for_close() From 6d67bc0764820101c0bed54f60c43aaab416e08f Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Wed, 28 Jun 2023 21:42:46 -0500 Subject: [PATCH 12/17] FIX: Fix exception chaining --- CHANGELOG.md | 1 + databento/common/dbnstore.py | 3 +-- databento/common/enums.py | 10 +++++----- databento/common/validation.py | 12 ++++++------ databento/live/client.py | 8 ++++---- databento/live/gateway.py | 4 ++-- databento/live/session.py | 14 +++++++------- tests/test_release.py | 8 ++++---- 8 files changed, 30 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 821d06c..455c7bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Changed `Live.__iter__()` and `Live.__aiter__()` to send the session start message if the session is connected but not started - Upgraded `databento-dbn` to 0.7.1 - Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` +- Removed exception chaining from exceptions emitted by the library #### Breaking changes - Renamed `func` parameter to `record_callback` for `Live.add_callback` and `Live.add_stream` diff --git a/databento/common/dbnstore.py b/databento/common/dbnstore.py index a28248d..5dfd0cd 100644 --- a/databento/common/dbnstore.py +++ b/databento/common/dbnstore.py @@ -789,10 +789,9 @@ def replay(self, callback: Callable[[Any], None]) -> None: for record in self: try: callback(record) - except Exception as exc: + except Exception: logger.exception( "exception while replaying to user callback", - exc_info=exc, ) raise diff --git a/databento/common/enums.py b/databento/common/enums.py index 23e3110..b69f119 100644 --- a/databento/common/enums.py +++ b/databento/common/enums.py @@ -65,17 +65,17 @@ def coerced_new(enum: type[M], value: object) -> M: ) try: return _new(enum, coerce_fn(value)) - except ValueError as ve: + except ValueError: name_to_try = str(value).replace(".", "_").replace("-", "_").upper() named = enum._member_map_.get(name_to_try) if named is not None: return named - enum_values = tuple(value for value in enum._value2member_map_) + enum_values = list(value for value in enum._value2member_map_) raise ValueError( - f"value `{value}` is not a member of {enum_type.__name__}. " - f"use one of {enum_values}.", - ) from ve + f"The `{value}` was not a valid value of {enum_type.__name__}" + f", was '{value}'. Use any of {enum_values}.", + ) from None setattr(enum_type, "__new__", coerced_new) diff --git a/databento/common/validation.py b/databento/common/validation.py index 3559c0b..3bfe9ad 100644 --- a/databento/common/validation.py +++ b/databento/common/validation.py @@ -30,11 +30,11 @@ def validate_path(value: PathLike[str] | str, param: str) -> Path: """ try: return Path(value) - except TypeError as e: + except TypeError: raise TypeError( f"The `{param}` was not a valid path type. " "Use any of [str, bytes, os.PathLike].", - ) from e + ) from None def validate_enum( @@ -68,16 +68,16 @@ def validate_enum( """ try: return enum(value) - except ValueError as e: + except ValueError: if hasattr(enum, "variants"): valid = list(map(str, enum.variants())) # type: ignore [attr-defined] else: valid = list(map(str, enum)) raise ValueError( - f"The `{param}` was not a valid value of {enum}, was '{value}'. " - f"Use any of {valid}.", - ) from e + f"The `{param}` was not a valid value of {enum.__name__}" + f", was '{value}'. Use any of {valid}.", + ) from None def validate_maybe_enum( diff --git a/databento/live/client.py b/databento/live/client.py index d5c577c..8c2302d 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -527,9 +527,9 @@ def block_for_close( self.terminate() if isinstance(exc, KeyboardInterrupt): raise - except Exception as exc: + except Exception: logger.exception("exception encountered blocking for close") - raise BentoError("connection lost") from exc + raise BentoError("connection lost") from None async def wait_for_close( self, @@ -572,9 +572,9 @@ async def wait_for_close( self.terminate() if isinstance(exc, KeyboardInterrupt): raise - except Exception as exc: + except Exception: logger.exception("exception encountered waiting for close") - raise BentoError("connection lost") from exc + raise BentoError("connection lost") from None async def _shutdown(self) -> None: """ diff --git a/databento/live/gateway.py b/databento/live/gateway.py index 5f73989..0f173fe 100644 --- a/databento/live/gateway.py +++ b/databento/live/gateway.py @@ -49,10 +49,10 @@ def parse(cls: type[T], line: str) -> T: try: return cls(**data_dict) - except TypeError as type_err: + except TypeError: raise ValueError( f"`{line.strip()} is not a parsible {cls.__name__}", - ) from type_err + ) from None def __str__(self) -> str: fields = tuple(map(attrgetter("name"), dataclasses.fields(self))) diff --git a/databento/live/session.py b/databento/live/session.py index 107ea9d..5f35d07 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -500,15 +500,15 @@ async def _connect_task( ), timeout=CONNECT_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError as exc: + except asyncio.TimeoutError: raise BentoError( f"Connection to {gateway}:{port} timed out after " f"{CONNECT_TIMEOUT_SECONDS} second(s).", - ) from exc + ) from None except OSError as exc: raise BentoError( - f"Connection to {gateway}:{port} failed.", - ) from exc + f"Connection to {gateway}:{port} failed: {exc}", + ) from None logger.debug( "connected to %s:%d", @@ -521,13 +521,13 @@ async def _connect_task( protocol.authenticated, timeout=AUTH_TIMEOUT_SECONDS, ) - except asyncio.TimeoutError as exc: + except asyncio.TimeoutError: raise BentoError( f"Authentication with {gateway}:{port} timed out after " f"{AUTH_TIMEOUT_SECONDS} second(s).", - ) from exc + ) from None except ValueError as exc: - raise BentoError(f"User authentication failed: {str(exc)}") from exc + raise BentoError(f"User authentication failed: {str(exc)}") from None logger.info( "authentication with remote gateway completed", diff --git a/tests/test_release.py b/tests/test_release.py index 0ac0ad8..93343dc 100644 --- a/tests/test_release.py +++ b/tests/test_release.py @@ -45,16 +45,16 @@ def test_release_changelog(changelog: str) -> None: try: versions = list(map(operator.itemgetter(0), releases)) version_tuples = [tuple(map(int, v.split("."))) for v in versions] - except Exception as exc: + except Exception: # This could happen if we have an irregular version string. - raise AssertionError("Failed to parse version from CHANGELOG.md") from exc + raise AssertionError("Failed to parse version from CHANGELOG.md") try: date_strings = list(map(operator.itemgetter(1), releases)) dates = list(map(date.fromisoformat, date_strings)) - except Exception as exc: + except Exception: # This could happen if we have TBD as the release date. - raise AssertionError("Failed to parse release date from CHANGELOG.md") from exc + raise AssertionError("Failed to parse release date from CHANGELOG.md") # Ensure latest version matches `__version__` assert databento.__version__ == versions[0] From fa03a623d2cc6fd53b9a16e0c87e3304d6922cbd Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Fri, 30 Jun 2023 12:22:38 -0700 Subject: [PATCH 13/17] MOD: Batch large symbol subscriptions --- CHANGELOG.md | 5 ++ databento/common/parsing.py | 73 ++++++++------- databento/historical/api/batch.py | 6 +- databento/historical/api/metadata.py | 14 +-- databento/historical/api/symbology.py | 6 +- databento/historical/api/timeseries.py | 10 +-- databento/live/client.py | 2 - databento/live/gateway.py | 1 - databento/live/protocol.py | 54 ++++++++++-- tests/test_common_parsing.py | 117 +++++++++++++------------ tests/test_live_client.py | 34 +++++++ 11 files changed, 200 insertions(+), 122 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 455c7bf..88fd0a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,14 +4,19 @@ #### Enhancements - Added `symbology_map` property to `Live` client +- Added `optional_symbols_list_to_list` parsing function - Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback - Changed `Live.__iter__()` and `Live.__aiter__()` to send the session start message if the session is connected but not started - Upgraded `databento-dbn` to 0.7.1 - Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` - Removed exception chaining from exceptions emitted by the library +#### Bug fixes +- Fixed issue where a large unreadable symbol subscription message could be sent + #### Breaking changes - Renamed `func` parameter to `record_callback` for `Live.add_callback` and `Live.add_stream` +- Removed `optional_symbols_list_to_string` parsing function ## 0.14.1 - 2023-06-16 diff --git a/databento/common/parsing.py b/databento/common/parsing.py index df845f0..e99c291 100644 --- a/databento/common/parsing.py +++ b/databento/common/parsing.py @@ -58,12 +58,13 @@ def optional_values_list_to_string( @singledispatch -def optional_symbols_list_to_string( +def optional_symbols_list_to_list( symbols: Iterable[str] | Iterable[Number] | str | Number | None, stype_in: SType, -) -> str: +) -> list[str]: """ - Concatenate a symbols string or iterable of symbol strings (if not None). + Create a list from a symbols string or iterable of symbol strings (if not + None). Parameters ---------- @@ -74,11 +75,11 @@ def optional_symbols_list_to_string( Returns ------- - str + list[str] Notes ----- - If None is given, ALL_SYMBOLS is returned. + If None is given, [ALL_SYMBOLS] is returned. """ raise TypeError( @@ -87,48 +88,48 @@ def optional_symbols_list_to_string( ) -@optional_symbols_list_to_string.register -def _(_: None, __: SType) -> str: +@optional_symbols_list_to_list.register +def _(_: None, __: SType) -> list[str]: """ - Dispatch method for optional_symbols_list_to_string. Handles None which - defaults to ALL_SYMBOLS. + Dispatch method for optional_symbols_list_to_list. Handles None which + defaults to [ALL_SYMBOLS]. See Also -------- - optional_symbols_list_to_string + optional_symbols_list_to_list """ - return ALL_SYMBOLS + return [ALL_SYMBOLS] -@optional_symbols_list_to_string.register -def _(symbols: Number, stype_in: SType) -> str: +@optional_symbols_list_to_list.register +def _(symbols: Number, stype_in: SType) -> list[str]: """ - Dispatch method for optional_symbols_list_to_string. Handles numerical - types, alerting when an integer is given for STypes that expect strings. + Dispatch method for optional_symbols_list_to_list. Handles numerical types, + alerting when an integer is given for STypes that expect strings. See Also -------- - optional_symbols_list_to_string + optional_symbols_list_to_list """ if stype_in == SType.INSTRUMENT_ID: - return str(symbols) + return [str(symbols)] raise ValueError( f"value `{symbols}` is not a valid symbol for stype {stype_in}; " "did you mean to use `instrument_id`?", ) -@optional_symbols_list_to_string.register -def _(symbols: str, stype_in: SType) -> str: +@optional_symbols_list_to_list.register +def _(symbols: str, stype_in: SType) -> list[str]: """ - Dispatch method for optional_symbols_list_to_string. Handles str, splitting + Dispatch method for optional_symbols_list_to_list. Handles str, splitting on commas and validating smart symbology. See Also -------- - optional_symbols_list_to_string + optional_symbols_list_to_list """ if not symbols: @@ -137,35 +138,33 @@ def _(symbols: str, stype_in: SType) -> str: "an empty string is not allowed", ) - if "," in symbols: - symbol_to_string = partial( - optional_symbols_list_to_string, - stype_in=stype_in, - ) - symbol_list = symbols.strip().strip(",").split(",") - return ",".join(map(symbol_to_string, symbol_list)) + symbol_list = symbols.strip().strip(",").split(",") if stype_in in (SType.PARENT, SType.CONTINUOUS): - return validate_smart_symbol(symbols) - return symbols.strip().upper() + return list(map(str.strip, map(validate_smart_symbol, symbol_list))) + + return list(map(str.upper, map(str.strip, symbol_list))) -@optional_symbols_list_to_string.register(cls=Iterable) -def _(symbols: Iterable[str] | Iterable[int], stype_in: SType) -> str: +@optional_symbols_list_to_list.register(cls=Iterable) +def _(symbols: Iterable[str] | Iterable[int], stype_in: SType) -> list[str]: """ - Dispatch method for optional_symbols_list_to_string. Handles Iterables by + Dispatch method for optional_symbols_list_to_list. Handles Iterables by dispatching the individual members. See Also -------- - optional_symbols_list_to_string + optional_symbols_list_to_list """ - symbol_to_string = partial( - optional_symbols_list_to_string, + symbol_to_list = partial( + optional_symbols_list_to_list, stype_in=stype_in, ) - return ",".join(map(symbol_to_string, symbols)) + aggregated: list[str] = [] + for sym in map(symbol_to_list, symbols): + aggregated.extend(sym) + return aggregated def optional_date_to_string(value: date | str | None) -> str | None: diff --git a/databento/historical/api/batch.py b/databento/historical/api/batch.py index 9acf23b..ec410ff 100644 --- a/databento/historical/api/batch.py +++ b/databento/historical/api/batch.py @@ -22,7 +22,7 @@ from databento.common.enums import SplitDuration from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_datetime_to_string -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.parsing import optional_values_list_to_string from databento.common.validation import validate_enum from databento.common.validation import validate_path @@ -118,12 +118,12 @@ def submit_job( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) data: dict[str, object | None] = { "dataset": validate_semantic_string(dataset, "dataset"), "start": datetime_to_string(start), "end": optional_datetime_to_string(end), - "symbols": str(symbols_list), + "symbols": ",".join(symbols_list), "schema": str(validate_enum(schema, Schema, "schema")), "stype_in": str(stype_in_valid), "stype_out": str(validate_enum(stype_out, SType, "stype_out")), diff --git a/databento/historical/api/metadata.py b/databento/historical/api/metadata.py index ad116d2..609741c 100644 --- a/databento/historical/api/metadata.py +++ b/databento/historical/api/metadata.py @@ -14,7 +14,7 @@ from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_date_to_string from databento.common.parsing import optional_datetime_to_string -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.validation import validate_enum from databento.common.validation import validate_maybe_enum from databento.common.validation import validate_semantic_string @@ -318,10 +318,10 @@ def get_record_count( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) params: list[tuple[str, str | None]] = [ ("dataset", validate_semantic_string(dataset, "dataset")), - ("symbols", symbols_list), + ("symbols", ",".join(symbols_list)), ("schema", str(validate_enum(schema, Schema, "schema"))), ("start", optional_datetime_to_string(start)), ("end", optional_datetime_to_string(end)), @@ -387,12 +387,12 @@ def get_billable_size( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) params: list[tuple[str, str | None]] = [ ("dataset", validate_semantic_string(dataset, "dataset")), ("start", datetime_to_string(start)), ("end", optional_datetime_to_string(end)), - ("symbols", symbols_list), + ("symbols", ",".join(symbols_list)), ("schema", str(validate_enum(schema, Schema, "schema"))), ("stype_in", str(stype_in_valid)), ("stype_out", str(SType.INSTRUMENT_ID)), @@ -459,12 +459,12 @@ def get_cost( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) params: list[tuple[str, str | None]] = [ ("dataset", validate_semantic_string(dataset, "dataset")), ("start", datetime_to_string(start)), ("end", optional_datetime_to_string(end)), - ("symbols", symbols_list), + ("symbols", ",".join(symbols_list)), ("schema", str(validate_enum(schema, Schema, "schema"))), ("stype_in", str(stype_in_valid)), ("stype_out", str(SType.INSTRUMENT_ID)), diff --git a/databento/historical/api/symbology.py b/databento/historical/api/symbology.py index 859f3aa..80dcb5f 100644 --- a/databento/historical/api/symbology.py +++ b/databento/historical/api/symbology.py @@ -9,7 +9,7 @@ from databento.common.enums import Dataset from databento.common.parsing import datetime_to_date_string from databento.common.parsing import optional_date_to_string -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string from databento.historical.api import API_VERSION @@ -65,10 +65,10 @@ def resolve( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) data: dict[str, object | None] = { "dataset": validate_semantic_string(dataset, "dataset"), - "symbols": symbols_list, + "symbols": ",".join(symbols_list), "stype_in": str(stype_in_valid), "stype_out": str(validate_enum(stype_out, SType, "stype_out")), "start_date": datetime_to_date_string(start_date), diff --git a/databento/historical/api/timeseries.py b/databento/historical/api/timeseries.py index 2f080c7..5fdd3c4 100644 --- a/databento/historical/api/timeseries.py +++ b/databento/historical/api/timeseries.py @@ -13,7 +13,7 @@ from databento.common.enums import Dataset from databento.common.parsing import datetime_to_string from databento.common.parsing import optional_datetime_to_string -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string from databento.historical.api import API_VERSION @@ -95,7 +95,7 @@ def get_range( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) schema_valid = validate_enum(schema, Schema, "schema") start_valid = datetime_to_string(start) end_valid = optional_datetime_to_string(end) @@ -103,7 +103,7 @@ def get_range( "dataset": validate_semantic_string(dataset, "dataset"), "start": start_valid, "end": end_valid, - "symbols": symbols_list, + "symbols": ",".join(symbols_list), "schema": str(schema_valid), "stype_in": str(stype_in_valid), "stype_out": str(validate_enum(stype_out, SType, "stype_out")), @@ -189,7 +189,7 @@ async def get_range_async( """ stype_in_valid = validate_enum(stype_in, SType, "stype_in") - symbols_list = optional_symbols_list_to_string(symbols, stype_in_valid) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) schema_valid = validate_enum(schema, Schema, "schema") start_valid = datetime_to_string(start) end_valid = optional_datetime_to_string(end) @@ -197,7 +197,7 @@ async def get_range_async( "dataset": validate_semantic_string(dataset, "dataset"), "start": start_valid, "end": end_valid, - "symbols": symbols_list, + "symbols": ",".join(symbols_list), "schema": str(schema_valid), "stype_in": str(stype_in_valid), "stype_out": str(validate_enum(stype_out, SType, "stype_out")), diff --git a/databento/live/client.py b/databento/live/client.py index 8c2302d..73f681e 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -18,7 +18,6 @@ from databento.common.enums import Dataset from databento.common.error import BentoError from databento.common.parsing import optional_datetime_to_unix_nanoseconds -from databento.common.parsing import optional_symbols_list_to_string from databento.common.symbology import ALL_SYMBOLS from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string @@ -453,7 +452,6 @@ def subscribe( dataset = validate_semantic_string(dataset, "dataset") schema = validate_enum(schema, Schema, "schema") stype_in = validate_enum(stype_in, SType, "stype_in") - symbols = optional_symbols_list_to_string(symbols, stype_in) start = optional_datetime_to_unix_nanoseconds(start) if not self.dataset: diff --git a/databento/live/gateway.py b/databento/live/gateway.py index 0f173fe..6f2f3ca 100644 --- a/databento/live/gateway.py +++ b/databento/live/gateway.py @@ -18,7 +18,6 @@ T = TypeVar("T", bound="GatewayControl") - @dataclasses.dataclass class GatewayControl: """ diff --git a/databento/live/protocol.py b/databento/live/protocol.py index c3a1ee3..10494de 100644 --- a/databento/live/protocol.py +++ b/databento/live/protocol.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +import itertools import logging from collections.abc import Iterable from functools import singledispatchmethod from numbers import Number +from typing import TypeVar import databento_dbn from databento_dbn import Schema @@ -14,7 +16,7 @@ from databento.common.enums import Dataset from databento.common.error import BentoError from databento.common.parsing import optional_datetime_to_unix_nanoseconds -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.symbology import ALL_SYMBOLS from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string @@ -34,6 +36,38 @@ logger = logging.getLogger(__name__) +_C = TypeVar("_C") + + +def chunk(iterable: Iterable[_C], size: int) -> Iterable[tuple[_C, ...]]: + """ + Break an iterable into chunks with a length of + at most `size`. + + Parameters + ---------- + iterable: Iterable[_C] + The iterable to break up. + size : int + The maximum size of each chunk. + + Returns + ------- + Iterable[_C] + + Raises + ------ + ValueError + If `size` is less than 1. + + """ + if size < 1: + raise ValueError("size must be at least 1") + + it = iter(iterable) + return iter(lambda: tuple(itertools.islice(it, size)), ()) + + class DatabentoLiveProtocol(asyncio.BufferedProtocol): """ A BufferedProtocol implementation for the Databento live subscription @@ -274,14 +308,18 @@ def subscribe( start if start is not None else "now", ) stype_in_valid = validate_enum(stype_in, SType, "stype_in") - message = SubscriptionRequest( - schema=validate_enum(schema, Schema, "schema"), - stype_in=stype_in_valid, - symbols=optional_symbols_list_to_string(symbols, stype_in_valid), - start=optional_datetime_to_unix_nanoseconds(start), - ) + symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) + + for batch in chunk(symbols_list, 128): + batch_str = ",".join(batch) + message = SubscriptionRequest( + schema=validate_enum(schema, Schema, "schema"), + stype_in=stype_in_valid, + symbols=batch_str, + start=optional_datetime_to_unix_nanoseconds(start), + ) - self.transport.write(bytes(message)) + self.transport.write(bytes(message)) def start( self, diff --git a/tests/test_common_parsing.py b/tests/test_common_parsing.py index 90a5f69..28dda40 100644 --- a/tests/test_common_parsing.py +++ b/tests/test_common_parsing.py @@ -10,7 +10,7 @@ from databento.common.parsing import optional_date_to_string from databento.common.parsing import optional_datetime_to_string from databento.common.parsing import optional_datetime_to_unix_nanoseconds -from databento.common.parsing import optional_symbols_list_to_string +from databento.common.parsing import optional_symbols_list_to_list from databento.common.parsing import optional_values_list_to_string from databento_dbn import SType @@ -47,26 +47,19 @@ def test_maybe_values_list_to_string_given_valid_inputs_returns_expected( # Assert assert result == expected - -def test_maybe_symbols_list_to_string_given_invalid_input_raises_type_error() -> None: - # Arrange, Act, Assert - with pytest.raises(TypeError): - optional_symbols_list_to_string(INCORRECT_TYPE, SType.RAW_SYMBOL) - - @pytest.mark.parametrize( "stype, symbols, expected", [ - pytest.param(SType.RAW_SYMBOL, None, "ALL_SYMBOLS"), - pytest.param(SType.PARENT, "ES.fut", "ES.FUT"), - pytest.param(SType.PARENT, "ES,CL", "ES,CL"), - pytest.param(SType.PARENT, "ES,CL,", "ES,CL"), - pytest.param(SType.PARENT, "es,cl,", "ES,CL"), - pytest.param(SType.PARENT, ["ES", "CL"], "ES,CL"), - pytest.param(SType.PARENT, ["es", "cl"], "ES,CL"), - pytest.param(SType.CONTINUOUS, ["ES.N.0", "CL.n.0"], "ES.n.0,CL.n.0"), - pytest.param(SType.CONTINUOUS, ["ES.N.0", ["ES,cl"]], "ES.n.0,ES,CL"), - pytest.param(SType.CONTINUOUS, ["ES.N.0", "ES,cl"], "ES.n.0,ES,CL"), + pytest.param(SType.RAW_SYMBOL, None, ["ALL_SYMBOLS"]), + pytest.param(SType.PARENT, "ES.fut", ["ES.FUT"]), + pytest.param(SType.PARENT, "ES,CL", ["ES", "CL"]), + pytest.param(SType.PARENT, "ES,CL,", ["ES", "CL"]), + pytest.param(SType.PARENT, "es,cl,", ["ES", "CL"]), + pytest.param(SType.PARENT, ["ES", "CL"], ["ES", "CL"]), + pytest.param(SType.PARENT, ["es", "cl"], ["ES", "CL"]), + pytest.param(SType.CONTINUOUS, ["ES.N.0", "CL.n.0"], ["ES.n.0", "CL.n.0"]), + pytest.param(SType.CONTINUOUS, ["ES.N.0", ["ES,cl"]], ["ES.n.0", "ES", "CL"]), + pytest.param(SType.CONTINUOUS, ["ES.N.0", "ES,cl"], ["ES.n.0", "ES", "CL"]), pytest.param(SType.CONTINUOUS, "", ValueError), pytest.param(SType.CONTINUOUS, [""], ValueError), pytest.param(SType.CONTINUOUS, ["ES.N.0", ""], ValueError), @@ -74,27 +67,35 @@ def test_maybe_symbols_list_to_string_given_invalid_input_raises_type_error() -> pytest.param(SType.PARENT, 123458, ValueError), ], ) -def test_optional_symbols_list_to_string_given_valid_inputs_returns_expected( +def test_optional_symbols_list_to_list_given_valid_inputs_returns_expected( stype: SType, symbols: list[str] | None, - expected: str | type[Exception], + expected: list[object] | type[Exception], ) -> None: # Arrange, Act, Assert - if isinstance(expected, str): - assert optional_symbols_list_to_string(symbols, stype) == expected + if isinstance(expected, list): + assert optional_symbols_list_to_list(symbols, stype) == expected else: with pytest.raises(expected): - optional_symbols_list_to_string(symbols, stype) + optional_symbols_list_to_list(symbols, stype) @pytest.mark.parametrize( "symbols, stype, expected", [ - pytest.param(12345, SType.INSTRUMENT_ID, "12345"), - pytest.param("67890", SType.INSTRUMENT_ID, "67890"), - pytest.param([12345, " 67890"], SType.INSTRUMENT_ID, "12345,67890"), - pytest.param([12345, [67890, 66]], SType.INSTRUMENT_ID, "12345,67890,66"), - pytest.param([12345, "67890,66"], SType.INSTRUMENT_ID, "12345,67890,66"), + pytest.param(12345, SType.INSTRUMENT_ID, ["12345"]), + pytest.param("67890", SType.INSTRUMENT_ID, ["67890"]), + pytest.param([12345, " 67890"], SType.INSTRUMENT_ID, ["12345", "67890"]), + pytest.param( + [12345, [67890, 66]], + SType.INSTRUMENT_ID, + ["12345", "67890", "66"], + ), + pytest.param( + [12345, "67890,66"], + SType.INSTRUMENT_ID, + ["12345", "67890", "66"], + ), pytest.param("", SType.INSTRUMENT_ID, ValueError), pytest.param([12345, ""], SType.INSTRUMENT_ID, ValueError), pytest.param([12345, [""]], SType.INSTRUMENT_ID, ValueError), @@ -103,10 +104,10 @@ def test_optional_symbols_list_to_string_given_valid_inputs_returns_expected( pytest.param(12345, SType.CONTINUOUS, ValueError), ], ) -def test_optional_symbols_list_to_string_int( +def test_optional_symbols_list_to_list_int( symbols: list[Number] | Number | None, stype: SType, - expected: str | type[Exception], + expected: list[object] | type[Exception], ) -> None: """ Test that integers are allowed for SType.INSTRUMENT_ID. @@ -114,44 +115,44 @@ def test_optional_symbols_list_to_string_int( If integers are given for a different SType we expect a ValueError. """ - if isinstance(expected, str): - assert optional_symbols_list_to_string(symbols, stype) == expected + if isinstance(expected, list): + assert optional_symbols_list_to_list(symbols, stype) == expected else: with pytest.raises(expected): - optional_symbols_list_to_string(symbols, stype) + optional_symbols_list_to_list(symbols, stype) @pytest.mark.parametrize( "symbols, stype, expected", [ - pytest.param(np.byte(120), SType.INSTRUMENT_ID, "120"), - pytest.param(np.short(32_000), SType.INSTRUMENT_ID, "32000"), + pytest.param(np.byte(120), SType.INSTRUMENT_ID, ["120"]), + pytest.param(np.short(32_000), SType.INSTRUMENT_ID, ["32000"]), pytest.param( [np.intc(12345), np.intc(67890)], SType.INSTRUMENT_ID, - "12345,67890", + ["12345", "67890"], ), pytest.param( [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, - "12345,67890", + ["12345", "67890"], ), pytest.param( [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, - "12345,67890", + ["12345", "67890"], ), pytest.param( [np.int_(12345), np.longlong(67890)], SType.INSTRUMENT_ID, - "12345,67890", + ["12345", "67890"], ), ], ) -def test_optional_symbols_list_to_string_numpy( +def test_optional_symbols_list_to_list_numpy( symbols: list[Number] | Number | None, stype: SType, - expected: str | type[Exception], + expected: list[object] | type[Exception], ) -> None: """ Test that weird numpy types are allowed for SType.INSTRUMENT_ID. @@ -159,42 +160,46 @@ def test_optional_symbols_list_to_string_numpy( If integers are given for a different SType we expect a ValueError. """ - if isinstance(expected, str): - assert optional_symbols_list_to_string(symbols, stype) == expected + if isinstance(expected, list): + assert optional_symbols_list_to_list(symbols, stype) == expected else: with pytest.raises(expected): - optional_symbols_list_to_string(symbols, stype) + optional_symbols_list_to_list(symbols, stype) @pytest.mark.parametrize( "symbols, stype, expected", [ - pytest.param("NVDA", SType.RAW_SYMBOL, "NVDA"), - pytest.param(" nvda ", SType.RAW_SYMBOL, "NVDA"), - pytest.param("NVDA,amd", SType.RAW_SYMBOL, "NVDA,AMD"), - pytest.param("NVDA,amd,NOC,", SType.RAW_SYMBOL, "NVDA,AMD,NOC"), - pytest.param("NVDA, amd,NOC, ", SType.RAW_SYMBOL, "NVDA,AMD,NOC"), - pytest.param(["NVDA", ["NOC", "AMD"]], SType.RAW_SYMBOL, "NVDA,NOC,AMD"), - pytest.param(["NVDA", "NOC,AMD"], SType.RAW_SYMBOL, "NVDA,NOC,AMD"), + pytest.param("NVDA", SType.RAW_SYMBOL, ["NVDA"]), + pytest.param(" nvda ", SType.RAW_SYMBOL, ["NVDA"]), + pytest.param("NVDA,amd", SType.RAW_SYMBOL, ["NVDA", "AMD"]), + pytest.param("NVDA,amd,NOC,", SType.RAW_SYMBOL, ["NVDA", "AMD", "NOC"]), + pytest.param("NVDA, amd,NOC, ", SType.RAW_SYMBOL, ["NVDA", "AMD", "NOC"]), + pytest.param( + ["NVDA", ["NOC", "AMD"]], + SType.RAW_SYMBOL, + ["NVDA", "NOC", "AMD"], + ), + pytest.param(["NVDA", "NOC,AMD"], SType.RAW_SYMBOL, ["NVDA", "NOC", "AMD"]), pytest.param("", SType.RAW_SYMBOL, ValueError), pytest.param([""], SType.RAW_SYMBOL, ValueError), pytest.param(["NVDA", ""], SType.RAW_SYMBOL, ValueError), pytest.param(["NVDA", [""]], SType.RAW_SYMBOL, ValueError), ], ) -def test_optional_symbols_list_to_string_raw_symbol( +def test_optional_symbols_list_to_list_raw_symbol( symbols: list[Number] | Number | None, stype: SType, - expected: str | type[Exception], + expected: list[object] | type[Exception], ) -> None: """ Test that str are allowed for SType.RAW_SYMBOL. """ - if isinstance(expected, str): - assert optional_symbols_list_to_string(symbols, stype) == expected + if isinstance(expected, list): + assert optional_symbols_list_to_list(symbols, stype) == expected else: with pytest.raises(expected): - optional_symbols_list_to_string(symbols, stype) + optional_symbols_list_to_list(symbols, stype) @pytest.mark.parametrize( diff --git a/tests/test_live_client.py b/tests/test_live_client.py index cc6b310..28ff14f 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -5,6 +5,8 @@ import pathlib import platform +import random +import string from io import BytesIO from typing import Callable from unittest.mock import MagicMock @@ -365,6 +367,38 @@ def test_live_subscribe( assert message.start == start +async def test_live_subscribe_large_symbol_list( + live_client: client.Live, + mock_live_server: MockLiveServer, +) -> None: + """ + Test that sending a subscription with a large symbol list breaks that list + up into multiple messages. + """ + large_symbol_list = list( + random.choices(string.ascii_uppercase, k=256), # noqa: S311 + ) + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.RAW_SYMBOL, + symbols=large_symbol_list, + ) + + first_message = mock_live_server.get_message_of_type( + gateway.SubscriptionRequest, + timeout=1, + ) + + second_message = mock_live_server.get_message_of_type( + gateway.SubscriptionRequest, + timeout=1, + ) + + reconstructed = first_message.symbols.split(",") + second_message.symbols.split(",") + assert reconstructed == large_symbol_list + + @pytest.mark.usefixtures("mock_live_server") def test_live_stop( live_client: client.Live, From c2a65e1f6ab12b3f68ad699fec571b6a28e3cd71 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Fri, 30 Jun 2023 09:41:02 -0700 Subject: [PATCH 14/17] VER: Release 0.15.0 --- CHANGELOG.md | 4 ++-- README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88fd0a7..c83f787 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## 0.15.0 - TBD +## 0.15.0 - 2023-07-03 #### Enhancements - Added `symbology_map` property to `Live` client @@ -8,13 +8,13 @@ - Changed `Live.add_callback` and `Live.add_stream` to accept an exception callback - Changed `Live.__iter__()` and `Live.__aiter__()` to send the session start message if the session is connected but not started - Upgraded `databento-dbn` to 0.7.1 -- Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` - Removed exception chaining from exceptions emitted by the library #### Bug fixes - Fixed issue where a large unreadable symbol subscription message could be sent #### Breaking changes +- Removed `Encoding`, `Compression`, `Schema`, and `SType` enums as they are now exposed by `databento-dbn` - Renamed `func` parameter to `record_callback` for `Live.add_callback` and `Live.add_stream` - Removed `optional_symbols_list_to_string` parsing function diff --git a/README.md b/README.md index 984e279..9aba4d9 100644 --- a/README.md +++ b/README.md @@ -32,10 +32,10 @@ The library is fully compatible with the latest distribution of Anaconda 3.8 and The minimum dependencies as found in the `pyproject.toml` are also listed below: - python = "^3.8" - aiohttp = "^3.8.3" -- databento-dbn = "0.6.1" +- databento-dbn = "0.7.1" - numpy= ">=1.23.5" - pandas = ">=1.5.3" -- requests = ">=2.28.1" +- requests = ">=2.24.0" - zstandard = ">=0.21.0" ## Installation From 248f7258e99f22f7707a4c9278f9852f08103044 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Mon, 3 Jul 2023 14:59:52 -0700 Subject: [PATCH 15/17] FIX: Fix pipeline errors in client library --- CHANGELOG.md | 5 +++++ databento/common/parsing.py | 6 +++--- databento/version.py | 2 +- tests/mock_live_server.py | 5 ++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c83f787..26c78b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 0.15.1 - TBD + +#### Bug Fixes +- Fixed an `ImportError` observed in Python 3.8 + ## 0.15.0 - 2023-07-03 #### Enhancements diff --git a/databento/common/parsing.py b/databento/common/parsing.py index e99c291..ea1a54c 100644 --- a/databento/common/parsing.py +++ b/databento/common/parsing.py @@ -88,7 +88,7 @@ def optional_symbols_list_to_list( ) -@optional_symbols_list_to_list.register +@optional_symbols_list_to_list.register(cls=type(None)) def _(_: None, __: SType) -> list[str]: """ Dispatch method for optional_symbols_list_to_list. Handles None which @@ -102,7 +102,7 @@ def _(_: None, __: SType) -> list[str]: return [ALL_SYMBOLS] -@optional_symbols_list_to_list.register +@optional_symbols_list_to_list.register(cls=Number) def _(symbols: Number, stype_in: SType) -> list[str]: """ Dispatch method for optional_symbols_list_to_list. Handles numerical types, @@ -121,7 +121,7 @@ def _(symbols: Number, stype_in: SType) -> list[str]: ) -@optional_symbols_list_to_list.register +@optional_symbols_list_to_list.register(cls=str) def _(symbols: str, stype_in: SType) -> list[str]: """ Dispatch method for optional_symbols_list_to_list. Handles str, splitting diff --git a/databento/version.py b/databento/version.py index 9da2f8f..903e77c 100644 --- a/databento/version.py +++ b/databento/version.py @@ -1 +1 @@ -__version__ = "0.15.0" +__version__ = "0.15.1" diff --git a/tests/mock_live_server.py b/tests/mock_live_server.py index 7521e7d..b625293 100644 --- a/tests/mock_live_server.py +++ b/tests/mock_live_server.py @@ -616,7 +616,10 @@ def get_message_of_type( end_time = self._loop.time() + timeout while start_time < end_time: remaining_time = abs(end_time - self._loop.time()) - message = self._message_queue.get(timeout=remaining_time) + try: + message = self._message_queue.get(timeout=remaining_time) + except queue.Empty: + continue if isinstance(message, message_type): return message From 6bcce0c7b43bd1e0d23cbd2c8e9b8829290f6b0d Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Mon, 3 Jul 2023 16:30:29 -0700 Subject: [PATCH 16/17] FIX: Skip test for Windows runner --- tests/mock_live_server.py | 2 +- tests/test_live_client.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/mock_live_server.py b/tests/mock_live_server.py index b625293..4b928c6 100644 --- a/tests/mock_live_server.py +++ b/tests/mock_live_server.py @@ -619,7 +619,7 @@ def get_message_of_type( try: message = self._message_queue.get(timeout=remaining_time) except queue.Empty: - continue + break if isinstance(message, message_type): return message diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 28ff14f..7c28682 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -367,6 +367,7 @@ def test_live_subscribe( assert message.start == start +@pytest.mark.skipif(platform.system() == "Windows", reason="timeout on windows") async def test_live_subscribe_large_symbol_list( live_client: client.Live, mock_live_server: MockLiveServer, From b2abb912bece2e7f20e86c100800597c699a321a Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Wed, 5 Jul 2023 10:42:44 -0700 Subject: [PATCH 17/17] VER: Release 0.15.1 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26c78b1..700841a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## 0.15.1 - TBD +## 0.15.1 - 2023-07-05 #### Bug Fixes - Fixed an `ImportError` observed in Python 3.8