diff --git a/src/MSAAdvanced.sol b/src/MSAAdvanced.sol index 7f87271..dbfc7fd 100644 --- a/src/MSAAdvanced.sol +++ b/src/MSAAdvanced.sol @@ -36,8 +36,8 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyEntryPointOrSelf - withHook { + (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -68,6 +68,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } + _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -84,11 +85,11 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyExecutorModule - withHook returns ( bytes[] memory returnData // TODO returnData is not used ) { + (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -127,6 +128,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } + _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -174,11 +176,17 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { payable onlyEntryPointOrSelf { - if (moduleTypeId == MODULE_TYPE_VALIDATOR) _uninstallValidator(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_EXECUTOR) _uninstallExecutor(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_FALLBACK) _uninstallFallbackHandler(module, deInitData); - else if (moduleTypeId == MODULE_TYPE_HOOK) _uninstallHook(module, deInitData); - else revert UnsupportedModuleType(moduleTypeId); + if (moduleTypeId == MODULE_TYPE_VALIDATOR) { + _uninstallValidator(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_EXECUTOR) { + _uninstallExecutor(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { + _uninstallFallbackHandler(module, deInitData); + } else if (moduleTypeId == MODULE_TYPE_HOOK) { + _uninstallHook(module, deInitData); + } else { + revert UnsupportedModuleType(moduleTypeId); + } emit ModuleUninstalled(moduleTypeId, module); } diff --git a/src/core/HookManager.sol b/src/core/HookManager.sol index 136e411..30b84db 100644 --- a/src/core/HookManager.sol +++ b/src/core/HookManager.sol @@ -4,11 +4,11 @@ pragma solidity ^0.8.21; import "./ModuleManager.sol"; import "../interfaces/IERC7579Account.sol"; import "../interfaces/IERC7579Module.sol"; + /** * @title reference implementation of HookManager * @author zeroknots.eth | rhinestone.wtf */ - abstract contract HookManager { /// @custom:storage-location erc7201:hookmanager.storage.msa struct HookManagerStorage { @@ -22,14 +22,24 @@ abstract contract HookManager { error HookPostCheckFailed(); error HookAlreadyInstalled(address currentHook); - modifier withHook() { - address hook = _getHook(); - if (hook == address(0)) { - _; - } else { - bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.data); - _; - if (!IHook(hook).postCheck(hookData)) revert HookPostCheckFailed(); + function _preCheck() internal returns (address hook, bytes memory hookData) { + hook = _getHook(); + if (hook != address(0)) { + hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); + return (hook, hookData); + } + } + + function _postCheck( + address hook, + bytes memory hookData, + bool executionSuccess, + bytes memory executionReturnValue + ) + internal + { + if (hook != address(0)) { + IHook(hook).postCheck(hookData, executionSuccess, executionReturnValue); } } diff --git a/src/interfaces/IERC7579Module.sol b/src/interfaces/IERC7579Module.sol index b82b552..4dc2312 100644 --- a/src/interfaces/IERC7579Module.sol +++ b/src/interfaces/IERC7579Module.sol @@ -85,11 +85,18 @@ interface IExecutor is IModule { } interface IHook is IModule { function preCheck( address msgSender, + uint256 msgValue, bytes calldata msgData ) external returns (bytes memory hookData); - function postCheck(bytes calldata hookData) external returns (bool success); + + function postCheck( + bytes calldata hookData, + bool executionSuccess, + bytes calldata executionReturnValue + ) + external; } interface IFallback is IModule { } diff --git a/test/mocks/MockHook.sol b/test/mocks/MockHook.sol index e3c0d0c..ce5b953 100644 --- a/test/mocks/MockHook.sol +++ b/test/mocks/MockHook.sol @@ -10,12 +10,19 @@ contract MockHook is IHook { function preCheck( address msgSender, + uint256 msgValue, bytes calldata msgData ) external returns (bytes memory hookData) { } - function postCheck(bytes calldata hookData) external returns (bool success) { } + function postCheck( + bytes calldata hookData, + bool executionSuccess, + bytes calldata executionReturnValue + ) + external + { } function isModuleType(uint256 moduleTypeId) external view returns (bool) { return moduleTypeId == MODULE_TYPE_HOOK;