diff --git a/src/demo_quipuswap/dipdup.yml b/src/demo_quipuswap/dipdup.yml index 98a0a82e0..287fbe2e1 100644 --- a/src/demo_quipuswap/dipdup.yml +++ b/src/demo_quipuswap/dipdup.yml @@ -74,6 +74,9 @@ templates: - type: transaction destination: entrypoint: withdrawProfit + - type: transaction + source: + optional: True quipuswap_fa2: kind: operation @@ -118,6 +121,9 @@ templates: - type: transaction destination: entrypoint: withdrawProfit + - type: transaction + source: + optional: True indexes: kusd_mainnet: diff --git a/src/demo_quipuswap/handlers/on_fa12_withdraw_profit.py b/src/demo_quipuswap/handlers/on_fa12_withdraw_profit.py index 80ecdf0ef..3c5794808 100644 --- a/src/demo_quipuswap/handlers/on_fa12_withdraw_profit.py +++ b/src/demo_quipuswap/handlers/on_fa12_withdraw_profit.py @@ -1,16 +1,17 @@ from decimal import Decimal +from typing import Optional import demo_quipuswap.models as models from demo_quipuswap.types.quipu_fa12.parameter.withdraw_profit import WithdrawProfitParameter from demo_quipuswap.types.quipu_fa12.storage import QuipuFa12Storage -from dipdup.models import OperationHandlerContext, TransactionContext +from dipdup.models import OperationData, OperationHandlerContext, OriginationContext, TransactionContext async def on_fa12_withdraw_profit( ctx: OperationHandlerContext, withdraw_profit: TransactionContext[WithdrawProfitParameter, QuipuFa12Storage], + transaction_0: Optional[OperationData], ) -> None: - if ctx.template_values is None: raise Exception('This index must be templated') @@ -18,9 +19,8 @@ async def on_fa12_withdraw_profit( trader = withdraw_profit.data.sender_address position, _ = await models.Position.get_or_create(trader=trader, symbol=symbol) - transaction = next(op for op in ctx.operations if op.amount) - - assert transaction.amount is not None - position.realized_pl += Decimal(transaction.amount) / (10 ** 6) # type: ignore + if transaction_0: + assert transaction_0.amount is not None + position.realized_pl += Decimal(transaction_0.amount) / (10 ** 6) # type: ignore - await position.save() + await position.save() diff --git a/src/demo_quipuswap/handlers/on_fa20_withdraw_profit.py b/src/demo_quipuswap/handlers/on_fa20_withdraw_profit.py index e29b68615..5fb36765a 100644 --- a/src/demo_quipuswap/handlers/on_fa20_withdraw_profit.py +++ b/src/demo_quipuswap/handlers/on_fa20_withdraw_profit.py @@ -1,14 +1,16 @@ from decimal import Decimal +from typing import Optional import demo_quipuswap.models as models from demo_quipuswap.types.quipu_fa2.parameter.withdraw_profit import WithdrawProfitParameter from demo_quipuswap.types.quipu_fa2.storage import QuipuFa2Storage -from dipdup.models import OperationHandlerContext, TransactionContext +from dipdup.models import OperationData, OperationHandlerContext, OriginationContext, TransactionContext async def on_fa20_withdraw_profit( ctx: OperationHandlerContext, withdraw_profit: TransactionContext[WithdrawProfitParameter, QuipuFa2Storage], + transaction_0: Optional[OperationData], ) -> None: if ctx.template_values is None: @@ -18,9 +20,9 @@ async def on_fa20_withdraw_profit( trader = withdraw_profit.data.sender_address position, _ = await models.Position.get_or_create(trader=trader, symbol=symbol) - transaction = next(op for op in ctx.operations if op.amount) - assert transaction.amount is not None - position.realized_pl += Decimal(transaction.amount) / (10 ** 6) # type: ignore + if transaction_0: + assert transaction_0.amount is not None + position.realized_pl += Decimal(transaction_0.amount) / (10 ** 6) # type: ignore - await position.save() + await position.save() diff --git a/src/dipdup/codegen.py b/src/dipdup/codegen.py index a8cfed9c9..e9bc80bae 100644 --- a/src/dipdup/codegen.py +++ b/src/dipdup/codegen.py @@ -17,6 +17,7 @@ DipDupConfig, IndexTemplateConfig, OperationHandlerConfig, + OperationHandlerOriginationPatternConfig, OperationHandlerTransactionPatternConfig, OperationIndexConfig, TzktDatasourceConfig, @@ -91,7 +92,17 @@ async def fetch_schemas(config: DipDupConfig): if isinstance(index_config, OperationIndexConfig): for operation_handler_config in index_config.handlers: for operation_pattern_config in operation_handler_config.pattern: - contract_config = operation_pattern_config.contract_config + + if ( + isinstance(operation_pattern_config, OperationHandlerTransactionPatternConfig) + and operation_pattern_config.entrypoint + ): + contract_config = operation_pattern_config.destination_contract_config + elif isinstance(operation_pattern_config, OperationHandlerOriginationPatternConfig): + contract_config = operation_pattern_config.contract_config + else: + continue + contract_schemas = await schemas_cache.get(index_config.datasource_config, contract_config) contract_schemas_path = join(schemas_path, contract_config.module_name) diff --git a/src/dipdup/config.py b/src/dipdup/config.py index e46cf9176..7e4cdc80e 100644 --- a/src/dipdup/config.py +++ b/src/dipdup/config.py @@ -6,14 +6,13 @@ import re import sys from collections import defaultdict -from dataclasses import field from enum import Enum from os import environ as env from os.path import dirname from typing import Any, Callable, Dict, List, Optional, Type, Union, cast from urllib.parse import urlparse -from pydantic import Field, validator +from pydantic import validator from pydantic.dataclasses import dataclass from pydantic.json import pydantic_encoder from ruamel.yaml import YAML @@ -155,21 +154,35 @@ class OperationHandlerTransactionPatternConfig: :param entrypoint: Contract entrypoint """ - type: Literal['transaction'] - destination: Union[str, ContractConfig] - entrypoint: str + type: Literal['transaction'] = 'transaction' + source: Optional[Union[str, ContractConfig]] = None + destination: Optional[Union[str, ContractConfig]] = None + entrypoint: Optional[str] = None + optional: bool = False def __post_init_post_parse__(self): + if self.entrypoint and not self.destination: + raise ConfigurationError('Transactions with entrypoint must also have destination') self._parameter_type_cls = None self._storage_type_cls = None + self._transaction_id = None @property - def contract_config(self) -> ContractConfig: - assert isinstance(self.destination, ContractConfig) + def source_contract_config(self) -> ContractConfig: + if not isinstance(self.source, ContractConfig): + raise RuntimeError('Config is not initialized') + return self.source + + @property + def destination_contract_config(self) -> ContractConfig: + if not isinstance(self.destination, ContractConfig): + raise RuntimeError('Config is not initialized') return self.destination @property def parameter_type_cls(self) -> Optional[Type]: + if not self.entrypoint: + raise RuntimeError('entrypoint is empty') if self._parameter_type_cls is None: raise RuntimeError('Config is not initialized') return self._parameter_type_cls @@ -180,6 +193,8 @@ def parameter_type_cls(self, typ: Type) -> None: @property def storage_type_cls(self) -> Type: + if not self.entrypoint: + raise RuntimeError('entrypoint is empty') if self._storage_type_cls is None: raise RuntimeError('Config is not initialized') return self._storage_type_cls @@ -188,22 +203,51 @@ def storage_type_cls(self) -> Type: def storage_type_cls(self, typ: Type) -> None: self._storage_type_cls = typ + @property + def transaction_id(self) -> int: + if self._transaction_id is None: + raise RuntimeError('Config is not initialized') + return self._transaction_id + + @transaction_id.setter + def transaction_id(self, id_: int) -> None: + self._transaction_id = id_ + def get_handler_imports(self, package: str) -> str: - return '\n'.join( - [ - f'from {package}.types.{self.contract_config.module_name}.parameter.{camel_to_snake(self.entrypoint)} import {snake_to_camel(self.entrypoint)}Parameter', - f'from {package}.types.{self.contract_config.module_name}.storage import {snake_to_camel(self.contract_config.module_name)}Storage', - ] - ) + if self.entrypoint: + module_name = self.destination_contract_config.module_name + entrypoint = camel_to_snake(self.entrypoint) + parameter_cls = f'{snake_to_camel(self.entrypoint)}Parameter' + storage_cls = f'{snake_to_camel(module_name)}Storage' + return '\n'.join( + [ + f'from {package}.types.{module_name}.parameter.{entrypoint} import {parameter_cls}', + f'from {package}.types.{module_name}.storage import {storage_cls}', + ] + ) + else: + return '' def get_handler_argument(self) -> str: - return f'{camel_to_snake(self.entrypoint)}: TransactionContext[{snake_to_camel(self.entrypoint)}Parameter, {snake_to_camel(self.contract_config.module_name)}Storage],' + if self.entrypoint: + module_name = self.destination_contract_config.module_name + entrypoint = camel_to_snake(self.entrypoint) + parameter_cls = f'{snake_to_camel(self.entrypoint)}Parameter' + storage_cls = f'{snake_to_camel(module_name)}Storage' + if self.optional: + return f'{entrypoint}: Optional[TransactionContext[{parameter_cls}, {storage_cls}]],' + return f'{entrypoint}: TransactionContext[{parameter_cls}, {storage_cls}],' + else: + if self.optional: + return f'transaction_{self._transaction_id}: Optional[OperationData],' + return f'transaction_{self._transaction_id}: OperationData,' @dataclass class OperationHandlerOriginationPatternConfig: - type: Literal['origination'] originated_contract: Union[str, ContractConfig] + type: Literal['origination'] = 'origination' + optional: bool = False def __post_init_post_parse__(self): self._storage_type_cls = None @@ -228,10 +272,16 @@ def storage_type_cls(self, typ: Type) -> None: self._storage_type_cls = typ def get_handler_imports(self, package: str) -> str: - return f'from {package}.types.{self.contract_config.module_name}.storage import {snake_to_camel(self.contract_config.module_name)}Storage' + module_name = self.contract_config.module_name + storage_cls = f'{snake_to_camel(module_name)}Storage' + return f'from {package}.types.{module_name}.storage import {storage_cls}' def get_handler_argument(self) -> str: - return f'{self.contract_config.module_name}_origination: OriginationContext[{snake_to_camel(self.contract_config.module_name)}Storage],' + module_name = self.contract_config.module_name + storage_cls = f'{snake_to_camel(module_name)}Storage' + if self.optional: + return f'{module_name}_origination: Optional[OriginationContext[{storage_cls}]],' + return f'{module_name}_origination: OriginationContext[{storage_cls}],' OperationHandlerPatternConfig = Union[OperationHandlerOriginationPatternConfig, OperationHandlerTransactionPatternConfig] @@ -481,6 +531,7 @@ def __post_init_post_parse__(self): except KeyError as e: raise ConfigurationError(f'Contract `{contract}` not found in `contracts` config section') from e + transaction_id = 0 for handler_config in index_config.handlers: callback_patterns[handler_config.callback].append(handler_config.pattern) for pattern_config in handler_config.pattern: @@ -492,6 +543,17 @@ def __post_init_post_parse__(self): raise ConfigurationError( f'Contract `{pattern_config.destination}` not found in `contracts` config section' ) from e + if isinstance(pattern_config.source, str): + try: + pattern_config.source = self.contracts[pattern_config.source] + except KeyError as e: + raise ConfigurationError( + f'Contract `{pattern_config.source}` not found in `contracts` config section' + ) from e + if not pattern_config.entrypoint: + pattern_config.transaction_id = transaction_id + transaction_id += 1 + elif isinstance(pattern_config, OperationHandlerOriginationPatternConfig): if isinstance(pattern_config.originated_contract, str): try: @@ -522,7 +584,13 @@ def __post_init_post_parse__(self): if len(patterns) > 1: def get_pattern_type(pattern: List[OperationHandlerPatternConfig]): - return '::'.join(map(lambda x: x.contract_config.module_name, pattern)) + module_names = [] + for pattern_config in pattern: + if isinstance(pattern_config, OperationHandlerTransactionPatternConfig) and pattern_config.entrypoint: + module_names.append(pattern_config.destination_contract_config.module_name) + elif isinstance(pattern_config, OperationHandlerOriginationPatternConfig): + module_names.append(pattern_config.contract_config.module_name) + return '::'.join(module_names) pattern_types = list(map(get_pattern_type, patterns)) if any(map(lambda x: x != pattern_types[0], pattern_types)): @@ -616,11 +684,14 @@ async def initialize(self) -> None: for operation_pattern_config in operation_handler_config.pattern: if isinstance(operation_pattern_config, OperationHandlerTransactionPatternConfig): + if not operation_pattern_config.entrypoint: + continue + _logger.info('Registering parameter type for entrypoint `%s`', operation_pattern_config.entrypoint) parameter_type_module = importlib.import_module( f'{self.package}' f'.types' - f'.{operation_pattern_config.contract_config.module_name}' + f'.{operation_pattern_config.destination_contract_config.module_name}' f'.parameter' f'.{camel_to_snake(operation_pattern_config.entrypoint)}' ) @@ -629,14 +700,25 @@ async def initialize(self) -> None: ) operation_pattern_config.parameter_type_cls = parameter_type_cls - _logger.info('Registering storage type') - storage_type_module = importlib.import_module( - f'{self.package}.types.{operation_pattern_config.contract_config.module_name}.storage' - ) - storage_type_cls = getattr( - storage_type_module, snake_to_camel(operation_pattern_config.contract_config.module_name) + 'Storage' - ) - operation_pattern_config.storage_type_cls = storage_type_cls + _logger.info('Registering storage type') + storage_type_module = importlib.import_module( + f'{self.package}.types.{operation_pattern_config.destination_contract_config.module_name}.storage' + ) + storage_type_cls = getattr( + storage_type_module, + snake_to_camel(operation_pattern_config.destination_contract_config.module_name) + 'Storage', + ) + operation_pattern_config.storage_type_cls = storage_type_cls + + elif isinstance(operation_handler_config, OperationHandlerOriginationPatternConfig): + _logger.info('Registering storage type') + storage_type_module = importlib.import_module( + f'{self.package}.types.{operation_pattern_config.contract_config.module_name}.storage' + ) + storage_type_cls = getattr( + storage_type_module, snake_to_camel(operation_pattern_config.contract_config.module_name) + 'Storage' + ) + operation_pattern_config.storage_type_cls = storage_type_cls elif isinstance(index_config, BigMapIndexConfig): for big_map_handler_config in index_config.handlers: diff --git a/src/dipdup/datasources/tzkt/cache.py b/src/dipdup/datasources/tzkt/cache.py index c6ff40023..937a8ab07 100644 --- a/src/dipdup/datasources/tzkt/cache.py +++ b/src/dipdup/datasources/tzkt/cache.py @@ -51,53 +51,79 @@ async def add(self, operation: OperationData): def match_operation(self, pattern_config: OperationHandlerPatternConfig, operation: OperationData) -> bool: if isinstance(pattern_config, OperationHandlerTransactionPatternConfig): - return all( - [ - pattern_config.entrypoint == operation.entrypoint, - pattern_config.contract_config.address == operation.target_address, - ] - ) + if pattern_config.entrypoint != operation.entrypoint: + return False + if pattern_config.destination and pattern_config.destination_contract_config.address != operation.target_address: + return False + if pattern_config.source and pattern_config.source_contract_config.address != operation.sender_address: + return False + return True if isinstance(pattern_config, OperationHandlerOriginationPatternConfig): return pattern_config.contract_config.address == operation.originated_contract_address raise NotImplementedError async def process( self, - callback: Callable[[OperationIndexConfig, OperationHandlerConfig, List[OperationData], List[OperationData]], Awaitable[None]], + callback: Callable[ + [OperationIndexConfig, OperationHandlerConfig, List[Optional[OperationData]], List[OperationData]], + Awaitable[None], + ], ) -> int: + async def on_match( + key: OperationGroup, + index_config: OperationIndexConfig, + handler_config: OperationHandlerConfig, + matched_operations: List[Optional[OperationData]], + operations: List[OperationData], + ) -> None: + self._logger.info('Handler `%s` matched! %s', handler_config.callback, key) + await callback(index_config, handler_config, matched_operations, operations) + + index_config.state.level = self._level + await index_config.state.save() + if self._level is None: raise RuntimeError('Add operations to cache before processing') keys = list(self._operations.keys()) self._logger.info('Matching %s operation groups', len(keys)) - for key, operations in copy(self._operations).items(): - self._logger.debug('Processing %s', key) + for key, operations in self._operations.items(): + self._logger.debug('Matching %s', key) matched = False for index_config in self._indexes.values(): - if matched: - break for handler_config in index_config.handlers: - matched_operations = [] - for pattern_config in handler_config.pattern: - for operation in operations: - operation_matched = self.match_operation(pattern_config, operation) - if operation_matched: - matched_operations.append(operation) - - if len(matched_operations) == len(handler_config.pattern): - self._logger.info('Handler `%s` matched! %s', handler_config.callback, key) + operation_idx = 0 + pattern_idx = 0 + matched_operations: List[Optional[OperationData]] = [] + + while operation_idx < len(operations): + pattern_config = handler_config.pattern[pattern_idx] + matched = self.match_operation(pattern_config, operations[operation_idx]) + if matched: + matched_operations.append(operations[operation_idx]) + pattern_idx += 1 + operation_idx += 1 + elif pattern_config.optional: + matched_operations.append(None) + pattern_idx += 1 + else: + operation_idx += 1 + + if pattern_idx == len(handler_config.pattern): + await on_match(key, index_config, handler_config, matched_operations, operations) + matched = True + matched_operations = [] + pattern_idx = 0 + + if len(matched_operations) >= sum(map(lambda x: 0 if x.optional else 1, handler_config.pattern)): + await on_match(key, index_config, handler_config, matched_operations, operations) matched = True - await callback(index_config, handler_config, matched_operations, operations) - - index_config.state.level = self._level - await index_config.state.save() - del self._operations[key] - break + # NOTE: Only one index could match as addresses do not intersect between indexes (checked on config initialization) + if matched: + break - keys_left = self._operations.keys() - self._logger.info('%s operation groups unmatched', len(keys_left)) self._logger.info('Current level: %s', self._level) self._operations = {} diff --git a/src/dipdup/datasources/tzkt/datasource.py b/src/dipdup/datasources/tzkt/datasource.py index b1f60d509..6670bec3a 100644 --- a/src/dipdup/datasources/tzkt/datasource.py +++ b/src/dipdup/datasources/tzkt/datasource.py @@ -567,17 +567,23 @@ async def on_operation_match( self, index_config: OperationIndexConfig, handler_config: OperationHandlerConfig, - matched_operations: List[OperationData], + matched_operations: List[Optional[OperationData]], operations: List[OperationData], ): handler_context = OperationHandlerContext( operations=operations, template_values=index_config.template_values, ) - args: List[Union[OperationHandlerContext, TransactionContext, OriginationContext]] = [handler_context] + args: List[Optional[Union[OperationHandlerContext, TransactionContext, OriginationContext, OperationData]]] = [handler_context] for pattern_config, operation in zip(handler_config.pattern, matched_operations): + if operation is None: + args.append(None) + + elif isinstance(pattern_config, OperationHandlerTransactionPatternConfig): + if not pattern_config.entrypoint: + args.append(operation) + continue - if isinstance(pattern_config, OperationHandlerTransactionPatternConfig): parameter_type = pattern_config.parameter_type_cls parameter = parameter_type.parse_obj(operation.parameter_json) if parameter_type else None diff --git a/src/dipdup/templates/operation_handler.py.j2 b/src/dipdup/templates/operation_handler.py.j2 index 95fe54662..c1afde6e9 100644 --- a/src/dipdup/templates/operation_handler.py.j2 +++ b/src/dipdup/templates/operation_handler.py.j2 @@ -1,4 +1,6 @@ -from dipdup.models import OperationHandlerContext, TransactionContext, OriginationContext +from typing import Optional + +from dipdup.models import OperationData, OperationHandlerContext, OriginationContext, TransactionContext import {{ package }}.models as models {% for pattern in patterns %} diff --git a/tests/integration_tests/quipuswap.yml b/tests/integration_tests/quipuswap.yml index efc48efa1..0874c0cbe 100644 --- a/tests/integration_tests/quipuswap.yml +++ b/tests/integration_tests/quipuswap.yml @@ -69,6 +69,14 @@ templates: - type: transaction destination: entrypoint: transfer + - callback: on_fa12_withdraw_profit + pattern: + - type: transaction + destination: + entrypoint: withdrawProfit + - type: transaction + source: + optional: True first_block: 1407528 last_block: 1408934 @@ -110,6 +118,14 @@ templates: - type: transaction destination: entrypoint: transfer + - callback: on_fa20_withdraw_profit + pattern: + - type: transaction + destination: + entrypoint: withdrawProfit + - type: transaction + source: + optional: True first_block: 1407528 last_block: 1408928