diff --git a/docs/DexManagerFacet.md b/docs/DexManagerFacet.md index 5cc7a3bd9..438e6e8a6 100644 --- a/docs/DexManagerFacet.md +++ b/docs/DexManagerFacet.md @@ -8,4 +8,4 @@ Facets that use swapping inherit from the `Swapper.sol` contract which checks th ## Caution -The DEX Manager manages which contracts and functions can be executed through the LiFi main contract. This can be updated by a single admin key which if compromised could lead to malicious code being added to the allow list. +The DEX Manager manages which contracts and functions can be executed through the LI.FI main contract. This can be updated by a single admin key which if compromised could lead to malicious code being added to the allow list. diff --git a/src/Errors/GenericErrors.sol b/src/Errors/GenericErrors.sol index e4a282bdf..e990e53b5 100644 --- a/src/Errors/GenericErrors.sol +++ b/src/Errors/GenericErrors.sol @@ -15,3 +15,4 @@ error NoTransferToNullAddress(); error NativeAssetTransferFailed(); error InvalidContract(); error InvalidConfig(); +error OnlyContractOwner(); diff --git a/src/LiFiDiamond.sol b/src/LiFiDiamond.sol index bcfdff4f1..8ad432608 100644 --- a/src/LiFiDiamond.sol +++ b/src/LiFiDiamond.sol @@ -5,6 +5,11 @@ import { LibDiamond } from "./Libraries/LibDiamond.sol"; import { IDiamondCut } from "./Interfaces/IDiamondCut.sol"; contract LiFiDiamond { + // LiFiDiamond specific errors + error FunctionDoesNotExist(); + + // --------------------------- + constructor(address _contractOwner, address _diamondCutFacet) payable { LibDiamond.setContractOwner(_contractOwner); @@ -35,7 +40,9 @@ contract LiFiDiamond { // get facet from function selector address facet = ds.selectorToFacetAndPosition[msg.sig].facetAddress; - require(facet != address(0), "Diamond: Function does not exist"); + if (facet == address(0)) { + revert FunctionDoesNotExist(); + } // Execute external function from facet using delegatecall and return any value. // solhint-disable-next-line no-inline-assembly diff --git a/src/Libraries/LibBytes.sol b/src/Libraries/LibBytes.sol index 40e4dcfa7..64ffba82b 100644 --- a/src/Libraries/LibBytes.sol +++ b/src/Libraries/LibBytes.sol @@ -4,6 +4,14 @@ pragma solidity 0.8.13; library LibBytes { // solhint-disable no-inline-assembly + // LibBytes specific errors + error SliceOverflow(); + error SliceOutOfBounds(); + error AddressOutOfBounds(); + error UintOutOfBounds(); + + // ------------------------- + function concat(bytes memory _preBytes, bytes memory _postBytes) internal pure returns (bytes memory) { bytes memory tempBytes; @@ -216,8 +224,8 @@ library LibBytes { uint256 _start, uint256 _length ) internal pure returns (bytes memory) { - require(_length + 31 >= _length, "slice_overflow"); - require(_bytes.length >= _start + _length, "slice_outOfBounds"); + if (_length + 31 < _length) revert SliceOverflow(); + if (_bytes.length < _start + _length) revert SliceOutOfBounds(); bytes memory tempBytes; @@ -277,7 +285,9 @@ library LibBytes { } function toAddress(bytes memory _bytes, uint256 _start) internal pure returns (address) { - require(_bytes.length >= _start + 20, "toAddress_outOfBounds"); + if (_bytes.length < _start + 20) { + revert AddressOutOfBounds(); + } address tempAddress; assembly { @@ -288,7 +298,9 @@ library LibBytes { } function toUint8(bytes memory _bytes, uint256 _start) internal pure returns (uint8) { - require(_bytes.length >= _start + 1, "toUint8_outOfBounds"); + if (_bytes.length < _start + 1) { + revert UintOutOfBounds(); + } uint8 tempUint; assembly { @@ -299,7 +311,9 @@ library LibBytes { } function toUint16(bytes memory _bytes, uint256 _start) internal pure returns (uint16) { - require(_bytes.length >= _start + 2, "toUint16_outOfBounds"); + if (_bytes.length < _start + 2) { + revert UintOutOfBounds(); + } uint16 tempUint; assembly { @@ -310,7 +324,9 @@ library LibBytes { } function toUint32(bytes memory _bytes, uint256 _start) internal pure returns (uint32) { - require(_bytes.length >= _start + 4, "toUint32_outOfBounds"); + if (_bytes.length < _start + 4) { + revert UintOutOfBounds(); + } uint32 tempUint; assembly { @@ -321,7 +337,9 @@ library LibBytes { } function toUint64(bytes memory _bytes, uint256 _start) internal pure returns (uint64) { - require(_bytes.length >= _start + 8, "toUint64_outOfBounds"); + if (_bytes.length < _start + 8) { + revert UintOutOfBounds(); + } uint64 tempUint; assembly { @@ -332,7 +350,9 @@ library LibBytes { } function toUint96(bytes memory _bytes, uint256 _start) internal pure returns (uint96) { - require(_bytes.length >= _start + 12, "toUint96_outOfBounds"); + if (_bytes.length < _start + 12) { + revert UintOutOfBounds(); + } uint96 tempUint; assembly { @@ -343,7 +363,9 @@ library LibBytes { } function toUint128(bytes memory _bytes, uint256 _start) internal pure returns (uint128) { - require(_bytes.length >= _start + 16, "toUint128_outOfBounds"); + if (_bytes.length < _start + 16) { + revert UintOutOfBounds(); + } uint128 tempUint; assembly { @@ -354,7 +376,9 @@ library LibBytes { } function toUint256(bytes memory _bytes, uint256 _start) internal pure returns (uint256) { - require(_bytes.length >= _start + 32, "toUint256_outOfBounds"); + if (_bytes.length < _start + 32) { + revert UintOutOfBounds(); + } uint256 tempUint; assembly { @@ -365,7 +389,9 @@ library LibBytes { } function toBytes32(bytes memory _bytes, uint256 _start) internal pure returns (bytes32) { - require(_bytes.length >= _start + 32, "toBytes32_outOfBounds"); + if (_bytes.length < _start + 32) { + revert UintOutOfBounds(); + } bytes32 tempBytes32; assembly { diff --git a/src/Libraries/LibDiamond.sol b/src/Libraries/LibDiamond.sol index 297333a80..51f1266ce 100644 --- a/src/Libraries/LibDiamond.sol +++ b/src/Libraries/LibDiamond.sol @@ -2,10 +2,25 @@ pragma solidity 0.8.13; import { IDiamondCut } from "../Interfaces/IDiamondCut.sol"; +import { OnlyContractOwner } from "../Errors/GenericErrors.sol"; library LibDiamond { bytes32 internal constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage"); + // Diamond specific errors + error IncorrectFacetCutAction(); + error NoSelectorsInFace(); + error FunctionAlreadyExists(); + error FacetAddressIsZero(); + error FacetAddressIsNotZero(); + error FacetContainsNoCode(); + error FunctionDoesNotExist(); + error FunctionIsImmutable(); + error InitZeroButCalldataNotEmpty(); + error CalldataEmptyButInitNotZero(); + error InitReverted(); + // ---------------- + struct FacetAddressAndPosition { address facetAddress; uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array @@ -53,7 +68,7 @@ library LibDiamond { } function enforceIsContractOwner() internal view { - require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner"); + if (msg.sender != diamondStorage().contractOwner) revert OnlyContractOwner(); } event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata); @@ -73,7 +88,7 @@ library LibDiamond { } else if (action == IDiamondCut.FacetCutAction.Remove) { removeFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors); } else { - revert("LibDiamondCut: Incorrect FacetCutAction"); + revert IncorrectFacetCutAction(); } } emit DiamondCut(_diamondCut, _init, _calldata); @@ -81,9 +96,13 @@ library LibDiamond { } function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { - require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); + if (_functionSelectors.length == 0) { + revert NoSelectorsInFace(); + } DiamondStorage storage ds = diamondStorage(); - require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)"); + if (_facetAddress == address(0)) { + revert FacetAddressIsZero(); + } uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length); // add new facet address if it does not exist if (selectorPosition == 0) { @@ -92,16 +111,22 @@ library LibDiamond { for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; - require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists"); + if (oldFacetAddress != address(0)) { + revert FunctionAlreadyExists(); + } addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; } } function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { - require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); + if (_functionSelectors.length == 0) { + revert NoSelectorsInFace(); + } DiamondStorage storage ds = diamondStorage(); - require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)"); + if (_facetAddress == address(0)) { + revert FacetAddressIsZero(); + } uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length); // add new facet address if it does not exist if (selectorPosition == 0) { @@ -110,7 +135,9 @@ library LibDiamond { for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; - require(oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function"); + if (oldFacetAddress == _facetAddress) { + revert FunctionAlreadyExists(); + } removeFunction(ds, oldFacetAddress, selector); addFunction(ds, selector, selectorPosition, _facetAddress); selectorPosition++; @@ -118,10 +145,14 @@ library LibDiamond { } function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal { - require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); + if (_functionSelectors.length == 0) { + revert NoSelectorsInFace(); + } DiamondStorage storage ds = diamondStorage(); // if function does not exist then do nothing and return - require(_facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)"); + if (_facetAddress != address(0)) { + revert FacetAddressIsNotZero(); + } for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) { bytes4 selector = _functionSelectors[selectorIndex]; address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress; @@ -130,7 +161,7 @@ library LibDiamond { } function addFacet(DiamondStorage storage ds, address _facetAddress) internal { - enforceHasContractCode(_facetAddress, "LibDiamondCut: New facet has no code"); + enforceHasContractCode(_facetAddress); ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds.facetAddresses.length; ds.facetAddresses.push(_facetAddress); } @@ -151,9 +182,13 @@ library LibDiamond { address _facetAddress, bytes4 _selector ) internal { - require(_facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist"); + if (_facetAddress == address(0)) { + revert FunctionDoesNotExist(); + } // an immutable function is a function defined directly in a diamond - require(_facetAddress != address(this), "LibDiamondCut: Can't remove immutable function"); + if (_facetAddress == address(this)) { + revert FunctionIsImmutable(); + } // replace selector with last selector, then delete last selector uint256 selectorPosition = ds.selectorToFacetAndPosition[_selector].functionSelectorPosition; uint256 lastSelectorPosition = ds.facetFunctionSelectors[_facetAddress].functionSelectors.length - 1; @@ -184,11 +219,15 @@ library LibDiamond { function initializeDiamondCut(address _init, bytes memory _calldata) internal { if (_init == address(0)) { - require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty"); + if (_calldata.length != 0) { + revert InitZeroButCalldataNotEmpty(); + } } else { - require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)"); + if (_calldata.length == 0) { + revert CalldataEmptyButInitNotZero(); + } if (_init != address(this)) { - enforceHasContractCode(_init, "LibDiamondCut: _init address has no code"); + enforceHasContractCode(_init); } // solhint-disable-next-line avoid-low-level-calls (bool success, bytes memory error) = _init.delegatecall(_calldata); @@ -197,18 +236,20 @@ library LibDiamond { // bubble up the error revert(string(error)); } else { - revert("LibDiamondCut: _init function reverted"); + revert InitReverted(); } } } } - function enforceHasContractCode(address _contract, string memory _errorMessage) internal view { + function enforceHasContractCode(address _contract) internal view { uint256 contractSize; // solhint-disable-next-line no-inline-assembly assembly { contractSize := extcodesize(_contract) } - require(contractSize > 0, _errorMessage); + if (contractSize == 0) { + revert FacetContainsNoCode(); + } } }