diff --git a/CHANGELOG.md b/CHANGELOG.md index a6811698f..bb9deffb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,15 +2,25 @@ ## [unreleased] +### Improved + +* A significant increase in indexing speed. + ### Fixed -* Removed unnecessary file IO calls, improved logging +* Fixed unexpected reindexing caused by the bug in processing zero- and single-level rollbacks. +* Removed unnecessary file IO calls that could cause `PermissionError` exception in Docker environments. +* Fixed possible violation of block-level atomicity during real-time indexing. + +### Changes + +* Public methods of `TzktDatasource` now return immutable sequences. ## 3.0.3 - 2021-10-01 ### Fixed -* Fixed processing of single level rollbacks emitted before rolled back head +* Fixed processing of single-level rollbacks emitted before rolled back head. ## 3.0.2 - 2021-09-30 diff --git a/src/dipdup/cli.py b/src/dipdup/cli.py index ff2645b34..63eba3710 100644 --- a/src/dipdup/cli.py +++ b/src/dipdup/cli.py @@ -20,7 +20,7 @@ from dipdup.codegen import DEFAULT_DOCKER_ENV_FILE, DEFAULT_DOCKER_IMAGE, DEFAULT_DOCKER_TAG, DipDupCodeGenerator from dipdup.config import DipDupConfig, LoggingConfig, PostgresDatabaseConfig from dipdup.dipdup import DipDup -from dipdup.exceptions import ConfigurationError, DeprecatedHandlerError, DipDupError, MigrationRequiredError +from dipdup.exceptions import ConfigurationError, DeprecatedHandlerError, DipDupError, InitializationRequiredError, MigrationRequiredError from dipdup.hasura import HasuraGateway from dipdup.migrations import DipDupMigrationManager, deprecated_handlers from dipdup.utils.database import set_decimal_context, tortoise_wrapper @@ -113,7 +113,10 @@ async def cli(ctx, config: List[str], env_file: List[str], logging_config: str): _config = DipDupConfig.load(config) init_sentry(_config) - await DipDupCodeGenerator(_config, {}).create_package() + try: + await DipDupCodeGenerator(_config, {}).create_package() + except Exception as e: + raise InitializationRequiredError from e if _config.spec_version not in spec_version_mapping: raise ConfigurationError(f'Unknown `spec_version`, correct ones: {", ".join(spec_version_mapping)}') diff --git a/src/dipdup/context.py b/src/dipdup/context.py index 35c057ac5..0a87cb7da 100644 --- a/src/dipdup/context.py +++ b/src/dipdup/context.py @@ -90,22 +90,26 @@ async def restart(self) -> None: sys.argv.remove('--reindex') os.execl(sys.executable, sys.executable, *sys.argv) - async def reindex(self, reason: Optional[Union[str, ReindexingReason]] = None) -> None: + async def reindex(self, reason: Optional[Union[str, ReindexingReason]] = None, **context) -> None: """Drop all tables or whole database and restart with the same CLI arguments""" - reason_str = reason.value if isinstance(reason, ReindexingReason) else 'unknown' - self.logger.warning('Reindexing initialized, reason: %s', reason_str) - - if not reason or isinstance(reason, str): + if not reason: + reason = ReindexingReason.MANUAL + elif isinstance(reason, str): + context['message'] = reason reason = ReindexingReason.MANUAL + reason_str = reason.value + f' ({context["message"]})' if "message" in context else '' + self.logger.warning('Reindexing initialized, reason: %s', reason_str) + self.logger.info('Additional context: %s', context) + if forbid_reindexing: schema = await Schema.filter().get() if schema.reindex: - raise ReindexingRequiredError(schema.reindex) + raise ReindexingRequiredError(schema.reindex, context) schema.reindex = reason await schema.save() - raise ReindexingRequiredError(schema.reindex) + raise ReindexingRequiredError(schema.reindex, context) database_config = self.config.database if isinstance(database_config, PostgresDatabaseConfig): diff --git a/src/dipdup/datasources/datasource.py b/src/dipdup/datasources/datasource.py index 743e29651..370066ea6 100644 --- a/src/dipdup/datasources/datasource.py +++ b/src/dipdup/datasources/datasource.py @@ -1,6 +1,6 @@ import logging from abc import abstractmethod -from typing import Awaitable, Callable, List, Set +from typing import Awaitable, Callable, Set, Tuple from dipdup.config import HTTPConfig from dipdup.http import HTTPGateway @@ -11,8 +11,8 @@ HeadCallbackT = Callable[['IndexDatasource', HeadBlockData], Awaitable[None]] -OperationsCallbackT = Callable[['IndexDatasource', List[OperationData]], Awaitable[None]] -BigMapsCallbackT = Callable[['IndexDatasource', List[BigMapData]], Awaitable[None]] +OperationsCallbackT = Callable[['IndexDatasource', Tuple[OperationData, ...]], Awaitable[None]] +BigMapsCallbackT = Callable[['IndexDatasource', Tuple[BigMapData, ...]], Awaitable[None]] RollbackCallbackT = Callable[['IndexDatasource', int, int], Awaitable[None]] @@ -57,11 +57,11 @@ async def emit_head(self, head: HeadBlockData) -> None: for fn in self._on_head: await fn(self, head) - async def emit_operations(self, operations: List[OperationData]) -> None: + async def emit_operations(self, operations: Tuple[OperationData, ...]) -> None: for fn in self._on_operations: await fn(self, operations) - async def emit_big_maps(self, big_maps: List[BigMapData]) -> None: + async def emit_big_maps(self, big_maps: Tuple[BigMapData, ...]) -> None: for fn in self._on_big_maps: await fn(self, big_maps) diff --git a/src/dipdup/datasources/tzkt/datasource.py b/src/dipdup/datasources/tzkt/datasource.py index f0d54fba9..a897a6f04 100644 --- a/src/dipdup/datasources/tzkt/datasource.py +++ b/src/dipdup/datasources/tzkt/datasource.py @@ -1,10 +1,9 @@ import asyncio import logging -from collections import defaultdict +from collections import defaultdict, deque from datetime import datetime, timezone from decimal import Decimal -from enum import Enum -from typing import Any, AsyncGenerator, DefaultDict, Dict, List, NoReturn, Optional, Set, Tuple, cast +from typing import Any, AsyncGenerator, DefaultDict, Deque, Dict, List, NoReturn, Optional, Set, Tuple, cast from aiohttp import ClientResponseError from aiosignalrcore.hub.base_hub_connection import BaseHubConnection # type: ignore @@ -21,62 +20,27 @@ ResolvedIndexConfigT, ) from dipdup.datasources.datasource import IndexDatasource -from dipdup.datasources.tzkt.enums import TzktMessageType +from dipdup.datasources.tzkt.enums import ( + ORIGINATION_MIGRATION_FIELDS, + ORIGINATION_OPERATION_FIELDS, + TRANSACTION_OPERATION_FIELDS, + OperationFetcherRequest, + TzktMessageType, +) from dipdup.enums import MessageType from dipdup.models import BigMapAction, BigMapData, BlockData, HeadBlockData, OperationData, QuoteData -from dipdup.utils import groupby, split_by_chunks +from dipdup.utils import split_by_chunks TZKT_ORIGINATIONS_REQUEST_LIMIT = 100 -OPERATION_FIELDS = ( - "type", - "id", - "level", - "timestamp", - "hash", - "counter", - "sender", - "nonce", - "target", - "initiator", - "amount", - "storage", - "status", - "hasInternals", - "diffs", -) -ORIGINATION_MIGRATION_FIELDS = ( - "id", - "level", - "timestamp", - "storage", - "diffs", - "account", - "balanceChange", -) -ORIGINATION_OPERATION_FIELDS = ( - *OPERATION_FIELDS, - "originatedContract", -) -TRANSACTION_OPERATION_FIELDS = ( - *OPERATION_FIELDS, - "parameter", - "hasInternals", -) - - -class OperationFetcherChannel(Enum): - """Represents multiple TzKT calls to be merged into a single batch of operations""" - - sender_transactions = 'sender_transactions' - target_transactions = 'target_transactions' - originations = 'originations' -def dedup_operations(operations: List[OperationData]) -> List[OperationData]: +def dedup_operations(operations: Tuple[OperationData, ...]) -> Tuple[OperationData, ...]: """Merge operations from multiple endpoints""" - return sorted( - list(({op.id: op for op in operations}).values()), - key=lambda op: op.id, + return tuple( + sorted( + tuple(({op.id: op for op in operations}).values()), + key=lambda op: op.id, + ) ) @@ -93,7 +57,7 @@ def __init__( transaction_addresses: Set[str], origination_addresses: Set[str], cache: bool = False, - migration_originations: List[OperationData] = None, + migration_originations: Tuple[OperationData, ...] = None, ) -> None: self._datasource = datasource self._first_level = first_level @@ -104,17 +68,15 @@ def __init__( self._logger = logging.getLogger('dipdup.tzkt') self._head: int = 0 - self._heads: Dict[OperationFetcherChannel, int] = {} - self._offsets: Dict[OperationFetcherChannel, int] = {} - self._fetched: Dict[OperationFetcherChannel, bool] = {} + self._heads: Dict[OperationFetcherRequest, int] = {} + self._offsets: Dict[OperationFetcherRequest, int] = {} + self._fetched: Dict[OperationFetcherRequest, bool] = {} - self._operations: DefaultDict[int, List[OperationData]] - if migration_originations: - self._operations = groupby(migration_originations, lambda op: op.level) - else: - self._operations = defaultdict(list) + self._operations: DefaultDict[int, Deque[OperationData]] = defaultdict(deque) + for origination in migration_originations or (): + self._operations[origination.level].append(origination) - def _get_operations_head(self, operations: List[OperationData]) -> int: + def _get_operations_head(self, operations: Tuple[OperationData, ...]) -> int: """Get latest block level (head) of sorted operations batch""" for i in range(len(operations) - 1)[::-1]: if operations[i].level != operations[i + 1].level: @@ -123,7 +85,7 @@ def _get_operations_head(self, operations: List[OperationData]) -> int: async def _fetch_originations(self) -> None: """Fetch a single batch of originations, bump channel offset""" - key = OperationFetcherChannel.originations + key = OperationFetcherRequest.originations if not self._origination_addresses: self._fetched[key] = True self._heads[key] = self._last_level @@ -142,8 +104,6 @@ async def _fetch_originations(self) -> None: for op in originations: level = op.level - if level not in self._operations: - self._operations[level] = [] self._operations[level].append(op) self._logger.debug('Got %s', len(originations)) @@ -157,7 +117,7 @@ async def _fetch_originations(self) -> None: async def _fetch_transactions(self, field: str) -> None: """Fetch a single batch of transactions, bump channel offset""" - key = getattr(OperationFetcherChannel, field + '_transactions') + key = getattr(OperationFetcherRequest, field + '_transactions') if not self._transaction_addresses: self._fetched[key] = True self._heads[key] = self._last_level @@ -177,8 +137,6 @@ async def _fetch_transactions(self, field: str) -> None: for op in transactions: level = op.level - if level not in self._operations: - self._operations[level] = [] self._operations[level].append(op) self._logger.debug('Got %s', len(transactions)) @@ -190,12 +148,15 @@ async def _fetch_transactions(self, field: str) -> None: self._offsets[key] += self._datasource.request_limit self._heads[key] = self._get_operations_head(transactions) - async def fetch_operations_by_level(self) -> AsyncGenerator[Tuple[int, List[OperationData]], None]: - """Iterate by operations from multiple channels. Return is splitted by level, deduped/sorted and ready to be passeed to Matcher.""" + async def fetch_operations_by_level(self) -> AsyncGenerator[Tuple[int, Tuple[OperationData, ...]], None]: + """Iterate over operations fetched with multiple REST requests with different filters. + + Resulting data is splitted by level, deduped, sorted and ready to be processed by OperationIndex. + """ for type_ in ( - OperationFetcherChannel.sender_transactions, - OperationFetcherChannel.target_transactions, - OperationFetcherChannel.originations, + OperationFetcherRequest.sender_transactions, + OperationFetcherRequest.target_transactions, + OperationFetcherRequest.originations, ): self._heads[type_] = 0 self._offsets[type_] = 0 @@ -203,11 +164,11 @@ async def fetch_operations_by_level(self) -> AsyncGenerator[Tuple[int, List[Oper while True: min_head = sorted(self._heads.items(), key=lambda x: x[1])[0][0] - if min_head == OperationFetcherChannel.originations: + if min_head == OperationFetcherRequest.originations: await self._fetch_originations() - elif min_head == OperationFetcherChannel.target_transactions: + elif min_head == OperationFetcherRequest.target_transactions: await self._fetch_transactions('target') - elif min_head == OperationFetcherChannel.sender_transactions: + elif min_head == OperationFetcherRequest.sender_transactions: await self._fetch_transactions('sender') else: raise RuntimeError @@ -216,7 +177,7 @@ async def fetch_operations_by_level(self) -> AsyncGenerator[Tuple[int, List[Oper while self._head <= head: if self._head in self._operations: operations = self._operations.pop(self._head) - yield self._head, dedup_operations(operations) + yield self._head, dedup_operations(tuple(operations)) self._head += 1 if all(list(self._fetched.values())): @@ -243,11 +204,14 @@ def __init__( self._big_map_paths = big_map_paths self._cache = cache - async def fetch_big_maps_by_level(self) -> AsyncGenerator[Tuple[int, List[BigMapData]], None]: - """Fetch big map diffs via Fetcher (not implemented yet) and pass to message callback""" + async def fetch_big_maps_by_level(self) -> AsyncGenerator[Tuple[int, Tuple[BigMapData, ...]], None]: + """Iterate over big map diffs fetched fetched from REST. + + Resulting data is splitted by level, deduped, sorted and ready to be processed by BigMapIndex. + """ offset = 0 - big_maps = [] + big_maps: Tuple[BigMapData, ...] = tuple() while True: fetched_big_maps = await self._datasource.get_big_maps( @@ -258,13 +222,13 @@ async def fetch_big_maps_by_level(self) -> AsyncGenerator[Tuple[int, List[BigMap self._last_level, cache=self._cache, ) - big_maps += fetched_big_maps + big_maps = big_maps + fetched_big_maps while True: for i in range(len(big_maps) - 1): if big_maps[i].level != big_maps[i + 1].level: - yield big_maps[i].level, big_maps[: i + 1] - big_maps = big_maps[i + 1 :] # noqa: E203 + yield big_maps[i].level, tuple(big_maps[: i + 1]) + big_maps = big_maps[i + 1 :] break else: break @@ -275,7 +239,7 @@ async def fetch_big_maps_by_level(self) -> AsyncGenerator[Tuple[int, List[BigMap offset += self._datasource.request_limit if big_maps: - yield big_maps[0].level, big_maps[: i + 1] + yield big_maps[0].level, tuple(big_maps[: i + 1]) class TzktDatasource(IndexDatasource): @@ -310,7 +274,7 @@ def __init__( self._big_map_subscriptions: Dict[str, Set[str]] = {} self._ws_client: Optional[BaseHubConnection] = None - self._level: Optional[int] = None + self._level: DefaultDict[MessageType, Optional[int]] = defaultdict(lambda: None) self._sync_level: Optional[int] = None @property @@ -321,7 +285,7 @@ def request_limit(self) -> int: def sync_level(self) -> Optional[int]: return self._sync_level - async def get_similar_contracts(self, address: str, strict: bool = False) -> List[str]: + async def get_similar_contracts(self, address: str, strict: bool = False) -> Tuple[str, ...]: """Get list of contracts sharing the same code hash or type hash""" entrypoint = 'same' if strict else 'similar' self._logger.info('Fetching %s contracts for address `%s', entrypoint, address) @@ -334,9 +298,9 @@ async def get_similar_contracts(self, address: str, strict: bool = False) -> Lis limit=self.request_limit, ), ) - return contracts + return tuple(c for c in contracts) - async def get_originated_contracts(self, address: str) -> List[str]: + async def get_originated_contracts(self, address: str) -> Tuple[str, ...]: """Get contracts originated from given address""" self._logger.info('Fetching originated contracts for address `%s', address) contracts = await self._http.request( @@ -346,7 +310,7 @@ async def get_originated_contracts(self, address: str) -> List[str]: limit=self.request_limit, ), ) - return [c['address'] for c in contracts] + return tuple(c['address'] for c in contracts) async def get_contract_summary(self, address: str) -> Dict[str, Any]: """Get contract summary""" @@ -393,7 +357,7 @@ async def get_block(self, level: int) -> BlockData: ) return self.convert_block(block_json) - async def get_migration_originations(self, first_level: int = 0) -> List[OperationData]: + async def get_migration_originations(self, first_level: int = 0) -> Tuple[OperationData, ...]: """Get contracts originated from migrations""" self._logger.info('Fetching contracts originated with migrations') # NOTE: Empty unwrapped request to ensure API supports migration originations @@ -407,7 +371,7 @@ async def get_migration_originations(self, first_level: int = 0) -> List[Operati }, ) except ClientResponseError: - return [] + return () raw_migrations = await self._http.request( 'get', @@ -418,11 +382,11 @@ async def get_migration_originations(self, first_level: int = 0) -> List[Operati 'select': ','.join(ORIGINATION_MIGRATION_FIELDS), }, ) - return [self.convert_migration_origination(m) for m in raw_migrations] + return tuple(self.convert_migration_origination(m) for m in raw_migrations) async def get_originations( self, addresses: Set[str], offset: int, first_level: int, last_level: int, cache: bool = False - ) -> List[OperationData]: + ) -> Tuple[OperationData, ...]: raw_originations = [] # NOTE: TzKT may hit URL length limit with hundreds of originations in a single request. # NOTE: Chunk of 100 addresses seems like a reasonable choice - URL of ~3971 characters. @@ -443,16 +407,16 @@ async def get_originations( cache=cache, ) - originations = [] for op in raw_originations: # NOTE: `type` field needs to be set manually when requesting operations by specific type op['type'] = 'origination' - originations.append(self.convert_operation(op)) + + originations = tuple(self.convert_operation(op) for op in raw_originations) return originations async def get_transactions( self, field: str, addresses: Set[str], offset: int, first_level: int, last_level: int, cache: bool = False - ) -> List[OperationData]: + ) -> Tuple[OperationData, ...]: raw_transactions = await self._http.request( 'get', url='v1/operations/transactions', @@ -467,16 +431,16 @@ async def get_transactions( }, cache=cache, ) - transactions = [] for op in raw_transactions: # NOTE: type needs to be set manually when requesting operations by specific type op['type'] = 'transaction' - transactions.append(self.convert_operation(op)) + + transactions = tuple(self.convert_operation(op) for op in raw_transactions) return transactions async def get_big_maps( self, addresses: Set[str], paths: Set[str], offset: int, first_level: int, last_level: int, cache: bool = False - ) -> List[BigMapData]: + ) -> Tuple[BigMapData, ...]: raw_big_maps = await self._http.request( 'get', url='v1/bigmaps/updates', @@ -490,9 +454,7 @@ async def get_big_maps( }, cache=cache, ) - big_maps = [] - for bm in raw_big_maps: - big_maps.append(self.convert_big_map(bm)) + big_maps = tuple(self.convert_big_map(bm) for bm in raw_big_maps) return big_maps async def get_quote(self, level: int) -> QuoteData: @@ -506,7 +468,7 @@ async def get_quote(self, level: int) -> QuoteData: ) return self.convert_quote(quote_json[0]) - async def get_quotes(self, from_level: int, to_level: int) -> List[QuoteData]: + async def get_quotes(self, from_level: int, to_level: int) -> Tuple[QuoteData, ...]: """Get quotes for blocks""" self._logger.info('Fetching quotes for levels %s-%s', from_level, to_level) quotes_json = await self._http.request( @@ -519,7 +481,7 @@ async def get_quotes(self, from_level: int, to_level: int) -> List[QuoteData]: }, cache=False, ) - return [self.convert_quote(quote) for quote in quotes_json] + return tuple(self.convert_quote(quote) for quote in quotes_json) async def add_index(self, index_config: ResolvedIndexConfigT) -> None: """Register index config in internal mappings and matchers. Find and register subscriptions.""" @@ -653,29 +615,38 @@ async def _subscribe_to_head(self) -> None: ) async def _extract_message_data(self, type_: MessageType, message: List[Any]) -> AsyncGenerator[Dict, None]: - # TODO: Docstring + """Parse message received from Websocket, ensure it's correct in the current context and yield data.""" for item in message: tzkt_type = TzktMessageType(item['type']) - head_level = item['state'] + level, current_level = item['state'], self._level[type_] + self._level[type_] = level - self._logger.info('Realtime message received: %s, %s', type_, tzkt_type) + self._logger.info('Realtime message received: %s, %s, %s -> %s', type_.value, tzkt_type.name, current_level, level) - # NOTE: State messages will be replaced with WS negotiation some day + # NOTE: Ensure correctness, update sync level if tzkt_type == TzktMessageType.STATE: - if self._sync_level != head_level: - self._logger.info('Datasource level set to %s', head_level) - self._sync_level = head_level - self._level = head_level + if self._sync_level < level: + self._logger.info('Datasource sync level has been updated: %s -> %s', self._sync_level, level) + self._sync_level = level + elif self._sync_level > level: + raise RuntimeError('Attempt to set sync level to the lower value: %s -> %s', self._sync_level, level) + else: + pass + # NOTE: Just yield data elif tzkt_type == TzktMessageType.DATA: - self._level = head_level yield item['data'] + # NOTE: Emit rollback, but not on `head` message elif tzkt_type == TzktMessageType.REORG: - if self._level is None: - raise RuntimeError('Reorg message received but datasource is not connected') - self._logger.info('Emitting rollback from %s to %s', self._level, head_level) - await self.emit_rollback(self._level, head_level) + if current_level is None: + raise RuntimeError('Reorg message received but level is not set') + # NOTE: operation/big_map channels have their own levels + if type_ == MessageType.head: + return + + self._logger.info('Emitting rollback from %s to %s', current_level, level) + await self.emit_rollback(current_level, level) else: raise NotImplementedError @@ -683,23 +654,23 @@ async def _extract_message_data(self, type_: MessageType, message: List[Any]) -> async def _on_operations_message(self, message: List[Dict[str, Any]]) -> None: """Parse and emit raw operations from WS""" async for data in self._extract_message_data(MessageType.operation, message): - operations = [] + operations: Deque[OperationData] = deque() for operation_json in data: if operation_json['status'] != 'applied': continue operation = self.convert_operation(operation_json) operations.append(operation) if operations: - await self.emit_operations(operations) + await self.emit_operations(tuple(operations)) async def _on_big_maps_message(self, message: List[Dict[str, Any]]) -> None: """Parse and emit raw big map diffs from WS""" async for data in self._extract_message_data(MessageType.big_map, message): - big_maps = [] + big_maps: Deque[BigMapData] = deque() for big_map_json in data: big_map = self.convert_big_map(big_map_json) big_maps.append(big_map) - await self.emit_big_maps(big_maps) + await self.emit_big_maps(tuple(big_maps)) async def _on_head_message(self, message: List[Dict[str, Any]]) -> None: """Parse and emit raw head block from WS""" diff --git a/src/dipdup/datasources/tzkt/enums.py b/src/dipdup/datasources/tzkt/enums.py index 836f9d82b..cee21abed 100644 --- a/src/dipdup/datasources/tzkt/enums.py +++ b/src/dipdup/datasources/tzkt/enums.py @@ -5,3 +5,48 @@ class TzktMessageType(Enum): STATE = 0 DATA = 1 REORG = 2 + + +OPERATION_FIELDS = ( + "type", + "id", + "level", + "timestamp", + "hash", + "counter", + "sender", + "nonce", + "target", + "initiator", + "amount", + "storage", + "status", + "hasInternals", + "diffs", +) +ORIGINATION_MIGRATION_FIELDS = ( + "id", + "level", + "timestamp", + "storage", + "diffs", + "account", + "balanceChange", +) +ORIGINATION_OPERATION_FIELDS = ( + *OPERATION_FIELDS, + "originatedContract", +) +TRANSACTION_OPERATION_FIELDS = ( + *OPERATION_FIELDS, + "parameter", + "hasInternals", +) + + +class OperationFetcherRequest(Enum): + """Represents multiple TzKT calls to be merged into a single batch of operations""" + + sender_transactions = 'sender_transactions' + target_transactions = 'target_transactions' + originations = 'originations' diff --git a/src/dipdup/dipdup.py b/src/dipdup/dipdup.py index 1111208f7..fa265f1fd 100644 --- a/src/dipdup/dipdup.py +++ b/src/dipdup/dipdup.py @@ -5,7 +5,7 @@ from contextlib import AsyncExitStack, asynccontextmanager, suppress from functools import partial from operator import ne -from typing import Awaitable, Deque, Dict, List, Optional, Set +from typing import Awaitable, Deque, Dict, List, Optional, Set, Tuple, cast from apscheduler.events import EVENT_JOB_ERROR # type: ignore from tortoise.exceptions import OperationalError @@ -31,7 +31,7 @@ from dipdup.enums import ReindexingReason from dipdup.exceptions import ConfigInitializationException, DipDupException from dipdup.hasura import HasuraGateway -from dipdup.index import BigMapIndex, Index, OperationIndex +from dipdup.index import BigMapIndex, Index, OperationIndex, block_cache from dipdup.models import BigMapData, Contract, Head, HeadBlockData from dipdup.models import Index as IndexState from dipdup.models import IndexStatus, OperationData, Schema @@ -64,7 +64,7 @@ async def run( while self._tasks: tasks.append(self._tasks.popleft()) - async with slowdown(1.0): + async with slowdown(0.1): await gather(*tasks) indexes_spawned = False @@ -139,6 +139,8 @@ async def _load_index_states(self) -> None: else: self._logger.warning('Index `%s` was removed from config, ignoring', name) + block_cache.clear() + async def _on_head(self, datasource: TzktDatasource, head: HeadBlockData) -> None: # NOTE: Do not await query results - blocked database connection may cause Websocket timeout. self._tasks.append( @@ -154,36 +156,48 @@ async def _on_head(self, datasource: TzktDatasource, head: HeadBlockData) -> Non ) ) - async def _on_operations(self, datasource: TzktDatasource, operations: List[OperationData]) -> None: - assert len(set(op.level for op in operations)) == 1 - level = operations[0].level + async def _on_operations(self, datasource: TzktDatasource, operations: Tuple[OperationData, ...]) -> None: for index in self._indexes.values(): if isinstance(index, OperationIndex) and index.datasource == datasource: - index.push(level, operations) + index.push_operations(operations) - async def _on_big_maps(self, datasource: TzktDatasource, big_maps: List[BigMapData]) -> None: - assert len(set(op.level for op in big_maps)) == 1 - level = big_maps[0].level + async def _on_big_maps(self, datasource: TzktDatasource, big_maps: Tuple[BigMapData]) -> None: for index in self._indexes.values(): if isinstance(index, BigMapIndex) and index.datasource == datasource: - index.push(level, big_maps) + index.push_big_maps(big_maps) async def _on_rollback(self, datasource: TzktDatasource, from_level: int, to_level: int) -> None: - # NOTE: Rollback could be received before head - if from_level - to_level in (0, 1): - # NOTE: Single level rollbacks are processed at Index level. + """Perform a single level rollback when possible, otherwise call `on_rollback` hook""" + self._logger.warning('Datasource `%s` rolled back: %s -> %s', datasource.name, from_level, to_level) + + # NOTE: Zero difference between levels means we received no operations/big_maps on this level and thus channel level hasn't changed + zero_level_rollback = from_level - to_level == 0 + single_level_rollback = from_level - to_level == 1 + + if zero_level_rollback: + self._logger.info('Zero level rollback, ignoring') + + elif single_level_rollback: # NOTE: Notify all indexes which use rolled back datasource to drop duplicated operations from the next block - for index in self._indexes.values(): - if index.datasource == datasource: - # NOTE: Continue to rollback with handler - if not isinstance(index, OperationIndex): - self._logger.info('Single level rollback is not supported by `%s` indexes', index._config.kind) - break - await index.single_level_rollback(from_level) + self._logger.info('Checking if single level rollback is possible') + matching_indexes = tuple(i for i in self._indexes.values() if i.datasource == datasource) + matching_operation_indexes = tuple(i for i in matching_indexes if isinstance(i, OperationIndex)) + self._logger.info( + 'Indexes: %s total, %s matching, %s support single level rollback', + len(self._indexes), + len(matching_indexes), + len(matching_operation_indexes), + ) + + all_indexes_are_operation = len(matching_indexes) == len(matching_operation_indexes) + if all_indexes_are_operation: + for index in cast(List[OperationIndex], matching_indexes): + index.push_rollback(from_level) else: - return + await self._ctx.fire_hook('on_rollback', datasource=datasource, from_level=from_level, to_level=to_level) - await self._ctx.fire_hook('on_rollback', datasource=datasource, from_level=from_level, to_level=to_level) + else: + await self._ctx.fire_hook('on_rollback', datasource=datasource, from_level=from_level, to_level=to_level) class DipDup: diff --git a/src/dipdup/exceptions.py b/src/dipdup/exceptions.py index c11decf63..72c278201 100644 --- a/src/dipdup/exceptions.py +++ b/src/dipdup/exceptions.py @@ -1,8 +1,8 @@ import textwrap import traceback from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Iterator, Optional, Type +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, Optional, Type import sentry_sdk from tabulate import tabulate @@ -137,6 +137,7 @@ class ReindexingRequiredError(DipDupError): """Unable to continue indexing with existing database""" reason: ReindexingReason + context: Dict[str, Any] = field(default_factory=dict) def _help(self) -> str: return f""" @@ -144,6 +145,8 @@ def _help(self) -> str: Reason: {self.reason.value} + Additional context: {self.context} + You may want to backup database before proceeding. After that perform one of the following actions: * Eliminate the cause of reindexing and run `UPDATE dupdup_schema SET reindex = NULL;` diff --git a/src/dipdup/index.py b/src/dipdup/index.py index 84a32765a..d657a4ecd 100644 --- a/src/dipdup/index.py +++ b/src/dipdup/index.py @@ -1,6 +1,6 @@ from abc import abstractmethod from collections import defaultdict, deque, namedtuple -from typing import Deque, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union, cast from pydantic.error_wrappers import ValidationError @@ -27,7 +27,13 @@ # NOTE: Operations of a single contract call OperationSubgroup = namedtuple('OperationSubgroup', ('hash', 'counter')) -_cached_blocks: Dict[int, BlockData] = {} +# NOTE: Message queue of OperationIndex +SingleLevelRollback = namedtuple('SingleLevelRollback', ('level')) +Operations = Tuple[OperationData, ...] +OperationQueueItemT = Union[Operations, SingleLevelRollback] + +# NOTE: For initializing the index state on startup +block_cache: Dict[int, BlockData] = {} class Index: @@ -73,9 +79,9 @@ async def initialize_state(self) -> None: if not head: return - if head.level not in _cached_blocks: - _cached_blocks[head.level] = await self.datasource.get_block(head.level) - if head.hash != _cached_blocks[head.level].hash: + if head.level not in block_cache: + block_cache[head.level] = await self.datasource.get_block(head.level) + if head.hash != block_cache[head.level].hash: await self._ctx.reindex(ReindexingReason.BLOCK_HASH_MISMATCH) async def process(self) -> None: @@ -89,7 +95,9 @@ async def process(self) -> None: raise RuntimeError('Call `set_sync_level` before starting IndexDispatcher') elif self.state.level < self._datasource.sync_level: - self._logger.info('Index is behind datasource, sync to datasource level') + self._logger.info( + 'Index is behind datasource, sync to datasource level: %s -> %s', self.state.level, self._datasource.sync_level + ) self._queue.clear() last_level = self._datasource.sync_level await self._synchronize(last_level) @@ -124,44 +132,57 @@ async def _exit_sync_state(self, last_level: int) -> None: self._logger.info('Index is synchronized to level %s', last_level) await self.state.update_status(status=IndexStatus.REALTIME, level=last_level) + def _extract_level(self, message: Union[Tuple[OperationData, ...], Tuple[BigMapData, ...]]) -> int: + batch_levels = tuple(set(item.level for item in message)) + if len(batch_levels) != 1: + raise RuntimeError(f'Items in operation/big_map batch have different levels: {batch_levels}') + return tuple(batch_levels)[0] + class OperationIndex(Index): _config: OperationIndexConfig def __init__(self, ctx: DipDupContext, config: OperationIndexConfig, datasource: TzktDatasource) -> None: super().__init__(ctx, config, datasource) - self._queue: Deque[Tuple[int, List[OperationData]]] = deque() + self._queue: Deque[OperationQueueItemT] = deque() self._contract_hashes: Dict[str, Tuple[int, int]] = {} self._rollback_level: Optional[int] = None - self._last_hashes: Set[str] = set() + self._head_hashes: Set[str] = set() self._migration_originations: Optional[Dict[str, OperationData]] = None - def push(self, level: int, operations: List[OperationData]) -> None: - self._queue.append((level, operations)) + def push_operations(self, operations: Tuple[OperationData, ...]) -> None: + self._queue.append(operations) - async def single_level_rollback(self, from_level: int) -> None: - """Ensure next arrived block is the same as rolled back one + def push_rollback(self, level: int) -> None: + self._queue.append(SingleLevelRollback(level)) - Called by IndexDispatcher in case index datasource reported a rollback. + async def _single_level_rollback(self, level: int) -> None: + """Ensure next arrived block has all operations of the previous block. But it could also contain additional operations. + + Called by IndexDispatcher when index datasource receive a single level rollback. """ if self._rollback_level: - raise RuntimeError('Already in rollback state') - - if self.state.level < from_level: - self._logger.info('Index level is lower than rollback level, ignoring') - elif self.state.level == from_level: - self._logger.info('Single level rollback has been triggered') - self._rollback_level = from_level + raise RuntimeError('Index is already in rollback state') + + state_level = cast(int, self.state.level) + if state_level < level: + self._logger.info('Index level is lower than rollback level, ignoring: %s < %s', state_level, level) + elif state_level == level: + self._logger.info('Single level rollback, next block will be processed partially') + self._rollback_level = level else: - raise RuntimeError('Index level is higher than rollback level') + raise RuntimeError(f'Index level is higher than rollback level: {state_level} > {level}') async def _process_queue(self) -> None: """Process WebSocket queue""" - if self._queue: - self._logger.info('Processing websocket queue') while self._queue: - level, operations = self._queue.popleft() - await self._process_level_operations(level, operations) + message = self._queue.popleft() + if isinstance(message, SingleLevelRollback): + self._logger.info('Processing rollback realtime message, %s left in queue', len(self._queue)) + await self._single_level_rollback(message.level) + else: + self._logger.info('Processing operations realtime message, %s left in queue', len(self._queue)) + await self._process_level_operations(message) async def _synchronize(self, last_level: int, cache: bool = False) -> None: """Fetch operations via Fetcher and pass to message callback""" @@ -173,9 +194,9 @@ async def _synchronize(self, last_level: int, cache: bool = False) -> None: transaction_addresses = await self._get_transaction_addresses() origination_addresses = await self._get_origination_addresses() - migration_originations = [] + migration_originations: Tuple[OperationData, ...] = () if self._config.types and OperationType.migration in self._config.types: - migration_originations = await self._datasource.get_migration_originations(first_level) + migration_originations = tuple(await self._datasource.get_migration_originations(first_level)) for op in migration_originations: code_hash, type_hash = await self._get_contract_hashes(cast(str, op.originated_contract_address)) op.originated_contract_code_hash, op.originated_contract_type_hash = code_hash, type_hash @@ -190,14 +211,15 @@ async def _synchronize(self, last_level: int, cache: bool = False) -> None: migration_originations=migration_originations, ) - async for level, operations in fetcher.fetch_operations_by_level(): - await self._process_level_operations(level, operations) + async for _, operations in fetcher.fetch_operations_by_level(): + await self._process_level_operations(operations) await self._exit_sync_state(last_level) - async def _process_level_operations(self, level: int, operations: List[OperationData]) -> None: - if level <= self.state.level: - raise RuntimeError(f'Level of operation batch must be higher than index state level: {level} <= {self.state.level}') + async def _process_level_operations(self, operations: Tuple[OperationData, ...]) -> None: + if not operations: + return + level = self._extract_level(operations) if self._rollback_level: levels = { @@ -210,18 +232,23 @@ async def _process_level_operations(self, level: int, operations: List[Operation raise RuntimeError(f'Index is in a rollback state, but received operation batch with different levels: {levels_repr}') self._logger.info('Rolling back to previous level, verifying processed operations') - expected_hashes = set(self._last_hashes) - received_hashes = set([op.hash for op in operations]) - reused_hashes = received_hashes & expected_hashes - if reused_hashes != expected_hashes: + expected_hashes = set(self._head_hashes) + received_hashes = set(op.hash for op in operations) + new_hashes = received_hashes - expected_hashes + missing_hashes = expected_hashes - received_hashes + + self._logger.info('Comparing hashes: %s new, %s missing', len(new_hashes), len(missing_hashes)) + if missing_hashes: + self._logger.info('Some operations are backtracked: %s', ', '.join(missing_hashes)) await self._ctx.reindex(ReindexingReason.ROLLBACK) self._rollback_level = None - self._last_hashes = set() - new_hashes = received_hashes - expected_hashes - if not new_hashes: - return - operations = [op for op in operations if op.hash in new_hashes] + self._head_hashes = set() + operations = tuple(op for op in operations if op.hash in new_hashes) + + # NOTE: le operator because it could be a single level rollback + elif level < self.state.level: + raise RuntimeError(f'Level of operation batch must be higher than index state level: {level} < {self.state.level}') async with in_global_transaction(): self._logger.info('Processing %s operations of level %s', len(operations), level) @@ -261,14 +288,14 @@ async def _match_operation(self, pattern_config: OperationHandlerPatternConfigT, else: raise NotImplementedError - async def _process_operations(self, operations: List[OperationData]) -> None: + async def _process_operations(self, operations: Iterable[OperationData]) -> None: """Try to match operations in cache with all patterns from indexes. Must be wrapped in transaction.""" - self._last_hashes = set() - operation_subgroups: Dict[OperationSubgroup, List[OperationData]] = defaultdict(list) + self._head_hashes = set() + operation_subgroups: Dict[OperationSubgroup, Deque[OperationData]] = defaultdict(deque) for operation in operations: key = OperationSubgroup(operation.hash, operation.counter) operation_subgroups[key].append(operation) - self._last_hashes.add(operation.hash) + self._head_hashes.add(operation.hash) for operation_subgroup, operations in operation_subgroups.items(): self._logger.debug('Matching %s', key) @@ -276,7 +303,7 @@ async def _process_operations(self, operations: List[OperationData]) -> None: for handler_config in self._config.handlers: operation_idx = 0 pattern_idx = 0 - matched_operations: List[Optional[OperationData]] = [] + matched_operations: Deque[Optional[OperationData]] = deque() # TODO: Ensure complex cases work, for ex. required argument after optional one # TODO: Add None to matched_operations where applicable (pattern is optional and operation not found) @@ -304,7 +331,7 @@ async def _process_operations(self, operations: List[OperationData]) -> None: if pattern_idx == len(handler_config.pattern): await self._on_match(operation_subgroup, handler_config, matched_operations) - matched_operations = [] + matched_operations.clear() pattern_idx = 0 if len(matched_operations) >= sum(map(lambda x: 0 if x.optional else 1, handler_config.pattern)): @@ -314,7 +341,7 @@ async def _on_match( self, operation_subgroup: OperationSubgroup, handler_config: OperationHandlerConfig, - matched_operations: List[Optional[OperationData]], + matched_operations: Deque[Optional[OperationData]], ): """Prepare handler arguments, parse parameter and storage. Schedule callback in executor.""" self._logger.info('%s: `%s` handler matched!', operation_subgroup.hash, handler_config.callback) @@ -405,18 +432,18 @@ class BigMapIndex(Index): def __init__(self, ctx: DipDupContext, config: BigMapIndexConfig, datasource: TzktDatasource) -> None: super().__init__(ctx, config, datasource) - self._queue: Deque[Tuple[int, List[BigMapData]]] = deque() + self._queue: Deque[Tuple[BigMapData, ...]] = deque() - def push(self, level: int, big_maps: List[BigMapData]): - self._queue.append((level, big_maps)) + def push_big_maps(self, big_maps: Tuple[BigMapData, ...]) -> None: + self._queue.append(big_maps) async def _process_queue(self) -> None: """Process WebSocket queue""" if self._queue: self._logger.info('Processing websocket queue') while self._queue: - level, big_maps = self._queue.popleft() - await self._process_level_big_maps(level, big_maps) + big_maps = self._queue.popleft() + await self._process_level_big_maps(big_maps) async def _synchronize(self, last_level: int, cache: bool = False) -> None: """Fetch operations via Fetcher and pass to message callback""" @@ -438,12 +465,15 @@ async def _synchronize(self, last_level: int, cache: bool = False) -> None: cache=cache, ) - async for level, big_maps in fetcher.fetch_big_maps_by_level(): - await self._process_level_big_maps(level, big_maps) + async for _, big_maps in fetcher.fetch_big_maps_by_level(): + await self._process_level_big_maps(big_maps) await self._exit_sync_state(last_level) - async def _process_level_big_maps(self, level: int, big_maps: List[BigMapData]): + async def _process_level_big_maps(self, big_maps: Tuple[BigMapData, ...]): + level = self._extract_level(big_maps) + + # NOTE: le operator because single level rollbacks are not supported if level <= self.state.level: raise RuntimeError(f'Level of big map batch must be higher than index state level: {level} <= {self.state.level}') @@ -504,7 +534,7 @@ async def _on_match( big_map_diff, ) - async def _process_big_maps(self, big_maps: List[BigMapData]) -> None: + async def _process_big_maps(self, big_maps: Iterable[BigMapData]) -> None: """Try to match big map diffs in cache with all patterns from indexes.""" for big_map in big_maps: diff --git a/tests/integration_tests/test_rollback.py b/tests/integration_tests/test_rollback.py index a16ff7b4a..f0cf18e8f 100644 --- a/tests/integration_tests/test_rollback.py +++ b/tests/integration_tests/test_rollback.py @@ -1,19 +1,21 @@ import asyncio +from contextlib import ExitStack, contextmanager from datetime import datetime -from functools import partial from os.path import dirname, join -from types import MethodType -from unittest import IsolatedAsyncioTestCase, skip -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Generator, Tuple +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, Mock, patch from dipdup.config import DipDupConfig from dipdup.datasources.tzkt.datasource import TzktDatasource -from dipdup.dipdup import DipDup, IndexDispatcher -from dipdup.index import OperationIndex -from dipdup.models import BlockData, HeadBlockData +from dipdup.dipdup import DipDup +from dipdup.models import HeadBlockData from dipdup.models import Index as State from dipdup.models import OperationData +# import logging +# logging.basicConfig(level=logging.INFO) + def _get_operation(hash_: str, level: int) -> OperationData: return OperationData( @@ -34,116 +36,131 @@ def _get_operation(hash_: str, level: int) -> OperationData: ) -# NOTE: Skip synchronization -async def operation_index_process(self: OperationIndex): - await self.initialize_state() - await self._process_queue() +initial_level = 1365000 +next_level = initial_level + 1 +exact_operations = ( + _get_operation('1', next_level), + _get_operation('2', next_level), + _get_operation('3', next_level), +) -# NOTE: Emit operations, rollback, emit again, check state -async def datasource_run(self: TzktDatasource, index_dispatcher: IndexDispatcher, fail=False): +less_operations = ( + _get_operation('1', next_level), + _get_operation('2', next_level), +) - old_block = MagicMock(spec=HeadBlockData) - old_block.hash = 'block_a' - old_block.level = 1365001 - old_block.timestamp = datetime(2018, 1, 1) - new_block = MagicMock(spec=HeadBlockData) - new_block.hash = 'block_b' - new_block.level = 1365001 - new_block.timestamp = datetime(2018, 1, 1) +more_operations = ( + _get_operation('1', next_level), + _get_operation('2', next_level), + _get_operation('3', next_level), + _get_operation('4', next_level), +) - await self.emit_operations( - [ - _get_operation('1', 1365001), - _get_operation('2', 1365001), - _get_operation('3', 1365001), - ], - ) - await asyncio.sleep(0.05) +async def check_level(level: int) -> None: + state = await State.filter(name='hen_mainnet').get() + assert state.level == level, state.level + + +async def emit_messages( + self: TzktDatasource, + old_block: Tuple[OperationData, ...], + new_block: Tuple[OperationData, ...], + level: int, +): + await self.emit_operations(old_block) await self.emit_rollback( - from_level=1365001, - to_level=1365000, - ) - await asyncio.sleep(0.05) - - self.emit_operations( - [ - _get_operation('1', 1365001), - _get_operation('2', 1365001), - ] - + ( - [ - _get_operation('3', 1365001), - ] - if not fail - else [] - ), + from_level=next_level, + to_level=next_level - level, ) - await asyncio.sleep(0.05) + await self.emit_operations(new_block) - index_dispatcher.stop() + for _ in range(10): + await asyncio.sleep(0.1) + + raise asyncio.CancelledError + + +async def datasource_run_exact(self: TzktDatasource): + await emit_messages(self, exact_operations, exact_operations, 1) + await check_level(initial_level + 1) + + +async def datasource_run_more(self: TzktDatasource): + await emit_messages(self, exact_operations, more_operations, 1) + await check_level(initial_level + 1) + + +async def datasource_run_less(self: TzktDatasource): + await emit_messages(self, exact_operations, less_operations, 1) + await check_level(initial_level + 1) + + +async def datasource_run_zero(self: TzktDatasource): + await emit_messages(self, (), (exact_operations), 0) + await check_level(initial_level + 1) - # Assert - state = await State.filter(name='hen_mainnet').get() - assert state.level == 1365001 + +async def datasource_run_deep(self: TzktDatasource): + await emit_messages(self, (exact_operations), (), 1337) + await check_level(initial_level + 1) + + +head = Mock(spec=HeadBlockData) +head.level = initial_level + + +@contextmanager +def patch_dipdup(datasource_run) -> Generator: + with ExitStack() as stack: + stack.enter_context(patch('dipdup.index.OperationIndex._synchronize', AsyncMock())) + stack.enter_context(patch('dipdup.datasources.tzkt.datasource.TzktDatasource.run', datasource_run)) + stack.enter_context(patch('dipdup.context.DipDupContext.reindex', AsyncMock())) + stack.enter_context(patch('dipdup.datasources.tzkt.datasource.TzktDatasource.get_head_block', AsyncMock(return_value=head))) + yield + + +def get_dipdup() -> DipDup: + config = DipDupConfig.load([join(dirname(__file__), 'hic_et_nunc.yml')]) + config.database.path = ':memory:' # type: ignore + config.indexes['hen_mainnet'].last_level = 0 # type: ignore + config.initialize() + return DipDup(config) -@skip('RuntimeError: Index is synchronized but has no head block data') class RollbackTest(IsolatedAsyncioTestCase): - async def test_rollback_ok(self): - # Arrange - config = DipDupConfig.load([join(dirname(__file__), 'hic_et_nunc.yml')]) - config.database.path = ':memory:' - - datasource_name, datasource_config = list(config.datasources.items())[0] - datasource = TzktDatasource('test') - dipdup = DipDup(config) - dipdup._datasources[datasource_name] = datasource - dipdup._datasources_by_config[datasource_config] = datasource - - initial_block = MagicMock(spec=BlockData) - initial_block.level = 0 - initial_block.hash = 'block_0' - - datasource.on_operations(dipdup._index_dispatcher._on_operations) - datasource.on_big_maps(dipdup._index_dispatcher._on_big_maps) - datasource.on_rollback(dipdup._index_dispatcher._on_rollback) - - datasource.run = MethodType(partial(datasource_run, index_dispatcher=dipdup._index_dispatcher), datasource) - datasource.get_block = AsyncMock(return_value=initial_block) - - # Act - with patch('dipdup.index.OperationIndex.process', operation_index_process): - with patch('dipdup.dipdup.INDEX_DISPATCHER_INTERVAL', 0.01): - await dipdup.run(False, False) - - async def test_rollback_fail(self): - # Arrange - config = DipDupConfig.load([join(dirname(__file__), 'hic_et_nunc.yml')]) - config.database.path = ':memory:' - - datasource_name, datasource_config = list(config.datasources.items())[0] - datasource = TzktDatasource('test') - dipdup = DipDup(config) - dipdup._datasources[datasource_name] = datasource - dipdup._datasources_by_config[datasource_config] = datasource - dipdup._ctx.reindex = AsyncMock() - - initial_block = MagicMock(spec=BlockData) - initial_block.level = 0 - initial_block.hash = 'block_0' - - datasource.on_operations(dipdup._index_dispatcher._on_operations) - datasource.on_big_maps(dipdup._index_dispatcher._on_big_maps) - datasource.on_rollback(dipdup._index_dispatcher._on_rollback) - - datasource.run = MethodType(partial(datasource_run, index_dispatcher=dipdup._index_dispatcher, fail=True), datasource) - datasource.get_block = AsyncMock(return_value=initial_block) - - # Act - with patch('dipdup.index.OperationIndex.process', operation_index_process): - with patch('dipdup.dipdup.INDEX_DISPATCHER_INTERVAL', 0.01): - await dipdup.run(False, False) - - dipdup._ctx.reindex.assert_awaited() + async def test_rollback_exact(self): + with patch_dipdup(datasource_run_exact): + dipdup = get_dipdup() + await dipdup.run(False, False, False) + + assert dipdup._ctx.reindex.call_count == 0 + + async def test_rollback_more(self): + with patch_dipdup(datasource_run_more): + dipdup = get_dipdup() + await dipdup.run(False, False, False) + + assert dipdup._ctx.reindex.call_count == 0 + + async def test_rollback_less(self): + with patch_dipdup(datasource_run_less): + dipdup = get_dipdup() + await dipdup.run(False, False, False) + + assert dipdup._ctx.reindex.call_count == 1 + + async def test_rollback_zero(self): + with patch_dipdup(datasource_run_zero): + dipdup = get_dipdup() + await dipdup.run(False, False, False) + + assert dipdup._ctx.reindex.call_count == 0 + + async def test_rollback_deep(self): + with patch_dipdup(datasource_run_deep): + dipdup = get_dipdup() + await dipdup.run(False, False, False) + + assert dipdup._ctx.reindex.call_count == 1