From 712465e7b2bff9c7ba188ebf5aea0eb280ace40b Mon Sep 17 00:00:00 2001 From: kopy-kat Date: Wed, 28 Feb 2024 21:39:12 +0000 Subject: [PATCH 1/2] feat: hook changes --- src/MSAAdvanced.sol | 22 +++++++++++++++------- src/core/HookManager.sol | 24 +++++++++++++++++------- src/interfaces/IERC7579Module.sol | 9 ++++++++- test/mocks/MockHook.sol | 9 ++++++++- 4 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/MSAAdvanced.sol b/src/MSAAdvanced.sol index 7f87271..dbfc7fd 100644 --- a/src/MSAAdvanced.sol +++ b/src/MSAAdvanced.sol @@ -36,8 +36,8 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyEntryPointOrSelf - withHook { + (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -68,6 +68,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } + _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -84,11 +85,11 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyExecutorModule - withHook returns ( bytes[] memory returnData // TODO returnData is not used ) { + (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -127,6 +128,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } + _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -174,11 +176,17 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { payable onlyEntryPointOrSelf { - if (moduleTypeId == MODULE_TYPE_VALIDATOR) _uninstallValidator(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_EXECUTOR) _uninstallExecutor(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_FALLBACK) _uninstallFallbackHandler(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_HOOK) _uninstallHook(module, deInitData); - else revert UnsupportedModuleType(moduleTypeId); + if (moduleTypeId == MODULE_TYPE_VALIDATOR) { + _uninstallValidator(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_EXECUTOR) { + _uninstallExecutor(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { + _uninstallFallbackHandler(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_HOOK) { + _uninstallHook(module, deInitData); + } else { + revert UnsupportedModuleType(moduleTypeId); + } emit ModuleUninstalled(moduleTypeId, module); } diff --git a/src/core/HookManager.sol b/src/core/HookManager.sol index 136e411..9ba5553 100644 --- a/src/core/HookManager.sol +++ b/src/core/HookManager.sol @@ -22,14 +22,24 @@ abstract contract HookManager { error HookPostCheckFailed(); error HookAlreadyInstalled(address currentHook); - modifier withHook() { + function _preCheck() internal returns (address hook, bytes memory hookData) { address hook = _getHook(); - if (hook == address(0)) { - _; - } else { - bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.data); - _; - if (!IHook(hook).postCheck(hookData)) revert HookPostCheckFailed(); + if (hook != address(0)) { + bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); + return (hook, hookData); + } + } + + function _postCheck( + address hook, + bytes memory hookData, + bool executionSuccess, + bytes memory executionReturnValue + ) + internal + { + if (hook != address(0)) { + IHook(hook).postCheck(hookData, executionSuccess, executionReturnValue); } } diff --git a/src/interfaces/IERC7579Module.sol b/src/interfaces/IERC7579Module.sol index b82b552..4dc2312 100644 --- a/src/interfaces/IERC7579Module.sol +++ b/src/interfaces/IERC7579Module.sol @@ -85,11 +85,18 @@ interface IExecutor is IModule { } interface IHook is IModule { function preCheck( address msgSender, + uint256 msgValue, bytes calldata msgData ) external returns (bytes memory hookData); - function postCheck(bytes calldata hookData) external returns (bool success); + + function postCheck( + bytes calldata hookData, + bool executionSuccess, + bytes calldata executionReturnValue + ) + external; } interface IFallback is IModule { } diff --git a/test/mocks/MockHook.sol b/test/mocks/MockHook.sol index e3c0d0c..ce5b953 100644 --- a/test/mocks/MockHook.sol +++ b/test/mocks/MockHook.sol @@ -10,12 +10,19 @@ contract MockHook is IHook { function preCheck( address msgSender, + uint256 msgValue, bytes calldata msgData ) external returns (bytes memory hookData) { } - function postCheck(bytes calldata hookData) external returns (bool success) { } + function postCheck( + bytes calldata hookData, + bool executionSuccess, + bytes calldata executionReturnValue + ) + external + { } function isModuleType(uint256 moduleTypeId) external view returns (bool) { return moduleTypeId == MODULE_TYPE_HOOK; From 469d7c92b3f6c1e135d51f75d7b0fba7b13967c2 Mon Sep 17 00:00:00 2001 From: kopy-kat Date: Mon, 8 Apr 2024 14:49:39 +0200 Subject: [PATCH 2/2] feat: use return vars --- src/core/HookManager.sol | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/HookManager.sol b/src/core/HookManager.sol index 9ba5553..30b84db 100644 --- a/src/core/HookManager.sol +++ b/src/core/HookManager.sol @@ -4,11 +4,11 @@ pragma solidity ^0.8.21; import "./ModuleManager.sol"; import "../interfaces/IERC7579Account.sol"; import "../interfaces/IERC7579Module.sol"; + /** * @title reference implementation of HookManager * @author zeroknots.eth | rhinestone.wtf */ - abstract contract HookManager { /// @custom:storage-location erc7201:hookmanager.storage.msa struct HookManagerStorage { @@ -23,9 +23,9 @@ abstract contract HookManager { error HookAlreadyInstalled(address currentHook); function _preCheck() internal returns (address hook, bytes memory hookData) { - address hook = _getHook(); + hook = _getHook(); if (hook != address(0)) { - bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); + hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); return (hook, hookData); } }