Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 68 additions & 35 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity ^0.8.19;

import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol";
import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol";
Expand All @@ -11,11 +12,13 @@ import {
AccountStorage,
getAccountStorage,
getPermittedCallKey,
HookGroup,
toFunctionReferenceArray
} from "../libraries/AccountStorage.sol";
import {FunctionReference} from "../libraries/FunctionReferenceLib.sol";

abstract contract AccountLoupe is IAccountLoupe {
using EnumerableMap for EnumerableMap.Bytes32ToUintMap;
using EnumerableSet for EnumerableSet.AddressSet;

error ManifestDiscrepancy(address plugin);
Expand Down Expand Up @@ -47,23 +50,7 @@ abstract contract AccountLoupe is IAccountLoupe {

/// @inheritdoc IAccountLoupe
function getExecutionHooks(bytes4 selector) external view returns (ExecutionHooks[] memory execHooks) {
AccountStorage storage _storage = getAccountStorage();

FunctionReference[] memory preExecHooks =
toFunctionReferenceArray(_storage.selectorData[selector].executionHooks.preHooks);

uint256 numHooks = preExecHooks.length;
execHooks = new ExecutionHooks[](numHooks);

for (uint256 i = 0; i < numHooks;) {
execHooks[i].preExecHook = preExecHooks[i];
execHooks[i].postExecHook =
_storage.selectorData[selector].executionHooks.associatedPostHooks[preExecHooks[i]];

unchecked {
++i;
}
}
execHooks = _getHooks(getAccountStorage().selectorData[selector].executionHooks);
}

/// @inheritdoc IAccountLoupe
Expand All @@ -72,25 +59,8 @@ abstract contract AccountLoupe is IAccountLoupe {
view
returns (ExecutionHooks[] memory execHooks)
{
AccountStorage storage _storage = getAccountStorage();

bytes24 key = getPermittedCallKey(callingPlugin, selector);

FunctionReference[] memory prePermittedCallHooks =
toFunctionReferenceArray(_storage.permittedCalls[key].permittedCallHooks.preHooks);

uint256 numHooks = prePermittedCallHooks.length;
execHooks = new ExecutionHooks[](numHooks);

for (uint256 i = 0; i < numHooks;) {
execHooks[i].preExecHook = prePermittedCallHooks[i];
execHooks[i].postExecHook =
_storage.permittedCalls[key].permittedCallHooks.associatedPostHooks[prePermittedCallHooks[i]];

unchecked {
++i;
}
}
execHooks = _getHooks(getAccountStorage().permittedCalls[key].permittedCallHooks);
}

/// @inheritdoc IAccountLoupe
Expand All @@ -112,4 +82,67 @@ abstract contract AccountLoupe is IAccountLoupe {
function getInstalledPlugins() external view returns (address[] memory pluginAddresses) {
pluginAddresses = getAccountStorage().plugins.values();
}

function _getHooks(HookGroup storage hooks) internal view returns (ExecutionHooks[] memory execHooks) {
uint256 preExecHooksLength = hooks.preHooks.length();
uint256 postOnlyExecHooksLength = hooks.postOnlyHooks.length();
uint256 maxExecHooksLength = postOnlyExecHooksLength;

// There can only be as many associated post hooks to run as there are pre hooks.
for (uint256 i = 0; i < preExecHooksLength;) {
(, uint256 count) = hooks.preHooks.at(i);
unchecked {
maxExecHooksLength += (count + 1);
++i;
}
}

// Overallocate on length - not all of this may get filled up. We set the correct length later.
execHooks = new ExecutionHooks[](maxExecHooksLength);
uint256 actualExecHooksLength;

for (uint256 i = 0; i < preExecHooksLength;) {
(bytes32 key,) = hooks.preHooks.at(i);
FunctionReference preExecHook = FunctionReference.wrap(bytes21(key));

uint256 associatedPostExecHooksLength = hooks.associatedPostHooks[preExecHook].length();
if (associatedPostExecHooksLength > 0) {
for (uint256 j = 0; j < associatedPostExecHooksLength;) {
execHooks[actualExecHooksLength].preExecHook = preExecHook;
(key,) = hooks.associatedPostHooks[preExecHook].at(j);
execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key));

unchecked {
++actualExecHooksLength;
++j;
}
}
} else {
execHooks[actualExecHooksLength].preExecHook = preExecHook;

unchecked {
++actualExecHooksLength;
}
}

unchecked {
++i;
}
}

for (uint256 i = 0; i < postOnlyExecHooksLength;) {
(bytes32 key,) = hooks.postOnlyHooks.at(i);
execHooks[actualExecHooksLength].postExecHook = FunctionReference.wrap(bytes21(key));

unchecked {
++actualExecHooksLength;
++i;
}
}

// Trim the exec hooks array to the actual length, since we may have overallocated.
assembly ("memory-safe") {
mstore(execHooks, actualExecHooksLength)
}
}
}
104 changes: 53 additions & 51 deletions src/account/PluginManagerInternals.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.19;

import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";
import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol";

Expand All @@ -27,33 +28,28 @@ import {
} from "../interfaces/IPlugin.sol";

abstract contract PluginManagerInternals is IPluginManager {
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableMap for EnumerableMap.Bytes32ToUintMap;
using EnumerableSet for EnumerableSet.AddressSet;

error ArrayLengthMismatch();
error ExecuteFromPluginAlreadySet(bytes4 selector, address plugin);
error ExecutionFunctionAlreadySet(bytes4 selector);
error ExecutionHookAlreadySet(bytes4 selector, FunctionReference hook);
error InvalidDependenciesProvided();
error InvalidPluginManifest();
error MissingPluginDependency(address dependency);
error NullFunctionReference();
error NullPlugin();
error PluginAlreadyInstalled(address plugin);
error PluginDependencyViolation(address plugin);
error PermittedCallHookAlreadySet(bytes4 selector, address plugin, FunctionReference hook);
error PluginInstallCallbackFailed(address plugin, bytes revertReason);
error PluginInterfaceNotSupported(address plugin);
error PluginNotInstalled(address plugin);
error PreRuntimeValidationHookAlreadySet(bytes4 selector, FunctionReference hook);
error PreUserOpValidationHookAlreadySet(bytes4 selector, FunctionReference hook);
error RuntimeValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction);
error UserOpValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction);
error PluginApplyHookCallbackFailed(address providingPlugin, bytes revertReason);
error PluginUnapplyHookCallbackFailed(address providingPlugin, bytes revertReason);

