diff --git a/lib/forge-std b/lib/forge-std index 066ff16c..2f112697 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 066ff16c5c03e6f931cd041fd366bc4be1fae82a +Subproject commit 2f112697506eab12d433a65fdc31a639548fe365 diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol index 6baac4e1..5f68ea32 100644 --- a/src/account/AccountLoupe.sol +++ b/src/account/AccountLoupe.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.19; import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol"; @@ -11,11 +12,13 @@ import { AccountStorage, getAccountStorage, getPermittedCallKey, + HookGroup, toFunctionReferenceArray } from "../libraries/AccountStorage.sol"; import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; abstract contract AccountLoupe is IAccountLoupe { + using EnumerableMap for EnumerableMap.Bytes32ToUintMap; using EnumerableSet for EnumerableSet.AddressSet; error ManifestDiscrepancy(address plugin); @@ -47,23 +50,7 @@ abstract contract AccountLoupe is IAccountLoupe { /// @inheritdoc IAccountLoupe function getExecutionHooks(bytes4 selector) external view returns (ExecutionHooks[] memory execHooks) { - AccountStorage storage _storage = getAccountStorage(); - - FunctionReference[] memory preExecHooks = - toFunctionReferenceArray(_storage.selectorData[selector].executionHooks.preHooks); - - uint256 numHooks = preExecHooks.length; - execHooks = new ExecutionHooks[](numHooks); - - for (uint256 i = 0; i < numHooks;) { - execHooks[i].preExecHook = preExecHooks[i]; - execHooks[i].postExecHook = - _storage.selectorData[selector].executionHooks.associatedPostHooks[preExecHooks[i]]; - - unchecked { - ++i; - } - } + execHooks = _getHooks(getAccountStorage().selectorData[selector].executionHooks); } /// @inheritdoc IAccountLoupe @@ -72,25 +59,8 @@ abstract contract AccountLoupe is IAccountLoupe { view returns (ExecutionHooks[] memory execHooks) { - AccountStorage storage _storage = getAccountStorage(); - bytes24 key = getPermittedCallKey(callingPlugin, selector); - - FunctionReference[] memory prePermittedCallHooks = - toFunctionReferenceArray(_storage.permittedCalls[key].permittedCallHooks.preHooks); - - uint256 numHooks = prePermittedCallHooks.length; - execHooks = new ExecutionHooks[](numHooks); - - for (uint256 i = 0; i < numHooks;) { - execHooks[i].preExecHook = prePermittedCallHooks[i]; - execHooks[i].postExecHook = - _storage.permittedCalls[key].permittedCallHooks.associatedPostHooks[prePermittedCallHooks[i]]; - - unchecked { - ++i; - } - } + execHooks = _getHooks(getAccountStorage().permittedCalls[key].permittedCallHooks); } /// @inheritdoc IAccountLoupe @@ -112,4 +82,67 @@ abstract contract AccountLoupe is IAccountLoupe { function getInstalledPlugins() external view returns (address[] memory pluginAddresses) { pluginAddresses = getAccountStorage().plugins.values(); } + + function _getHooks(HookGroup storage hooks) internal view returns (ExecutionHooks[] memory execHooks) { + uint256 preExecHooksLength = hooks.preHooks.length(); + uint256 postOnlyExecHooksLength = hooks.postOnlyHooks.length(); + uint256 maxExecHooksLength = postOnlyExecHooksLength; + + // There can only be as many associated post hooks to run as there are pre hooks. + for (uint256 i = 0; i < preExecHooksLength;) { + (, uint256 count) = hooks.preHooks.at(i); + unchecked { + maxExecHooksLength += (count + 1); + ++i; + } + } + + // Overallocate on length - not all of this may get filled up. We set the correct length later. + execHooks = new ExecutionHooks[](maxExecHooksLength); + uint256 actualExecHooksLength; + + for (uint256 i = 0; i < preExecHooksLength;) { + (bytes32 key,) = hooks.preHooks.at(i); + FunctionReference preExecHook = FunctionReference.wrap(bytes21(key)); + + uint256 associatedPostExecHooksLength = hooks.associatedPostHooks[preExecHook].length(); + if (associatedPostExecHooksLength > 0) { + for (uint256 j = 0; j < associatedPostExecHooksLength;) { + execHooks[actualExecHooksLength].preExecHook = preExecHook; + (key,) = hooks.associatedPostHooks[preExecHook].at(j); + execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key)); + + unchecked { + ++actualExecHooksLength; + ++j; + } + } + } else { + execHooks[actualExecHooksLength].preExecHook = preExecHook; + + unchecked { + ++actualExecHooksLength; + } + } + + unchecked { + ++i; + } + } + + for (uint256 i = 0; i < postOnlyExecHooksLength;) { + (bytes32 key,) = hooks.postOnlyHooks.at(i); + execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key)); + + unchecked { + ++actualExecHooksLength; + ++i; + } + } + + // Trim the exec hooks array to the actual length, since we may have overallocated. + assembly ("memory-safe") { + mstore(execHooks, actualExecHooksLength) + } + } } diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol index 3fcf2c30..6f47c55a 100644 --- a/src/account/PluginManagerInternals.sol +++ b/src/account/PluginManagerInternals.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity ^0.8.19; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; @@ -27,13 +28,11 @@ import { } from "../interfaces/IPlugin.sol"; abstract contract PluginManagerInternals is IPluginManager { - using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableMap for EnumerableMap.Bytes32ToUintMap; using EnumerableSet for EnumerableSet.AddressSet; error ArrayLengthMismatch(); - error ExecuteFromPluginAlreadySet(bytes4 selector, address plugin); error ExecutionFunctionAlreadySet(bytes4 selector); - error ExecutionHookAlreadySet(bytes4 selector, FunctionReference hook); error InvalidDependenciesProvided(); error InvalidPluginManifest(); error MissingPluginDependency(address dependency); @@ -41,19 +40,16 @@ abstract contract PluginManagerInternals is IPluginManager { error NullPlugin(); error PluginAlreadyInstalled(address plugin); error PluginDependencyViolation(address plugin); - error PermittedCallHookAlreadySet(bytes4 selector, address plugin, FunctionReference hook); error PluginInstallCallbackFailed(address plugin, bytes revertReason); error PluginInterfaceNotSupported(address plugin); error PluginNotInstalled(address plugin); - error PreRuntimeValidationHookAlreadySet(bytes4 selector, FunctionReference hook); - error PreUserOpValidationHookAlreadySet(bytes4 selector, FunctionReference hook); error RuntimeValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction); error UserOpValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction); error PluginApplyHookCallbackFailed(address providingPlugin, bytes revertReason); error PluginUnapplyHookCallbackFailed(address providingPlugin, bytes revertReason); modifier notNullFunction(FunctionReference functionReference) { - if (functionReference == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (functionReference.isEmpty()) { revert NullFunctionReference(); } _; @@ -90,7 +86,7 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (_selectorData.userOpValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (!_selectorData.userOpValidation.isEmpty()) { revert UserOpValidationFunctionAlreadySet(selector, validationFunction); } @@ -112,7 +108,7 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - if (_selectorData.runtimeValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (!_selectorData.runtimeValidation.isEmpty()) { revert RuntimeValidationFunctionAlreadySet(selector, validationFunction); } @@ -133,7 +129,7 @@ abstract contract PluginManagerInternals is IPluginManager { { SelectorData storage _selectorData = getAccountStorage().selectorData[selector]; - _addHooks(_selectorData.executionHooks, selector, preExecHook, postExecHook); + _addHooks(_selectorData.executionHooks, preExecHook, postExecHook); } function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) @@ -170,7 +166,7 @@ abstract contract PluginManagerInternals is IPluginManager { bytes24 permittedCallKey = getPermittedCallKey(plugin, selector); PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey]; - _addHooks(_permittedCalldata.permittedCallHooks, selector, preExecHook, postExecHook); + _addHooks(_permittedCalldata.permittedCallHooks, preExecHook, postExecHook); } function _removePermittedCallHooks( @@ -185,48 +181,39 @@ abstract contract PluginManagerInternals is IPluginManager { _removeHooks(_permittedCallData.permittedCallHooks, preExecHook, postExecHook); } - function _addHooks( - HookGroup storage hooks, - bytes4 selector, - FunctionReference preExecHook, - FunctionReference postExecHook - ) internal { - if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - // add pre or pre/post pair of exec hooks - if (!hooks.preHooks.add(_toSetValue(preExecHook))) { - revert ExecutionHookAlreadySet(selector, preExecHook); - } + function _addHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook) + internal + { + if (!preExecHook.isEmpty()) { + _addOrIncrement(hooks.preHooks, _toSetValue(preExecHook)); - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - hooks.associatedPostHooks[preExecHook] = postExecHook; + if (!postExecHook.isEmpty()) { + _addOrIncrement(hooks.associatedPostHooks[preExecHook], _toSetValue(postExecHook)); } } else { - if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (postExecHook.isEmpty()) { // both pre and post hooks cannot be null revert NullFunctionReference(); } - hooks.postOnlyHooks.add(_toSetValue(postExecHook)); + _addOrIncrement(hooks.postOnlyHooks, _toSetValue(postExecHook)); } } function _removeHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook) internal { - if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - // remove pre or pre/post pair of exec hooks - - // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - hooks.preHooks.remove(_toSetValue(preExecHook)); + if (!preExecHook.isEmpty()) { + _removeOrDecrement(hooks.preHooks, _toSetValue(preExecHook)); - if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - hooks.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + if (!postExecHook.isEmpty()) { + _removeOrDecrement(hooks.associatedPostHooks[preExecHook], _toSetValue(postExecHook)); } } else { - // THe case where both pre and post hooks are null was checked during installation. + // The case where both pre and post hooks are null was checked during installation. // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - hooks.postOnlyHooks.remove(_toSetValue(postExecHook)); + _removeOrDecrement(hooks.postOnlyHooks, _toSetValue(postExecHook)); } } @@ -234,13 +221,10 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(preUserOpValidationHook) { - if ( - !getAccountStorage().selectorData[selector].preUserOpValidationHooks.add( - _toSetValue(preUserOpValidationHook) - ) - ) { - revert PreUserOpValidationHookAlreadySet(selector, preUserOpValidationHook); - } + _addOrIncrement( + getAccountStorage().selectorData[selector].preUserOpValidationHooks, + _toSetValue(preUserOpValidationHook) + ); } function _removePreUserOpValidationHook(bytes4 selector, FunctionReference preUserOpValidationHook) @@ -248,7 +232,8 @@ abstract contract PluginManagerInternals is IPluginManager { notNullFunction(preUserOpValidationHook) { // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - getAccountStorage().selectorData[selector].preUserOpValidationHooks.remove( + _removeOrDecrement( + getAccountStorage().selectorData[selector].preUserOpValidationHooks, _toSetValue(preUserOpValidationHook) ); } @@ -257,13 +242,10 @@ abstract contract PluginManagerInternals is IPluginManager { internal notNullFunction(preRuntimeValidationHook) { - if ( - !getAccountStorage().selectorData[selector].preRuntimeValidationHooks.add( - _toSetValue(preRuntimeValidationHook) - ) - ) { - revert PreRuntimeValidationHookAlreadySet(selector, preRuntimeValidationHook); - } + _addOrIncrement( + getAccountStorage().selectorData[selector].preRuntimeValidationHooks, + _toSetValue(preRuntimeValidationHook) + ); } function _removePreRuntimeValidationHook(bytes4 selector, FunctionReference preRuntimeValidationHook) @@ -271,7 +253,8 @@ abstract contract PluginManagerInternals is IPluginManager { notNullFunction(preRuntimeValidationHook) { // May ignore return value, as the manifest hash is validated to ensure that the hook exists. - getAccountStorage().selectorData[selector].preRuntimeValidationHooks.remove( + _removeOrDecrement( + getAccountStorage().selectorData[selector].preRuntimeValidationHooks, _toSetValue(preRuntimeValidationHook) ); } @@ -855,6 +838,25 @@ abstract contract PluginManagerInternals is IPluginManager { emit PluginUninstalled(plugin, onUninstallSuccess); } + function _addOrIncrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal { + (bool success, uint256 value) = map.tryGet(key); + map.set(key, success ? value + 1 : 0); + } + + /// @return True if the key was removed or its value was decremented, false if the key was not found. + function _removeOrDecrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal returns (bool) { + (bool success, uint256 value) = map.tryGet(key); + if (!success) { + return false; + } + if (value == 0) { + map.remove(key); + } else { + map.set(key, value - 1); + } + return true; + } + function _toSetValue(FunctionReference functionReference) internal pure returns (bytes32) { return bytes32(FunctionReference.unwrap(functionReference)); } diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol index 0e0541af..0caf3b13 100644 --- a/src/account/UpgradeableModularAccount.sol +++ b/src/account/UpgradeableModularAccount.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.19; import {BaseAccount} from "@eth-infinitism/account-abstraction/core/BaseAccount.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; @@ -31,6 +32,7 @@ contract UpgradeableModularAccount is PluginManagerInternals, UUPSUpgradeable { + using EnumerableMap for EnumerableMap.Bytes32ToUintMap; using EnumerableSet for EnumerableSet.Bytes32Set; struct PostExecToRun { @@ -363,22 +365,23 @@ contract UpgradeableModularAccount is UserOperation calldata userOp, bytes32 userOpHash ) internal returns (uint256 validationData) { - if (userOpValidationFunction == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (userOpValidationFunction.isEmpty()) { revert UserOpValidationFunctionMissing(selector); } uint256 currentValidationData; // Do preUserOpValidation hooks - EnumerableSet.Bytes32Set storage preUserOpValidationHooks = + EnumerableMap.Bytes32ToUintMap storage preUserOpValidationHooks = getAccountStorage().selectorData[selector].preUserOpValidationHooks; uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length(); for (uint256 i = 0; i < preUserOpValidationHooksLength;) { - // FunctionReference preUserOpValidationHook = preUserOpValidationHooks[i]; + (bytes32 key,) = preUserOpValidationHooks.at(i); + FunctionReference preUserOpValidationHook = _toFunctionReference(key); - if (!_toFunctionReference(preUserOpValidationHooks.at(i)).isEmptyOrMagicValue()) { - (address plugin, uint8 functionId) = _toFunctionReference(preUserOpValidationHooks.at(i)).unpack(); + if (!preUserOpValidationHook.isEmptyOrMagicValue()) { + (address plugin, uint8 functionId) = preUserOpValidationHook.unpack(); try IPlugin(plugin).preUserOpValidationHook(functionId, userOp, userOpHash) returns ( uint256 returnData ) { @@ -432,12 +435,13 @@ contract UpgradeableModularAccount is AccountStorage storage _storage = getAccountStorage(); FunctionReference runtimeValidationFunction = _storage.selectorData[msg.sig].runtimeValidation; // run all preRuntimeValidation hooks - EnumerableSet.Bytes32Set storage preRuntimeValidationHooks = + EnumerableMap.Bytes32ToUintMap storage preRuntimeValidationHooks = getAccountStorage().selectorData[msg.sig].preRuntimeValidationHooks; uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length(); for (uint256 i = 0; i < preRuntimeValidationHooksLength;) { - FunctionReference preRuntimeValidationHook = _toFunctionReference(preRuntimeValidationHooks.at(i)); + (bytes32 key,) = preRuntimeValidationHooks.at(i); + FunctionReference preRuntimeValidationHook = _toFunctionReference(key); if (!preRuntimeValidationHook.isEmptyOrMagicValue()) { (address plugin, uint8 functionId) = preRuntimeValidationHook.unpack(); @@ -469,7 +473,7 @@ contract UpgradeableModularAccount is revert RuntimeValidationFunctionReverted(plugin, functionId, revertReason); } } else { - if (runtimeValidationFunction == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + if (runtimeValidationFunction.isEmpty()) { revert RuntimeValidationFunctionMissing(msg.sig); } else if (runtimeValidationFunction == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) { revert InvalidConfiguration(); @@ -501,40 +505,60 @@ contract UpgradeableModularAccount is internal returns (PostExecToRun[] memory postHooksToRun) { - uint256 postExecHooksLength = 0; uint256 preExecHooksLength = hooks.preHooks.length(); - // Over-allocate on length, but not all of this may get filled up. - postHooksToRun = new PostExecToRun[](preExecHooksLength + hooks.postOnlyHooks.length()); + uint256 postOnlyHooksLength = hooks.postOnlyHooks.length(); + uint256 maxPostExecHooksLength = postOnlyHooksLength; + + // There can only be as many associated post hooks to run as there are pre hooks. for (uint256 i = 0; i < preExecHooksLength;) { - FunctionReference preExecHook = _toFunctionReference(hooks.preHooks.at(i)); + (, uint256 count) = hooks.preHooks.at(i); + unchecked { + maxPostExecHooksLength += (count + 1); + ++i; + } + } - if (preExecHook.isEmptyOrMagicValue()) { - if (preExecHook == FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY) { - revert AlwaysDenyRule(); - } - // Function reference cannot be 0. If RUNTIME_VALIDATION_BYPASS, revert since it's an invalid - // configuration. - revert InvalidConfiguration(); + // Overallocate on length - not all of this may get filled up. We set the correct length later. + postHooksToRun = new PostExecToRun[](maxPostExecHooksLength); + uint256 actualPostHooksToRunLength; + + // Copy post-only hooks to the array. + for (uint256 i = 0; i < postOnlyHooksLength;) { + (bytes32 key,) = hooks.postOnlyHooks.at(i); + postHooksToRun[actualPostHooksToRunLength].postExecHook = _toFunctionReference(key); + unchecked { + ++actualPostHooksToRunLength; + ++i; } + } - (address plugin, uint8 functionId) = preExecHook.unpack(); - bytes memory preExecHookReturnData; - try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, data) returns ( - bytes memory returnData - ) { - preExecHookReturnData = returnData; - } catch (bytes memory revertReason) { - revert PreExecHookReverted(plugin, functionId, revertReason); + // Then run the pre hooks and copy the associated post hooks (along with their pre hook's return data) to + // the array. + for (uint256 i = 0; i < preExecHooksLength;) { + (bytes32 key,) = hooks.preHooks.at(i); + FunctionReference preExecHook = _toFunctionReference(key); + + if (preExecHook.isEmptyOrMagicValue()) { + // The function reference must be PRE_HOOK_ALWAYS_DENY in this case, because zero and any other + // magic value is unassignable here. + revert AlwaysDenyRule(); } - // Check to see if there is a postExec hook set for this preExec hook - FunctionReference postExecHook = hooks.associatedPostHooks[preExecHook]; - if (FunctionReference.unwrap(postExecHook) != 0) { - postHooksToRun[postExecHooksLength].postExecHook = postExecHook; - postHooksToRun[postExecHooksLength].preExecHookReturnData = preExecHookReturnData; - unchecked { - ++postExecHooksLength; + uint256 associatedPostExecHooksLength = hooks.associatedPostHooks[preExecHook].length(); + if (associatedPostExecHooksLength > 0) { + for (uint256 j = 0; j < associatedPostExecHooksLength;) { + (key,) = hooks.associatedPostHooks[preExecHook].at(j); + postHooksToRun[actualPostHooksToRunLength].postExecHook = _toFunctionReference(key); + postHooksToRun[actualPostHooksToRunLength].preExecHookReturnData = + _runPreExecHook(preExecHook, data); + + unchecked { + ++actualPostHooksToRunLength; + ++j; + } } + } else { + _runPreExecHook(preExecHook, data); } unchecked { @@ -542,37 +566,41 @@ contract UpgradeableModularAccount is } } - // Copy post-only hooks to the end of the array - uint256 postOnlyHooksLength = hooks.postOnlyHooks.length(); - for (uint256 i = 0; i < postOnlyHooksLength;) { - postHooksToRun[postExecHooksLength].postExecHook = _toFunctionReference(hooks.postOnlyHooks.at(i)); - unchecked { - ++postExecHooksLength; - ++i; - } + // Trim the post hook array to the actual length, since we may have overallocated. + assembly ("memory-safe") { + mstore(postHooksToRun, actualPostHooksToRunLength) } } + function _runPreExecHook(FunctionReference preExecHook, bytes calldata data) + internal + returns (bytes memory preExecHookReturnData) + { + (address plugin, uint8 functionId) = preExecHook.unpack(); + try IPlugin(plugin).preExecutionHook(functionId, msg.sender, msg.value, data) returns ( + bytes memory returnData + ) { + preExecHookReturnData = returnData; + } catch (bytes memory revertReason) { + revert PreExecHookReverted(plugin, functionId, revertReason); + } + } + + /// @dev Associated post hooks are run in reverse order of their pre hooks. function _doCachedPostExecHooks(PostExecToRun[] memory postHooksToRun) internal { uint256 postHooksToRunLength = postHooksToRun.length; - for (uint256 i = 0; i < postHooksToRunLength;) { - PostExecToRun memory postHookToRun = postHooksToRun[i]; - FunctionReference postExecHook = postHookToRun.postExecHook; - if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { - // Reached the end of runnable postExec hooks, stop. - // Array may be over-allocated. - return; + for (uint256 i = postHooksToRunLength; i > 0;) { + unchecked { + --i; } + + PostExecToRun memory postHookToRun = postHooksToRun[i]; (address plugin, uint8 functionId) = postHookToRun.postExecHook.unpack(); // solhint-disable-next-line no-empty-blocks try IPlugin(plugin).postExecutionHook(functionId, postHookToRun.preExecHookReturnData) {} catch (bytes memory revertReason) { revert PostExecHookReverted(plugin, functionId, revertReason); } - - unchecked { - ++i; - } } } diff --git a/src/libraries/AccountStorage.sol b/src/libraries/AccountStorage.sol index 070065d0..36d2845d 100644 --- a/src/libraries/AccountStorage.sol +++ b/src/libraries/AccountStorage.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.19; import {IPlugin} from "../interfaces/IPlugin.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; @@ -52,10 +53,10 @@ struct PermittedExternalCallData { // Represets a set of pre- and post- hooks. Used to store both execution hooks and permitted call hooks. struct HookGroup { - EnumerableSet.Bytes32Set preHooks; + EnumerableMap.Bytes32ToUintMap preHooks; // bytes21 key = pre hook function reference - mapping(FunctionReference => FunctionReference) associatedPostHooks; - EnumerableSet.Bytes32Set postOnlyHooks; + mapping(FunctionReference => EnumerableMap.Bytes32ToUintMap) associatedPostHooks; + EnumerableMap.Bytes32ToUintMap postOnlyHooks; } // Represents data associated with a specifc function selector. @@ -66,8 +67,8 @@ struct SelectorData { FunctionReference userOpValidation; FunctionReference runtimeValidation; // The pre validation hooks for this function selector. - EnumerableSet.Bytes32Set preUserOpValidationHooks; - EnumerableSet.Bytes32Set preRuntimeValidationHooks; + EnumerableMap.Bytes32ToUintMap preUserOpValidationHooks; + EnumerableMap.Bytes32ToUintMap preRuntimeValidationHooks; // The execution hooks for this function selector. HookGroup executionHooks; } @@ -100,15 +101,21 @@ function getPermittedCallKey(address addr, bytes4 selector) pure returns (bytes2 } // Helper function to get all elements of a set into memory. -using EnumerableSet for EnumerableSet.Bytes32Set; +using EnumerableMap for EnumerableMap.Bytes32ToUintMap; -function toFunctionReferenceArray(EnumerableSet.Bytes32Set storage set) +function toFunctionReferenceArray(EnumerableMap.Bytes32ToUintMap storage map) view returns (FunctionReference[] memory) { - FunctionReference[] memory result = new FunctionReference[](set.length()); - for (uint256 i = 0; i < set.length(); i++) { - result[i] = FunctionReference.wrap(bytes21(set.at(i))); + uint256 length = map.length(); + FunctionReference[] memory result = new FunctionReference[](length); + for (uint256 i = 0; i < length;) { + (bytes32 key,) = map.at(i); + result[i] = FunctionReference.wrap(bytes21(key)); + + unchecked { + ++i; + } } return result; } diff --git a/src/libraries/FunctionReferenceLib.sol b/src/libraries/FunctionReferenceLib.sol index 0bef591c..897f5bd1 100644 --- a/src/libraries/FunctionReferenceLib.sol +++ b/src/libraries/FunctionReferenceLib.sol @@ -25,6 +25,10 @@ library FunctionReferenceLib { functionId = uint8(bytes1(underlying << 160)); } + function isEmpty(FunctionReference fr) internal pure returns (bool) { + return fr == _EMPTY_FUNCTION_REFERENCE; + } + function isEmptyOrMagicValue(FunctionReference fr) internal pure returns (bool) { return FunctionReference.unwrap(fr) <= bytes21(uint168(2)); } diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol index fbc13f82..cd10df41 100644 --- a/test/account/AccountExecHooks.t.sol +++ b/test/account/AccountExecHooks.t.sol @@ -15,7 +15,7 @@ import { } from "../../src/interfaces/IPlugin.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {SingleOwnerPlugin} from "../../src/plugins/owner/SingleOwnerPlugin.sol"; -import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; import {MockPlugin} from "../mocks/MockPlugin.sol"; import {MSCAFactoryFixture} from "../mocks/MSCAFactoryFixture.sol"; @@ -31,7 +31,9 @@ contract AccountExecHooksTest is OptimizedTest { UpgradeableModularAccount public account; MockPlugin public mockPlugin1; + MockPlugin public mockPlugin2; bytes32 public manifestHash1; + bytes32 public manifestHash2; bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); uint8 internal constant _PRE_HOOK_FUNCTION_ID_1 = 1; @@ -40,6 +42,7 @@ contract AccountExecHooksTest is OptimizedTest { uint8 internal constant _POST_HOOK_FUNCTION_ID_4 = 4; PluginManifest public m1; + PluginManifest public m2; /// @dev Note that we strip hookApplyData from InjectedHooks in this event for gas savings event PluginInstalled( @@ -200,11 +203,223 @@ contract AccountExecHooksTest is OptimizedTest { _uninstallPlugin(mockPlugin1); } + function test_overlappingExecHookPairs_install() public { + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install a second plugin that applies the first plugin's hook pair to the same selector. + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _PRE_HOOK_FUNCTION_ID_1); + dependencies[1] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 1 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairs_run() public { + test_overlappingExecHookPairs_install(); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 1 + ); + + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + } + + function test_overlappingExecHookPairs_uninstall() public { + test_overlappingExecHookPairs_install(); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 1 + ); + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + } + + function test_overlappingExecHookPairsOnPost_install() public { + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_3, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingExecHookPairsOnPost_run() public { + test_overlappingExecHookPairsOnPost_install(); + + // Expect each pre hook to be called once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_3, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 2 + ); + + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + } + + function test_overlappingExecHookPairsOnPost_uninstall() public { + test_overlappingExecHookPairsOnPost_install(); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(this), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 1 + ); + (bool success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + } + function _installPlugin1WithHooks( bytes4 selector, ManifestFunction memory preHook, ManifestFunction memory postHook - ) internal returns (MockPlugin) { + ) internal { m1.executionHooks.push(ManifestExecutionHook(selector, preHook, postHook)); mockPlugin1 = new MockPlugin(m1); manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); @@ -223,8 +438,40 @@ contract AccountExecHooksTest is OptimizedTest { dependencies: new FunctionReference[](0), injectedHooks: new IPluginManager.InjectedHook[](0) }); + } + + function _installPlugin2WithHooks( + bytes4 selector, + ManifestFunction memory preHook, + ManifestFunction memory postHook, + FunctionReference[] memory dependencies + ) internal { + if (preHook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } + if (postHook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } - return mockPlugin1; + m2.executionHooks.push(ManifestExecutionHook(selector, preHook, postHook)); + + mockPlugin2 = new MockPlugin(m2); + manifestHash2 = keccak256(abi.encode(mockPlugin2.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin2), manifestHash2, dependencies, new IPluginManager.InjectedHook[](0) + ); + + account.installPlugin({ + plugin: address(mockPlugin2), + manifestHash: manifestHash2, + pluginInitData: bytes(""), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); } function _uninstallPlugin(MockPlugin plugin) internal { diff --git a/test/account/AccountPermittedCallHooks.t.sol b/test/account/AccountPermittedCallHooks.t.sol index 841760bf..a90eaae0 100644 --- a/test/account/AccountPermittedCallHooks.t.sol +++ b/test/account/AccountPermittedCallHooks.t.sol @@ -208,6 +208,143 @@ contract AccountPermittedCallHooksTest is OptimizedTest { _uninstallPlugin(mockPlugin1); } + function test_overlappingPermittedCallHookPairs_install() public { + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 2] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairs_run() public { + test_overlappingPermittedCallHookPairs_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 1 + ); + + account.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + function test_overlappingPermittedCallHookPairs_uninstall() public { + test_overlappingPermittedCallHookPairs_install(); + + _uninstallPlugin(mockPlugin1); + } + + function test_overlappingPermittedCallHookPairsOnPost_install() public { + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_3, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + } + + /// @dev Plugin hook pair(s): [1, 2], [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingPermittedCallHookPairsOnPost_run() public { + test_overlappingPermittedCallHookPairsOnPost_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect each pre hook to be called once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_3, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.postExecutionHook.selector, _POST_HOOK_FUNCTION_ID_2, ""), + 2 + ); + + account.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + function test_overlappingPermittedCallHookPairsOnPost_uninstall() public { + test_overlappingPermittedCallHookPairsOnPost_install(); + + _uninstallPlugin(mockPlugin1); + } + function _installPlugin1WithHooks(ManifestFunction memory preHook1, ManifestFunction memory postHook1) internal {