Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions src/dipdup/datasources/tzkt/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import Any, AsyncGenerator, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast
from functools import wraps
from typing import Any, AsyncGenerator, Callable, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast

from aiohttp import ClientResponseError
from aiosignalrcore.hub.base_hub_connection import BaseHubConnection # type: ignore
Expand Down Expand Up @@ -340,7 +341,7 @@ def __init__(
self._transaction_subscriptions: Set[str] = set()
self._origination_subscriptions: bool = False
self._big_map_subscriptions: Dict[str, Set[str]] = {}
self._client: Optional[BaseHubConnection] = None
self._ws_client: Optional[BaseHubConnection] = None

self._block_cache: BlockCache = BlockCache()
self._head: Optional[Head] = None
Expand Down Expand Up @@ -587,41 +588,53 @@ async def add_index(self, index_config: ResolvedIndexConfigT) -> None:

await self._on_connect()

def _get_client(self) -> BaseHubConnection:
def _get_ws_client(self) -> BaseHubConnection:
"""Create SignalR client, register message callbacks"""
if self._client is None:
self._logger.info('Creating websocket client')
self._client = (
HubConnectionBuilder()
.with_url(self._http._url + '/v1/events')
.with_automatic_reconnect(
{
"type": "raw",
"keep_alive_interval": 10,
"reconnect_interval": 5,
"max_attempts": 5,
}
)
).build()
if self._ws_client:
return self._ws_client

self._logger.info('Creating websocket client')
self._ws_client = (
HubConnectionBuilder()
.with_url(self._http._url + '/v1/events')
.with_automatic_reconnect(
{
"type": "raw",
"keep_alive_interval": 10,
"reconnect_interval": 5,
"max_attempts": 5,
}
)
).build()

_ws_lock = asyncio.Lock()

def _lock_wrapper(fn: Callable):
@wraps(fn)
async def _wrapper(*args, **kwargs):
async with _ws_lock:
return await fn(*args, **kwargs)

return _wrapper

self._client.on_open(self._on_connect)
self._client.on_error(self._on_error)
self._client.on('operations', self._on_operation_message)
self._client.on('bigmaps', self._on_big_map_message)
self._client.on('head', self._on_head_message)
self._ws_client.on_open(_lock_wrapper(self._on_connect))
self._ws_client.on_error(_lock_wrapper(self._on_error))
self._ws_client.on('operations', _lock_wrapper(self._on_operation_message))
self._ws_client.on('bigmaps', _lock_wrapper(self._on_big_map_message))
self._ws_client.on('head', _lock_wrapper(self._on_head_message))

return self._client
return self._ws_client

async def run(self) -> None:
"""Main loop. Sync indexes via REST, start WS connection"""
self._logger.info('Starting datasource')

self._logger.info('Starting websocket client')
await self._get_client().start()
await self._get_ws_client().start()

async def _on_connect(self) -> None:
"""Subscribe to all required channels on established WS connection"""
if self._get_client().transport.state != ConnectionState.connected:
if self._get_ws_client().transport.state != ConnectionState.connected:
return

self._logger.info('Realtime connection established, subscribing to channels')
Expand Down Expand Up @@ -932,7 +945,7 @@ def convert_quote(cls, quote_json: Dict[str, Any]) -> QuoteData:
)

async def _send(self, method: str, arguments: List[Dict[str, Any]], on_invocation=None) -> None:
client = self._get_client()
client = self._get_ws_client()
while client.transport.state != ConnectionState.connected:
await asyncio.sleep(0.1)
await client.send(method, arguments, on_invocation)
Expand Down