modifier notNullFunction(FunctionReference functionReference) {
if (functionReference == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
if (functionReference.isEmpty()) {
revert NullFunctionReference();
}
_;
Expand Down Expand Up @@ -90,7 +86,7 @@ abstract contract PluginManagerInternals is IPluginManager {
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (_selectorData.userOpValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
if (!_selectorData.userOpValidation.isEmpty()) {
revert UserOpValidationFunctionAlreadySet(selector, validationFunction);
}

Expand All @@ -112,7 +108,7 @@ abstract contract PluginManagerInternals is IPluginManager {
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

if (_selectorData.runtimeValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
if (!_selectorData.runtimeValidation.isEmpty()) {
revert RuntimeValidationFunctionAlreadySet(selector, validationFunction);
}

Expand All @@ -133,7 +129,7 @@ abstract contract PluginManagerInternals is IPluginManager {
{
SelectorData storage _selectorData = getAccountStorage().selectorData[selector];

_addHooks(_selectorData.executionHooks, selector, preExecHook, postExecHook);
_addHooks(_selectorData.executionHooks, preExecHook, postExecHook);
}

function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook)
Expand Down Expand Up @@ -170,7 +166,7 @@ abstract contract PluginManagerInternals is IPluginManager {
bytes24 permittedCallKey = getPermittedCallKey(plugin, selector);
PermittedCallData storage _permittedCalldata = getAccountStorage().permittedCalls[permittedCallKey];

_addHooks(_permittedCalldata.permittedCallHooks, selector, preExecHook, postExecHook);
_addHooks(_permittedCalldata.permittedCallHooks, preExecHook, postExecHook);
}

function _removePermittedCallHooks(
Expand All @@ -185,70 +181,59 @@ abstract contract PluginManagerInternals is IPluginManager {
_removeHooks(_permittedCallData.permittedCallHooks, preExecHook, postExecHook);
}

function _addHooks(
HookGroup storage hooks,
bytes4 selector,
FunctionReference preExecHook,
FunctionReference postExecHook
) internal {
if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
// add pre or pre/post pair of exec hooks
if (!hooks.preHooks.add(_toSetValue(preExecHook))) {
revert ExecutionHookAlreadySet(selector, preExecHook);
}
function _addHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook)
internal
{
if (!preExecHook.isEmpty()) {
_addOrIncrement(hooks.preHooks, _toSetValue(preExecHook));

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
hooks.associatedPostHooks[preExecHook] = postExecHook;
if (!postExecHook.isEmpty()) {
_addOrIncrement(hooks.associatedPostHooks[preExecHook], _toSetValue(postExecHook));
}
} else {
if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
if (postExecHook.isEmpty()) {
// both pre and post hooks cannot be null
revert NullFunctionReference();
}

hooks.postOnlyHooks.add(_toSetValue(postExecHook));
_addOrIncrement(hooks.postOnlyHooks, _toSetValue(postExecHook));
}
}

function _removeHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook)
internal
{
if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
// remove pre or pre/post pair of exec hooks

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
hooks.preHooks.remove(_toSetValue(preExecHook));
if (!preExecHook.isEmpty()) {
_removeOrDecrement(hooks.preHooks, _toSetValue(preExecHook));

if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) {
hooks.associatedPostHooks[preExecHook] = FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE;
if (!postExecHook.isEmpty()) {
_removeOrDecrement(hooks.associatedPostHooks[preExecHook], _toSetValue(postExecHook));
}
} else {
// THe case where both pre and post hooks are null was checked during installation.
// The case where both pre and post hooks are null was checked during installation.

// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
hooks.postOnlyHooks.remove(_toSetValue(postExecHook));
_removeOrDecrement(hooks.postOnlyHooks, _toSetValue(postExecHook));
}
}

function _addPreUserOpValidationHook(bytes4 selector, FunctionReference preUserOpValidationHook)
internal
notNullFunction(preUserOpValidationHook)
{
if (
!getAccountStorage().selectorData[selector].preUserOpValidationHooks.add(
_toSetValue(preUserOpValidationHook)
)
) {
revert PreUserOpValidationHookAlreadySet(selector, preUserOpValidationHook);
}
_addOrIncrement(
getAccountStorage().selectorData[selector].preUserOpValidationHooks,
_toSetValue(preUserOpValidationHook)
);
}

function _removePreUserOpValidationHook(bytes4 selector, FunctionReference preUserOpValidationHook)
internal
notNullFunction(preUserOpValidationHook)
{
// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
getAccountStorage().selectorData[selector].preUserOpValidationHooks.remove(
_removeOrDecrement(
getAccountStorage().selectorData[selector].preUserOpValidationHooks,
_toSetValue(preUserOpValidationHook)
);
}
Expand All @@ -257,21 +242,19 @@ abstract contract PluginManagerInternals is IPluginManager {
internal
notNullFunction(preRuntimeValidationHook)
{
if (
!getAccountStorage().selectorData[selector].preRuntimeValidationHooks.add(
_toSetValue(preRuntimeValidationHook)
)
) {
revert PreRuntimeValidationHookAlreadySet(selector, preRuntimeValidationHook);
}
_addOrIncrement(
getAccountStorage().selectorData[selector].preRuntimeValidationHooks,
_toSetValue(preRuntimeValidationHook)
);
}

function _removePreRuntimeValidationHook(bytes4 selector, FunctionReference preRuntimeValidationHook)
internal
notNullFunction(preRuntimeValidationHook)
{
// May ignore return value, as the manifest hash is validated to ensure that the hook exists.
getAccountStorage().selectorData[selector].preRuntimeValidationHooks.remove(
_removeOrDecrement(
getAccountStorage().selectorData[selector].preRuntimeValidationHooks,
_toSetValue(preRuntimeValidationHook)
);
}
Expand Down Expand Up @@ -855,6 +838,25 @@ abstract contract PluginManagerInternals is IPluginManager {
emit PluginUninstalled(plugin, onUninstallSuccess);
}

function _addOrIncrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal {
(bool success, uint256 value) = map.tryGet(key);
map.set(key, success ? value + 1 : 0);
}

/// @return True if the key was removed or its value was decremented, false if the key was not found.
function _removeOrDecrement(EnumerableMap.Bytes32ToUintMap storage map, bytes32 key) internal returns (bool) {
(bool success, uint256 value) = map.tryGet(key);
if (!success) {
return false;
}
if (value == 0) {
map.remove(key);
} else {
map.set(key, value - 1);
}
return true;
}

function _toSetValue(FunctionReference functionReference) internal pure returns (bytes32) {
return bytes32(FunctionReference.unwrap(functionReference));
}
Expand Down
Loading