Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/MSAAdvanced.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager {
external
payable
onlyEntryPointOrSelf
withHook
{
(address hook, bytes memory hookData) = _preCheck();
(CallType callType, ExecType execType,,) = mode.decode();

// check if calltype is batch or single
Expand Down Expand Up @@ -68,6 +68,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager {
} else {
revert UnsupportedCallType(callType);
}
_postCheck(hook, hookData, true, new bytes(0));
}

/**
Expand All @@ -84,11 +85,11 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager {
external
payable
onlyExecutorModule
withHook
returns (
bytes[] memory returnData // TODO returnData is not used
)
{
(address hook, bytes memory hookData) = _preCheck();
(CallType callType, ExecType execType,,) = mode.decode();

// check if calltype is batch or single
Expand Down Expand Up @@ -127,6 +128,7 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager {
} else {
revert UnsupportedCallType(callType);
}
_postCheck(hook, hookData, true, new bytes(0));
}

/**
Expand Down Expand Up @@ -174,11 +176,17 @@ contract MSAAdvanced is IMSA, ExecutionHelper, ModuleManager, HookManager {
payable
onlyEntryPointOrSelf
{
if (moduleTypeId == MODULE_TYPE_VALIDATOR) _uninstallValidator(module, deInitData);
else if (moduleTypeId == MODULE_TYPE_EXECUTOR) _uninstallExecutor(module, deInitData);
else if (moduleTypeId == MODULE_TYPE_FALLBACK) _uninstallFallbackHandler(module, deInitData);
else if (moduleTypeId == MODULE_TYPE_HOOK) _uninstallHook(module, deInitData);
else revert UnsupportedModuleType(moduleTypeId);
if (moduleTypeId == MODULE_TYPE_VALIDATOR) {
_uninstallValidator(module, deInitData);
} else if (moduleTypeId == MODULE_TYPE_EXECUTOR) {
_uninstallExecutor(module, deInitData);
} else if (moduleTypeId == MODULE_TYPE_FALLBACK) {
_uninstallFallbackHandler(module, deInitData);
} else if (moduleTypeId == MODULE_TYPE_HOOK) {
_uninstallHook(module, deInitData);
} else {
revert UnsupportedModuleType(moduleTypeId);
}
emit ModuleUninstalled(moduleTypeId, module);
}

Expand Down
28 changes: 19 additions & 9 deletions src/core/HookManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ pragma solidity ^0.8.21;
import "./ModuleManager.sol";
import "../interfaces/IERC7579Account.sol";
import "../interfaces/IERC7579Module.sol";

/**
* @title reference implementation of HookManager
* @author zeroknots.eth | rhinestone.wtf
*/

abstract contract HookManager {
/// @custom:storage-location erc7201:hookmanager.storage.msa
struct HookManagerStorage {
Expand All @@ -22,14 +22,24 @@ abstract contract HookManager {
error HookPostCheckFailed();
error HookAlreadyInstalled(address currentHook);

modifier withHook() {
address hook = _getHook();
if (hook == address(0)) {
_;
} else {
bytes memory hookData = IHook(hook).preCheck(msg.sender, msg.data);
_;
if (!IHook(hook).postCheck(hookData)) revert HookPostCheckFailed();
function _preCheck() internal returns (address hook, bytes memory hookData) {
hook = _getHook();
if (hook != address(0)) {
hookData = IHook(hook).preCheck(msg.sender, msg.value, msg.data);
return (hook, hookData);
}
}

function _postCheck(
address hook,
bytes memory hookData,
bool executionSuccess,
bytes memory executionReturnValue
)
internal
{
if (hook != address(0)) {
IHook(hook).postCheck(hookData, executionSuccess, executionReturnValue);
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/interfaces/IERC7579Module.sol
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,18 @@ interface IExecutor is IModule { }
interface IHook is IModule {
function preCheck(
address msgSender,
uint256 msgValue,
bytes calldata msgData
)
external
returns (bytes memory hookData);
function postCheck(bytes calldata hookData) external returns (bool success);

function postCheck(
bytes calldata hookData,
bool executionSuccess,
bytes calldata executionReturnValue
)
external;
}

interface IFallback is IModule { }
9 changes: 8 additions & 1 deletion test/mocks/MockHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@ contract MockHook is IHook {

function preCheck(
address msgSender,
uint256 msgValue,
bytes calldata msgData
)
external
returns (bytes memory hookData)
{ }
function postCheck(bytes calldata hookData) external returns (bool success) { }
function postCheck(
bytes calldata hookData,
bool executionSuccess,
bytes calldata executionReturnValue
)
external
{ }

function isModuleType(uint256 moduleTypeId) external view returns (bool) {
return moduleTypeId == MODULE_TYPE_HOOK;
Expand Down