diff --git a/src/MSAAdvanced.sol b/src/MSAAdvanced.sol index c0d6593..3abfbca 100644 --- a/src/MSAAdvanced.sol +++ b/src/MSAAdvanced.sol @@ -36,8 +36,8 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyEntryPointOrSelf + withHook { - (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -68,9 +68,6 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } - - // TODO: add correct data - _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -87,11 +84,11 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyExecutorModule + withHook returns ( bytes[] memory returnData // TODO returnData is not used ) { - (address hook, bytes memory hookData) = _preCheck(); (CallType callType, ExecType execType,,) = mode.decode(); // check if calltype is batch or single @@ -130,9 +127,6 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { } else { revert UnsupportedCallType(callType); } - - // TODO: add correct data - _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -167,18 +161,14 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyEntryPointOrSelf + withHook { - (address hook, bytes memory hookData) = _preCheck(); - if (moduleTypeId == MODULE_TYPE_VALIDATOR) _installValidator(module, initData); else if (moduleTypeId == MODULE_TYPE_EXECUTOR) _installExecutor(module, initData); else if (moduleTypeId == MODULE_TYPE_FALLBACK) _installFallbackHandler(module, initData); else if (moduleTypeId == MODULE_TYPE_HOOK) _installHook(module, initData); else revert UnsupportedModuleType(moduleTypeId); emit ModuleInstalled(moduleTypeId, module); - - // TODO: add correct data - _postCheck(hook, hookData, true, new bytes(0)); } /** @@ -192,9 +182,8 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { external payable onlyEntryPointOrSelf + withHook { - (address hook, bytes memory hookData) = _preCheck(); - if (moduleTypeId == MODULE_TYPE_VALIDATOR) { _uninstallValidator(module, deInitData); } else if (moduleTypeId == MODULE_TYPE_EXECUTOR) { @@ -207,9 +196,6 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager { revert UnsupportedModuleType(moduleTypeId); } emit ModuleUninstalled(moduleTypeId, module); - - // TODO: add correct data - _postCheck(hook, hookData, true, new bytes(0)); } /** diff --git a/src/core/HookManager.sol b/src/core/HookManager.sol index 30b84db..81fecce 100644 --- a/src/core/HookManager.sol +++ b/src/core/HookManager.sol @@ -22,24 +22,14 @@ abstract contract HookManager { error HookPostCheckFailed(); error HookAlreadyInstalled(address currentHook); - function _preCheck() internal returns (address hook, bytes memory hookData) { - hook = _getHook(); - if (hook != address(0)) { - hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); - return (hook, hookData); - } - } - - function _postCheck( - address hook, - bytes memory hookData, - bool executionSuccess, - bytes memory executionReturnValue - ) - internal - { - if (hook != address(0)) { - IHook(hook).postCheck(hookData, executionSuccess, executionReturnValue); + modifier withHook() { + address hook = _getHook(); + if (hook == address(0)) { + _; + } else { + bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data); + _; + IHook(hook).postCheck(hookData); } } diff --git a/src/interfaces/IERC7579Account.sol b/src/interfaces/IERC7579Account.sol index 1890fff..6ab3b0f 100644 --- a/src/interfaces/IERC7579Account.sol +++ b/src/interfaces/IERC7579Account.sol @@ -2,7 +2,6 @@ pragma solidity ^0.8.21; import { CallType, ExecType, ModeCode } from "../lib/ModeLib.sol"; -import { PackedUserOperation } from "account-abstraction/interfaces/IAccount.sol"; struct Execution { address target; diff --git a/src/interfaces/IERC7579Module.sol b/src/interfaces/IERC7579Module.sol index 4dc2312..09480a4 100644 --- a/src/interfaces/IERC7579Module.sol +++ b/src/interfaces/IERC7579Module.sol @@ -91,12 +91,7 @@ interface IHook is IModule { external returns (bytes memory hookData); - function postCheck( - bytes calldata hookData, - bool executionSuccess, - bytes calldata executionReturnValue - ) - external; + function postCheck(bytes calldata hookData) external; } interface IFallback is IModule { } diff --git a/test/mocks/MockHook.sol b/test/mocks/MockHook.sol index ce5b953..b3c4996 100644 --- a/test/mocks/MockHook.sol +++ b/test/mocks/MockHook.sol @@ -16,13 +16,7 @@ contract MockHook is IHook { external returns (bytes memory hookData) { } - function postCheck( - bytes calldata hookData, - bool executionSuccess, - bytes calldata executionReturnValue - ) - external - { } + function postCheck(bytes calldata hookData) external { } function isModuleType(uint256 moduleTypeId) external view returns (bool) { return moduleTypeId == MODULE_TYPE_HOOK;