diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 526fe297..6baac4e1 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -50,14 +50,15 @@ abstract contract AccountLoupe is IAccountLoupe { AccountStorage storage _storage = getAccountStorage(); FunctionReference[] memory preExecHooks = - toFunctionReferenceArray(_storage.selectorData[selector].preExecHooks); + toFunctionReferenceArray(_storage.selectorData[selector].executionHooks.preHooks); uint256 numHooks = preExecHooks.length; execHooks = new ExecutionHooks[](numHooks); for (uint256 i = 0; i < numHooks;) { execHooks[i].preExecHook = preExecHooks[i]; - execHooks[i].postExecHook = _storage.selectorData[selector].associatedPostExecHooks[preExecHooks[i]]; + execHooks[i].postExecHook = + _storage.selectorData[selector].executionHooks.associatedPostHooks[preExecHooks[i]]; unchecked { ++i; @@ -76,7 +77,7 @@ abstract contract AccountLoupe is IAccountLoupe { bytes24 key = getPermittedCallKey(callingPlugin, selector); FunctionReference[] memory prePermittedCallHooks = - toFunctionReferenceArray(_storage.permittedCalls[key].prePermittedCallHooks); + toFunctionReferenceArray(_storage.permittedCalls[key].permittedCallHooks.preHooks); uint256 numHooks = prePermittedCallHooks.length; execHooks = new ExecutionHooks[](numHooks); @@ -84,7 +85,7 @@ abstract contract AccountLoupe is IAccountLoupe { for (uint256 i = 0; i < numHooks;) { execHooks[i].preExecHook = prePermittedCallHooks[i]; execHooks[i].postExecHook = - _storage.permittedCalls[key].associatedPostPermittedCallHooks[prePermittedCallHooks[i]]; + _storage.permittedCalls[key].permittedCallHooks.associatedPostHooks[prePermittedCallHooks[i]]; unchecked { ++i; diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index eb3833e8..3fcf2c30 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -10,6 +10,7 @@ import { SelectorData, PermittedCallData, getPermittedCallKey, + HookGroup, PermittedExternalCallData, StoredInjectedHook } from "../libraries/AccountStorage.sol"; @@ -129,34 +130,18 @@ abstract contract PluginManagerInternals is IPluginManager { function _addExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) internal - notNullFunction(preExecHook) { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (!_selectorData.preExecHooks.add(_toSetValue(preExecHook))) { - // Treat the pre-exec and post-exec hook as a single unit, identified by the pre-exec hook. - // If the pre-exec hook exists, revert. - revert ExecutionHookAlreadySet(selector, preExecHook); - } - - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - _selectorData.associatedPostExecHooks[preExecHook] = postExecHook; - } + _addHooks(_selectorData.executionHooks, selector, preExecHook, postExecHook); } function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) internal - notNullFunction(preExecHook) { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - _selectorData.preExecHooks.remove(_toSetValue(preExecHook)); - - // If the post exec hook is set, clear it. - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - _selectorData.associatedPostExecHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; - } + _removeHooks(_selectorData.executionHooks, preExecHook, postExecHook); } function _enableExecFromPlugin(bytes4 selector, address plugin, AccountStorage storage accountStorage) @@ -181,19 +166,11 @@ abstract contract PluginManagerInternals is IPluginManager { address plugin, FunctionReference preExecHook, FunctionReference postExecHook - ) internal notNullPlugin(plugin) notNullFunction(preExecHook) { + ) internal notNullPlugin(plugin) { bytes24 permittedCallKey = getPermittedCallKey(plugin, selector); PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey]; - if (!_permittedCalldata.prePermittedCallHooks.add(_toSetValue(preExecHook))) { - // Treat the pre-exec and post-exec hook as a single unit, identified by the pre-exec hook. - // If the pre-exec hook exists, revert. - revert PermittedCallHookAlreadySet(selector, plugin, preExecHook); - } - - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - _permittedCalldata.associatedPostPermittedCallHooks[preExecHook] = postExecHook; - } + _addHooks(_permittedCalldata.permittedCallHooks, selector, preExecHook, postExecHook); } function _removePermittedCallHooks( @@ -201,17 +178,55 @@ abstract contract PluginManagerInternals is IPluginManager { address plugin, FunctionReference preExecHook, FunctionReference postExecHook - ) internal notNullPlugin(plugin) notNullFunction(preExecHook) { + ) internal notNullPlugin(plugin) { bytes24 permittedCallKey = getPermittedCallKey(plugin, selector); - PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey]; + PermittedCallData storage _permittedCallData = getAccountStorage().permittedCalls[permittedCallKey]; - // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - _permittedCalldata.prePermittedCallHooks.remove(_toSetValue(preExecHook)); + _removeHooks(_permittedCallData.permittedCallHooks, preExecHook, postExecHook); + } + + function _addHooks( + HookGroup storage hooks, + bytes4 selector, + FunctionReference preExecHook, + FunctionReference postExecHook + ) internal { + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // add pre or pre/post pair of exec hooks + if (!hooks.preHooks.add(_toSetValue(preExecHook))) { + revert ExecutionHookAlreadySet(selector, preExecHook); + } + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + hooks.associatedPostHooks[preExecHook] = postExecHook; + } + } else { + if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // both pre and post hooks cannot be null + revert NullFunctionReference(); + } + + hooks.postOnlyHooks.add(_toSetValue(postExecHook)); + } + } + + function _removeHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook) + internal + { + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // remove pre or pre/post pair of exec hooks + + // May ignore return value, as the manifest hash is validated to ensure that the hook exists. + hooks.preHooks.remove(_toSetValue(preExecHook)); + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + hooks.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + } + } else { + // THe case where both pre and post hooks are null was checked during installation. - // If the post permitted call exec hook is set, clear it. - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - _permittedCalldata.associatedPostPermittedCallHooks[preExecHook] = - FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + // May ignore return value, as the manifest hash is validated to ensure that the hook exists. + hooks.postOnlyHooks.remove(_toSetValue(postExecHook)); } } diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index eb9cb5cc..0e0541af 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -10,7 +10,7 @@ import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeab import {AccountExecutor} from "./AccountExecutor.sol"; import {AccountLoupe} from "./AccountLoupe.sol"; -import {AccountStorage, getAccountStorage, getPermittedCallKey} from "../libraries/AccountStorage.sol"; +import {AccountStorage, HookGroup, getAccountStorage, getPermittedCallKey} from "../libraries/AccountStorage.sol"; import {AccountStorageInitializable} from "./AccountStorageInitializable.sol"; import {FunctionReference, FunctionReferenceLib} from "../libraries/FunctionReferenceLib.sol"; import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol"; @@ -66,7 +66,7 @@ contract UpgradeableModularAccount is modifier wrapNativeFunction() { _doRuntimeValidationIfNotFromEP(); - PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig); + PostExecToRun[] memory postExecHooks = _doPreExecHooks(msg.sig, msg.data); _; @@ -127,7 +127,7 @@ contract UpgradeableModularAccount is PostExecToRun[] memory postExecHooks; // Cache post-exec hooks in memory - postExecHooks = _doPreExecHooks(msg.sig); + postExecHooks = _doPreExecHooks(msg.sig, msg.data); // execute the function, bubbling up any reverts (bool execSuccess, bytes memory execReturnData) = execPlugin.call(msg.data); @@ -188,7 +188,8 @@ contract UpgradeableModularAccount is revert ExecFromPluginNotPermitted(callingPlugin, selector); } - PostExecToRun[] memory postPermittedCallHooks = _doPrePermittedCallHooks(selector, callingPlugin); + PostExecToRun[] memory postPermittedCallHooks = + _doPrePermittedCallHooks(getPermittedCallKey(callingPlugin, selector), data); address execFunctionPlugin = _storage.selectorData[selector].plugin; @@ -196,7 +197,7 @@ contract UpgradeableModularAccount is revert UnrecognizedFunction(selector); } - PostExecToRun[] memory postExecHooks = _doPreExecHooks(selector); + PostExecToRun[] memory postExecHooks = _doPreExecHooks(selector, data); (bool success, bytes memory returnData) = execFunctionPlugin.call(data); @@ -250,11 +251,13 @@ contract UpgradeableModularAccount is // Run any pre plugin exec specific to this caller and the `executeFromPluginExternal` selector - PostExecToRun[] memory postPermittedCallHooks = - _doPrePermittedCallHooks(IPluginExecutor.executeFromPluginExternal.selector, msg.sender); + PostExecToRun[] memory postPermittedCallHooks = _doPrePermittedCallHooks( + getPermittedCallKey(msg.sender, IPluginExecutor.executeFromPluginExternal.selector), msg.data + ); // Run any pre exec hooks for this selector - PostExecToRun[] memory postExecHooks = _doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector); + PostExecToRun[] memory postExecHooks = + _doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector, msg.data); // Perform the external call bytes memory returnData = _exec(target, value, data); @@ -476,68 +479,34 @@ contract UpgradeableModularAccount is } } - function _doPreExecHooks(bytes4 selector) internal returns (PostExecToRun[] memory postHooksToRun) { - EnumerableSet.Bytes32Set storage preExecHooks = getAccountStorage().selectorData[selector].preExecHooks; - - uint256 postExecHooksLength = 0; - uint256 preExecHooksLength = preExecHooks.length(); - - // Over-allocate on length, but not all of this may get filled up. - postHooksToRun = new PostExecToRun[](preExecHooksLength); - for (uint256 i = 0; i < preExecHooksLength;) { - FunctionReference preExecHook = _toFunctionReference(preExecHooks.at(i)); - - if (preExecHook.isEmptyOrMagicValue()) { - if (preExecHook == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) { - revert AlwaysDenyRule(); - } - // Function reference cannot be 0. If _RUNTIME_VALIDATION_ALWAYS_ALLOW, revert since it's an - // invalid configuration. - revert InvalidConfiguration(); - } - - (address plugin, uint8 functionId) = preExecHook.unpack(); - bytes memory preExecHookReturnData; - try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, msg.data) returns ( - bytes memory returnData - ) { - preExecHookReturnData = returnData; - } catch (bytes memory revertReason) { - revert PreExecHookReverted(plugin, functionId, revertReason); - } - - // Check to see if there is a postExec hook set for this preExec hook - FunctionReference postExecHook = - getAccountStorage().selectorData[selector].associatedPostExecHooks[preExecHook]; - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - postHooksToRun[postExecHooksLength].postExecHook = postExecHook; - postHooksToRun[postExecHooksLength].preExecHookReturnData = preExecHookReturnData; - unchecked { - ++postExecHooksLength; - } - } + function _doPreExecHooks(bytes4 selector, bytes calldata data) + internal + returns (PostExecToRun[] memory postHooksToRun) + { + HookGroup storage hooks = getAccountStorage().selectorData[selector].executionHooks; - unchecked { - ++i; - } - } + return _doPreHooks(hooks, data); } - function _doPrePermittedCallHooks(bytes4 executionSelector, address callerPlugin) + function _doPrePermittedCallHooks(bytes24 permittedCallKey, bytes calldata data) internal returns (PostExecToRun[] memory postHooksToRun) { - bytes24 permittedCallKey = getPermittedCallKey(callerPlugin, executionSelector); + HookGroup storage hooks = getAccountStorage().permittedCalls[permittedCallKey].permittedCallHooks; - EnumerableSet.Bytes32Set storage preExecHooks = - getAccountStorage().permittedCalls[permittedCallKey].prePermittedCallHooks; + return _doPreHooks(hooks, data); + } + function _doPreHooks(HookGroup storage hooks, bytes calldata data) + internal + returns (PostExecToRun[] memory postHooksToRun) + { uint256 postExecHooksLength = 0; - uint256 preExecHooksLength = preExecHooks.length(); - postHooksToRun = new PostExecToRun[](preExecHooksLength); // Over-allocate on length, but not all of this - // may get filled up. + uint256 preExecHooksLength = hooks.preHooks.length(); + // Over-allocate on length, but not all of this may get filled up. + postHooksToRun = new PostExecToRun[](preExecHooksLength + hooks.postOnlyHooks.length()); for (uint256 i = 0; i < preExecHooksLength;) { - FunctionReference preExecHook = _toFunctionReference(preExecHooks.at(i)); + FunctionReference preExecHook = _toFunctionReference(hooks.preHooks.at(i)); if (preExecHook.isEmptyOrMagicValue()) { if (preExecHook == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) { @@ -550,7 +519,7 @@ contract UpgradeableModularAccount is (address plugin, uint8 functionId) = preExecHook.unpack(); bytes memory preExecHookReturnData; - try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, msg.data) returns ( + try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, data) returns ( bytes memory returnData ) { preExecHookReturnData = returnData; @@ -559,8 +528,7 @@ contract UpgradeableModularAccount is } // Check to see if there is a postExec hook set for this preExec hook - FunctionReference postExecHook = - getAccountStorage().permittedCalls[permittedCallKey].associatedPostPermittedCallHooks[preExecHook]; + FunctionReference postExecHook = hooks.associatedPostHooks[preExecHook]; if (FunctionReference.unwrap(postExecHook) != 0) { postHooksToRun[postExecHooksLength].postExecHook = postExecHook; postHooksToRun[postExecHooksLength].preExecHookReturnData = preExecHookReturnData; @@ -573,6 +541,16 @@ contract UpgradeableModularAccount is ++i; } } + + // Copy post-only hooks to the end of the array + uint256 postOnlyHooksLength = hooks.postOnlyHooks.length(); + for (uint256 i = 0; i < postOnlyHooksLength;) { + postHooksToRun[postExecHooksLength].postExecHook = _toFunctionReference(hooks.postOnlyHooks.at(i)); + unchecked { + ++postExecHooksLength; + ++i; + } + } } function _doCachedPostExecHooks(PostExecToRun[] memory postHooksToRun) internal { diff --git a/src/libraries/AccountStorage.sol b/src/libraries/AccountStorage.sol index 7439dd56..070065d0 100644 --- a/src/libraries/AccountStorage.sol +++ b/src/libraries/AccountStorage.sol @@ -20,7 +20,7 @@ struct PluginData { StoredInjectedHook[] injectedHooks; } -// A version of IPluginManager. InjectedHook used to track injected hooks in storage. +// A version of IPluginManager.InjectedHook used to track injected hooks in storage. // Omits the hookApplyData field, which is not needed for storage, and flattens the struct. struct StoredInjectedHook { // The plugin that provides the hook @@ -37,9 +37,7 @@ struct StoredInjectedHook { // to interact with another plugin installed on the account. struct PermittedCallData { bool callPermitted; - EnumerableSet.Bytes32Set prePermittedCallHooks; - // bytes21 key = pre exec hook function reference - mapping(FunctionReference => FunctionReference) associatedPostPermittedCallHooks; + HookGroup permittedCallHooks; } // Represents data associated with a plugin's permission to use `executeFromPluginExternal` @@ -52,6 +50,14 @@ struct PermittedExternalCallData { mapping(bytes4 => bool) permittedSelectors; } +// Represets a set of pre- and post- hooks. Used to store both execution hooks and permitted call hooks. +struct HookGroup { + EnumerableSet.Bytes32Set preHooks; + // bytes21 key = pre hook function reference + mapping(FunctionReference => FunctionReference) associatedPostHooks; + EnumerableSet.Bytes32Set postOnlyHooks; +} + // Represents data associated with a specifc function selector. struct SelectorData { // The plugin that implements this execution function. @@ -63,9 +69,7 @@ struct SelectorData { EnumerableSet.Bytes32Set preUserOpValidationHooks; EnumerableSet.Bytes32Set preRuntimeValidationHooks; // The execution hooks for this function selector. - EnumerableSet.Bytes32Set preExecHooks; - // bytes21 key = pre exec hook function reference - mapping(FunctionReference => FunctionReference) associatedPostExecHooks; + HookGroup executionHooks; } struct AccountStorage { @@ -86,7 +90,7 @@ struct AccountStorage { } function getAccountStorage() pure returns (AccountStorage storage _storage) { - assembly { + assembly ("memory-safe") { _storage.slot := _ACCOUNT_STORAGE_SLOT } } diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol new file mode 100644 index 00000000..fbc13f82 --- /dev/null +++ b/test/account/AccountExecHooks.t.sol @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import { + IPlugin, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + ManifestExecutionHook, + ManifestFunction, + PluginManifest +} from "../../src/interfaces/IPlugin.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MockPlugin} from "../mocks/MockPlugin.sol"; +import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; +import {OptimizedTest} from "../utils/OptimizedTest.sol"; + +contract AccountExecHooksTest is OptimizedTest { + using ECDSA for bytes32; + + EntryPoint public entryPoint; + SingleOwnerPlugin public singleOwnerPlugin; + MSCAFactoryFixture public factory; + + UpgradeableModularAccount public account; + + MockPlugin public mockPlugin1; + bytes32 public manifestHash1; + + bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); + uint8 internal constant _PRE_HOOK_FUNCTION_ID_1 = 1; + uint8 internal constant _POST_HOOK_FUNCTION_ID_2 = 2; + uint8 internal constant _PRE_HOOK_FUNCTION_ID_3 = 3; + uint8 internal constant _POST_HOOK_FUNCTION_ID_4 = 4; + + PluginManifest public m1; + + /// @dev Note that we strip hookApplyData from InjectedHooks in this event for gas savings + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + // emitted by MockPlugin + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = new EntryPoint(); + singleOwnerPlugin = _deploySingleOwnerPlugin(); + factory = new MSCAFactoryFixture(entryPoint, singleOwnerPlugin); + + // Create an account with "this" as the owner, so we can execute along the runtime path with regular + // solidity semantics + account = factory.createAccount(address(this), 0); + + m1.executionFunctions.push(_EXEC_SELECTOR); + + m1.runtimeValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: _EXEC_SELECTOR, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + } + + function test_preExecHook_install() public { + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Expected execution: [1, null] + function test_preExecHook_run() public { + test_preExecHook_install(); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + } + + function test_preExecHook_uninstall() public { + test_preExecHook_install(); + + _uninstallPlugin(mockPlugin1); + } + + function test_execHookPair_install() public { + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_execHookPair_run() public { + test_execHookPair_install(); + + vm.expectEmit(true, true, true, true); + // pre hook call + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + vm.expectEmit(true, true, true, true); + // exec call + emit ReceivedCall(abi.encodePacked(_EXEC_SELECTOR), 0); + vm.expectEmit(true, true, true, true); + // post hook call + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + } + + function test_execHookPair_uninstall() public { + test_execHookPair_install(); + + _uninstallPlugin(mockPlugin1); + } + + function test_postOnlyExecHook_install() public { + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyExecHook_run() public { + test_postOnlyExecHook_install(); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + } + + function test_postOnlyExecHook_uninstall() public { + test_postOnlyExecHook_install(); + + _uninstallPlugin(mockPlugin1); + } + + function _installPlugin1WithHooks( + bytes4 selector, + ManifestFunction memory preHook, + ManifestFunction memory postHook + ) internal returns (MockPlugin) { + m1.executionHooks.push(ManifestExecutionHook(selector, preHook, postHook)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + return mockPlugin1; + } + + function _uninstallPlugin(MockPlugin plugin) internal { + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onUninstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + + account.uninstallPlugin(address(plugin), bytes(""), bytes(""), new bytes[](0)); + } +} diff --git a/test/account/AccountPermittedCallHooks.t.sol b/test/account/AccountPermittedCallHooks.t.sol new file mode 100644 index 00000000..841760bf --- /dev/null +++ b/test/account/AccountPermittedCallHooks.t.sol @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import { + IPlugin, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + ManifestExecutionHook, + ManifestFunction, + PluginManifest +} from "../../src/interfaces/IPlugin.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MockPlugin} from "../mocks/MockPlugin.sol"; +import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; +import {OptimizedTest} from "../utils/OptimizedTest.sol"; + +contract AccountPermittedCallHooksTest is OptimizedTest { + using ECDSA for bytes32; + + EntryPoint public entryPoint; + SingleOwnerPlugin public singleOwnerPlugin; + MSCAFactoryFixture public factory; + + UpgradeableModularAccount public account; + + MockPlugin public mockPlugin1; + bytes32 public manifestHash1; + + bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); + uint8 internal constant _PRE_HOOK_FUNCTION_ID_1 = 1; + uint8 internal constant _POST_HOOK_FUNCTION_ID_2 = 2; + uint8 internal constant _PRE_HOOK_FUNCTION_ID_3 = 3; + uint8 internal constant _POST_HOOK_FUNCTION_ID_4 = 4; + + PluginManifest public m1; + + /// @dev Note that we strip hookApplyData from InjectedHooks in this event for gas savings + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + // emitted by MockPlugin + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = new EntryPoint(); + singleOwnerPlugin = _deploySingleOwnerPlugin(); + factory = new MSCAFactoryFixture(entryPoint, singleOwnerPlugin); + + // Create an account with "this" as the owner, so we can execute along the runtime path with regular + // solidity semantics + account = factory.createAccount(address(this), 0); + + m1.executionFunctions.push(_EXEC_SELECTOR); + + m1.runtimeValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: _EXEC_SELECTOR, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + + m1.permittedExecutionSelectors.push(_EXEC_SELECTOR); + } + + function test_prePermittedCallHook_install() public { + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + } + + /// @dev Plugin hook pair(s): [1, null] + /// Expected execution: [1, null] + function test_prePermittedCallHook_run() public { + test_prePermittedCallHook_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + + account.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + function test_prePermittedCallHook_uninstall() public { + test_prePermittedCallHook_install(); + + _uninstallPlugin(mockPlugin1); + } + + function test_permittedCallHookPair_install() public { + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin hook pair(s): [1, 2] + /// Expected execution: [1, 2] + function test_permittedCallHookPair_run() public { + test_permittedCallHookPair_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + // pre hook call + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + vm.expectEmit(true, true, true, true); + // exec call + emit ReceivedCall(abi.encodePacked(_EXEC_SELECTOR), 0); + vm.expectEmit(true, true, true, true); + // post hook call + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + account.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + function test_permittedCallHookPair_uninstall() public { + test_permittedCallHookPair_install(); + + _uninstallPlugin(mockPlugin1); + } + + function test_postOnlyPermittedCallHook_install() public { + _installPlugin1WithHooks( + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin hook pair(s): [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyPermittedCallHook_run() public { + test_postOnlyPermittedCallHook_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + account.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + function test_postOnlyPermittedCallHook_uninstall() public { + test_postOnlyPermittedCallHook_install(); + + _uninstallPlugin(mockPlugin1); + } + + function _installPlugin1WithHooks(ManifestFunction memory preHook1, ManifestFunction memory postHook1) + internal + { + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook1, postHook1)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _installPlugin1WithHooks( + ManifestFunction memory preHook1, + ManifestFunction memory postHook1, + ManifestFunction memory preHook2, + ManifestFunction memory postHook2 + ) internal { + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook1, postHook1)); + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook2, postHook2)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _uninstallPlugin(MockPlugin plugin) internal { + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onUninstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + + account.uninstallPlugin(address(plugin), bytes(""), bytes(""), new bytes[](0)); + } +} diff --git a/test/mocks/MockPlugin.sol b/test/mocks/MockPlugin.sol index 273d5335..1a71e999 100644 --- a/test/mocks/MockPlugin.sol +++ b/test/mocks/MockPlugin.sol @@ -39,7 +39,7 @@ contract MockPlugin is ERC165 { pure returns (function() internal pure returns (PluginManifest memory) fnOut) { - assembly { + assembly ("memory-safe") { fnOut := fnIn } } @@ -82,7 +82,8 @@ contract MockPlugin is ERC165 { || msg.sig == IPlugin.preExecutionHook.selector ) { // return 0 for userOp/runtimeVal case, return bytes("") for preExecutionHook case - assembly { + assembly ("memory-safe") { + mstore(0, 0) return(0x00, 0x20) } }