Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/DexManagerFacet.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ Facets that use swapping inherit from the `Swapper.sol` contract which checks th

## Caution

The DEX Manager manages which contracts and functions can be executed through the LiFi main contract. This can be updated by a single admin key which if compromised could lead to malicious code being added to the allow list.
The DEX Manager manages which contracts and functions can be executed through the LI.FI main contract. This can be updated by a single admin key which if compromised could lead to malicious code being added to the allow list.
1 change: 1 addition & 0 deletions src/Errors/GenericErrors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ error NoTransferToNullAddress();
error NativeAssetTransferFailed();
error InvalidContract();
error InvalidConfig();
error OnlyContractOwner();
9 changes: 8 additions & 1 deletion src/LiFiDiamond.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import { LibDiamond } from "./Libraries/LibDiamond.sol";
import { IDiamondCut } from "./Interfaces/IDiamondCut.sol";

contract LiFiDiamond {
// LiFiDiamond specific errors
error FunctionDoesNotExist();

// ---------------------------

constructor(address _contractOwner, address _diamondCutFacet) payable {
LibDiamond.setContractOwner(_contractOwner);

Expand Down Expand Up @@ -35,7 +40,9 @@ contract LiFiDiamond {

// get facet from function selector
address facet = ds.selectorToFacetAndPosition[msg.sig].facetAddress;
require(facet != address(0), "Diamond: Function does not exist");
if (facet == address(0)) {
revert FunctionDoesNotExist();
}

// Execute external function from facet using delegatecall and return any value.
// solhint-disable-next-line no-inline-assembly
Expand Down
48 changes: 37 additions & 11 deletions src/Libraries/LibBytes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ pragma solidity 0.8.13;
library LibBytes {
// solhint-disable no-inline-assembly

// LibBytes specific errors
error SliceOverflow();
error SliceOutOfBounds();
error AddressOutOfBounds();
error UintOutOfBounds();

// -------------------------

function concat(bytes memory _preBytes, bytes memory _postBytes) internal pure returns (bytes memory) {
bytes memory tempBytes;

Expand Down Expand Up @@ -216,8 +224,8 @@ library LibBytes {
uint256 _start,
uint256 _length
) internal pure returns (bytes memory) {
require(_length + 31 >= _length, "slice_overflow");
require(_bytes.length >= _start + _length, "slice_outOfBounds");
if (_length + 31 < _length) revert SliceOverflow();
if (_bytes.length < _start + _length) revert SliceOutOfBounds();

bytes memory tempBytes;

Expand Down Expand Up @@ -277,7 +285,9 @@ library LibBytes {
}

function toAddress(bytes memory _bytes, uint256 _start) internal pure returns (address) {
require(_bytes.length >= _start + 20, "toAddress_outOfBounds");
if (_bytes.length < _start + 20) {
revert AddressOutOfBounds();
}
address tempAddress;

assembly {
Expand All @@ -288,7 +298,9 @@ library LibBytes {
}

function toUint8(bytes memory _bytes, uint256 _start) internal pure returns (uint8) {
require(_bytes.length >= _start + 1, "toUint8_outOfBounds");
if (_bytes.length < _start + 1) {
revert UintOutOfBounds();
}
uint8 tempUint;

assembly {
Expand All @@ -299,7 +311,9 @@ library LibBytes {
}

function toUint16(bytes memory _bytes, uint256 _start) internal pure returns (uint16) {
require(_bytes.length >= _start + 2, "toUint16_outOfBounds");
if (_bytes.length < _start + 2) {
revert UintOutOfBounds();
}
uint16 tempUint;

assembly {
Expand All @@ -310,7 +324,9 @@ library LibBytes {
}

function toUint32(bytes memory _bytes, uint256 _start) internal pure returns (uint32) {
require(_bytes.length >= _start + 4, "toUint32_outOfBounds");
if (_bytes.length < _start + 4) {
revert UintOutOfBounds();
}
uint32 tempUint;

assembly {
Expand All @@ -321,7 +337,9 @@ library LibBytes {
}

function toUint64(bytes memory _bytes, uint256 _start) internal pure returns (uint64) {
require(_bytes.length >= _start + 8, "toUint64_outOfBounds");
if (_bytes.length < _start + 8) {
revert UintOutOfBounds();
}
uint64 tempUint;

assembly {
Expand All @@ -332,7 +350,9 @@ library LibBytes {
}

function toUint96(bytes memory _bytes, uint256 _start) internal pure returns (uint96) {
require(_bytes.length >= _start + 12, "toUint96_outOfBounds");
if (_bytes.length < _start + 12) {
revert UintOutOfBounds();
}
uint96 tempUint;

assembly {
Expand All @@ -343,7 +363,9 @@ library LibBytes {
}

function toUint128(bytes memory _bytes, uint256 _start) internal pure returns (uint128) {
require(_bytes.length >= _start + 16, "toUint128_outOfBounds");
if (_bytes.length < _start + 16) {
revert UintOutOfBounds();
}
uint128 tempUint;

assembly {
Expand All @@ -354,7 +376,9 @@ library LibBytes {
}

function toUint256(bytes memory _bytes, uint256 _start) internal pure returns (uint256) {
require(_bytes.length >= _start + 32, "toUint256_outOfBounds");
if (_bytes.length < _start + 32) {
revert UintOutOfBounds();
}
uint256 tempUint;

assembly {
Expand All @@ -365,7 +389,9 @@ library LibBytes {
}

function toBytes32(bytes memory _bytes, uint256 _start) internal pure returns (bytes32) {
require(_bytes.length >= _start + 32, "toBytes32_outOfBounds");
if (_bytes.length < _start + 32) {
revert UintOutOfBounds();
}
bytes32 tempBytes32;

assembly {
Expand Down
79 changes: 60 additions & 19 deletions src/Libraries/LibDiamond.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,25 @@
pragma solidity 0.8.13;

import { IDiamondCut } from "../Interfaces/IDiamondCut.sol";
import { OnlyContractOwner } from "../Errors/GenericErrors.sol";

library LibDiamond {
bytes32 internal constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage");

// Diamond specific errors
error IncorrectFacetCutAction();
error NoSelectorsInFace();
error FunctionAlreadyExists();
error FacetAddressIsZero();
error FacetAddressIsNotZero();
error FacetContainsNoCode();
error FunctionDoesNotExist();
error FunctionIsImmutable();
error InitZeroButCalldataNotEmpty();
error CalldataEmptyButInitNotZero();
error InitReverted();
// ----------------

struct FacetAddressAndPosition {
address facetAddress;
uint96 functionSelectorPosition; // position in facetFunctionSelectors.functionSelectors array
Expand Down Expand Up @@ -53,7 +68,7 @@ library LibDiamond {
}

function enforceIsContractOwner() internal view {
require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner");
if (msg.sender != diamondStorage().contractOwner) revert OnlyContractOwner();
}

event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);
Expand All @@ -73,17 +88,21 @@ library LibDiamond {
} else if (action == IDiamondCut.FacetCutAction.Remove) {
removeFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
revert IncorrectFacetCutAction();
}
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}

function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_functionSelectors.length == 0) {
revert NoSelectorsInFace();
}
DiamondStorage storage ds = diamondStorage();
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
if (_facetAddress == address(0)) {
revert FacetAddressIsZero();
}
uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length);
// add new facet address if it does not exist
if (selectorPosition == 0) {
Expand All @@ -92,16 +111,22 @@ library LibDiamond {
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress;
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
if (oldFacetAddress != address(0)) {
revert FunctionAlreadyExists();
}
addFunction(ds, selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}

function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_functionSelectors.length == 0) {
revert NoSelectorsInFace();
}
DiamondStorage storage ds = diamondStorage();
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
if (_facetAddress == address(0)) {
revert FacetAddressIsZero();
}
uint96 selectorPosition = uint96(ds.facetFunctionSelectors[_facetAddress].functionSelectors.length);
// add new facet address if it does not exist
if (selectorPosition == 0) {
Expand All @@ -110,18 +135,24 @@ library LibDiamond {
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress;
require(oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function");
if (oldFacetAddress == _facetAddress) {
revert FunctionAlreadyExists();
}
removeFunction(ds, oldFacetAddress, selector);
addFunction(ds, selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}

function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_functionSelectors.length == 0) {
revert NoSelectorsInFace();
}
DiamondStorage storage ds = diamondStorage();
// if function does not exist then do nothing and return
require(_facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
if (_facetAddress != address(0)) {
revert FacetAddressIsNotZero();
}
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.selectorToFacetAndPosition[selector].facetAddress;
Expand All @@ -130,7 +161,7 @@ library LibDiamond {
}

function addFacet(DiamondStorage storage ds, address _facetAddress) internal {
enforceHasContractCode(_facetAddress, "LibDiamondCut: New facet has no code");
enforceHasContractCode(_facetAddress);
ds.facetFunctionSelectors[_facetAddress].facetAddressPosition = ds.facetAddresses.length;
ds.facetAddresses.push(_facetAddress);
}
Expand All @@ -151,9 +182,13 @@ library LibDiamond {
address _facetAddress,
bytes4 _selector
) internal {
require(_facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
if (_facetAddress == address(0)) {
revert FunctionDoesNotExist();
}
// an immutable function is a function defined directly in a diamond
require(_facetAddress != address(this), "LibDiamondCut: Can't remove immutable function");
if (_facetAddress == address(this)) {
revert FunctionIsImmutable();
}
// replace selector with last selector, then delete last selector
uint256 selectorPosition = ds.selectorToFacetAndPosition[_selector].functionSelectorPosition;
uint256 lastSelectorPosition = ds.facetFunctionSelectors[_facetAddress].functionSelectors.length - 1;
Expand Down Expand Up @@ -184,11 +219,15 @@ library LibDiamond {

function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty");
if (_calldata.length != 0) {
revert InitZeroButCalldataNotEmpty();
}
} else {
require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)");
if (_calldata.length == 0) {
revert CalldataEmptyButInitNotZero();
}
if (_init != address(this)) {
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
enforceHasContractCode(_init);
}
// solhint-disable-next-line avoid-low-level-calls
(bool success, bytes memory error) = _init.delegatecall(_calldata);
Expand All @@ -197,18 +236,20 @@ library LibDiamond {
// bubble up the error
revert(string(error));
} else {
revert("LibDiamondCut: _init function reverted");
revert InitReverted();
}
}
}
}

function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
function enforceHasContractCode(address _contract) internal view {
uint256 contractSize;
// solhint-disable-next-line no-inline-assembly
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
if (contractSize == 0) {
revert FacetContainsNoCode();
}
}
}