Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved interface code generation in slither.utils.code_generation #1802

Merged
merged 20 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 143 additions & 38 deletions slither/utils/code_generation.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,170 @@
# Functions for generating Solidity code
from typing import TYPE_CHECKING, Optional

from slither.utils.type import convert_type_for_solidity_signature_to_string
from slither.utils.type import (
convert_type_for_solidity_signature_to_string,
export_nested_types_from_variable,
export_return_type_from_variable,
)
from slither.core.solidity_types import (
Type,
UserDefinedType,
MappingType,
ArrayType,
ElementaryType,
)
from slither.core.declarations import Structure, Enum, Contract

if TYPE_CHECKING:
from slither.core.declarations import FunctionContract, Structure, Contract
from slither.core.declarations import FunctionContract, CustomErrorContract
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable


def generate_interface(contract: "Contract") -> str:
# pylint: disable=too-many-arguments
def generate_interface(
contract: "Contract",
unroll_structs: bool = True,
include_events: bool = True,
include_errors: bool = True,
include_enums: bool = True,
include_structs: bool = True,
) -> str:
"""
Generates code for a Solidity interface to the contract.
Args:
contract: A Contract object
contract: A Contract object.
unroll_structs: Whether to use structures' underlying types instead of the user-defined type (default: True).
include_events: Whether to include event signatures in the interface (default: True).
include_errors: Whether to include custom error signatures in the interface (default: True).
include_enums: Whether to include enum definitions in the interface (default: True).
include_structs: Whether to include struct definitions in the interface (default: True).

Returns:
A string with the code for an interface, with function stubs for all public or external functions and
state variables, as well as any events, custom errors and/or structs declared in the contract.
"""
interface = f"interface I{contract.name} {{\n"
for event in contract.events:
name, args = event.signature
interface += f" event {name}({', '.join(args)});\n"
for error in contract.custom_errors:
args = [
convert_type_for_solidity_signature_to_string(arg.type)
.replace("(", "")
.replace(")", "")
for arg in error.parameters
]
interface += f" error {error.name}({', '.join(args)});\n"
for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
for struct in contract.structures:
interface += generate_struct_interface_str(struct)
if include_events:
for event in contract.events:
name, args = event.signature
interface += f" event {name}({', '.join(args)});\n"
if include_errors:
for error in contract.custom_errors:
interface += f" error {generate_custom_error_interface(error, unroll_structs)};\n"
if include_enums:
for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
if include_structs:
for struct in contract.structures:
interface += generate_struct_interface_str(struct, indent=4)
for var in contract.state_variables_entry_points:
interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n"
interface += f" function {generate_interface_variable_signature(var, unroll_structs)};\n"
for func in contract.functions_entry_points:
if func.is_constructor or func.is_fallback or func.is_receive:
continue
interface += f" function {generate_interface_function_signature(func)};\n"
interface += (
f" function {generate_interface_function_signature(func, unroll_structs)};\n"
)
interface += "}\n\n"
return interface


def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]:
def generate_interface_variable_signature(
var: "StateVariable", unroll_structs: bool = True
) -> Optional[str]:
if var.visibility in ["private", "internal"]:
return None
if unroll_structs:
params = [
convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "")
for x in export_nested_types_from_variable(var)
]
returns = [
convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "")
for x in export_return_type_from_variable(var)
]
else:
_, params, _ = var.signature
params = [p + " memory" if p in ["bytes", "string"] else p for p in params]
returns = []
_type = var.type
while isinstance(_type, MappingType):
_type = _type.type_to
while isinstance(_type, (ArrayType, UserDefinedType)):
_type = _type.type
ret = str(_type)
if isinstance(_type, Structure) or (isinstance(_type, Type) and _type.is_dynamic):
ret += " memory"
elif isinstance(_type, Contract):
ret = "address"
returns.append(ret)
return f"{var.name}({','.join(params)}) external returns ({', '.join(returns)})"


def generate_interface_function_signature(
func: "FunctionContract", unroll_structs: bool = True
) -> Optional[str]:
"""
Generates a string of the form:
func_name(type1,type2) external {payable/view/pure} returns (type3)

Args:
func: A FunctionContract object
unroll_structs: Determines whether structs are unrolled into underlying types (default: True)

Returns:
The function interface as a str (contains the return values).
Returns None if the function is private or internal, or is a constructor/fallback/receive.
"""

name, parameters, return_vars = func.signature
def format_var(var: "LocalVariable", unroll: bool) -> str:
if unroll:
return (
convert_type_for_solidity_signature_to_string(var.type)
.replace("(", "")
.replace(")", "")
)
if isinstance(var.type, ArrayType) and isinstance(
var.type.type, (UserDefinedType, ElementaryType)
):
return (
convert_type_for_solidity_signature_to_string(var.type)
.replace("(", "")
.replace(")", "")
+ f" {var.location}"
)
if isinstance(var.type, UserDefinedType):
if isinstance(var.type.type, (Structure, Enum)):
return f"{str(var.type.type)} memory"
if isinstance(var.type.type, Contract):
return "address"
if var.type.is_dynamic:
return f"{var.type} {var.location}"
return str(var.type)

name, _, _ = func.signature
if (
func not in func.contract.functions_entry_points
or func.is_constructor
or func.is_fallback
or func.is_receive
):
return None
view = " view" if func.view else ""
view = " view" if func.view and not func.pure else ""
pure = " pure" if func.pure else ""
payable = " payable" if func.payable else ""
returns = [
convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "")
for ret in func.returns
]
parameters = [
convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "")
for param in func.parameters
]
returns = [format_var(ret, unroll_structs) for ret in func.returns]
parameters = [format_var(param, unroll_structs) for param in func.parameters]
_interface_signature_str = (
name + "(" + ",".join(parameters) + ") external" + payable + pure + view
)
if len(return_vars) > 0:
if len(returns) > 0:
_interface_signature_str += " returns (" + ",".join(returns) + ")"
return _interface_signature_str


def generate_struct_interface_str(struct: "Structure") -> str:
def generate_struct_interface_str(struct: "Structure", indent: int = 0) -> str:
"""
Generates code for a structure declaration in an interface of the form:
struct struct_name {
Expand All @@ -92,13 +173,37 @@ def generate_struct_interface_str(struct: "Structure") -> str:
... ...
}
Args:
struct: A Structure object
struct: A Structure object.
indent: Number of spaces to indent the code block with.

Returns:
The structure declaration code as a string.
"""
definition = f" struct {struct.name} {{\n"
spaces = ""
for _ in range(0, indent):
spaces += " "
definition = f"{spaces}struct {struct.name} {{\n"
for elem in struct.elems_ordered:
definition += f" {elem.type} {elem.name};\n"
definition += " }\n"
if isinstance(elem.type, UserDefinedType):
if isinstance(elem.type.type, (Structure, Enum)):
definition += f"{spaces} {elem.type.type} {elem.name};\n"
elif isinstance(elem.type.type, Contract):
definition += f"{spaces} address {elem.name};\n"
else:
definition += f"{spaces} {elem.type} {elem.name};\n"
definition += f"{spaces}}}\n"
return definition


def generate_custom_error_interface(
error: "CustomErrorContract", unroll_structs: bool = True
) -> str:
args = [
convert_type_for_solidity_signature_to_string(arg.type).replace("(", "").replace(")", "")
if unroll_structs
else str(arg.type.type)
if isinstance(arg.type, UserDefinedType) and isinstance(arg.type.type, (Structure, Enum))
else str(arg.type)
for arg in error.parameters
]
return f"{error.name}({', '.join(args)})"
9 changes: 9 additions & 0 deletions tests/unit/utils/test_code_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from solc_select import solc_select

Expand All @@ -21,3 +22,11 @@ def test_interface_generation() -> None:
expected = file.read()

assert actual == expected

actual = generate_interface(sl.get_contract_from_name("TestContract")[0], unroll_structs=False)
expected_path = os.path.join(TEST_DATA_DIR, "TEST_generated_code_not_unrolled.sol")

with open(expected_path, "r", encoding="utf-8") as file:
expected = file.read()

assert actual == expected
17 changes: 13 additions & 4 deletions tests/unit/utils/test_data/code_generation/CodeGeneration.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ contract TestContract is I {
uint public stateA;
uint private stateB;
address public immutable owner = msg.sender;
mapping(address => mapping(uint => St)) public structs;
mapping(address => mapping(uint => St)) public structsMap;
St[] public structsArray;
I public otherI;

event NoParams();
event Anonymous() anonymous;
Expand All @@ -23,6 +25,10 @@ contract TestContract is I {
uint v;
}

struct Nested{
St st;
}

function err0() public {
revert ErrorSimple();
}
Expand All @@ -44,13 +50,16 @@ contract TestContract is I {
function newSt(uint x) public returns (St memory) {
St memory st;
st.v = x;
structs[msg.sender][x] = st;
structsMap[msg.sender][x] = st;
return st;
}
function getSt(uint x) public view returns (St memory) {
return structs[msg.sender][x];
return structsMap[msg.sender][x];
}
function removeSt(St memory st) public {
delete structs[msg.sender][st.v];
delete structsMap[msg.sender][st.v];
}
function setOtherI(I _i) public {
otherI = _i;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@ interface ITestContract {
struct St {
uint256 v;
}
struct Nested {
St st;
}
function stateA() external returns (uint256);
function owner() external returns (address);
function structs(address,uint256) external returns (uint256);
function structsMap(address,uint256) external returns (uint256);
function structsArray(uint256) external returns (uint256);
function otherI() external returns (address);
function err0() external;
function err1() external;
function err2(uint256,uint256) external;
function newSt(uint256) external returns (uint256);
function getSt(uint256) external view returns (uint256);
function removeSt(uint256) external;
function setOtherI(address) external;
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
interface ITestContract {
event NoParams();
event Anonymous();
event OneParam(address);
event OneParamIndexed(address);
error ErrorWithEnum(SomeEnum);
error ErrorSimple();
error ErrorWithArgs(uint256, uint256);
error ErrorWithStruct(St);
enum SomeEnum { ONE, TWO, THREE }
struct St {
uint256 v;
}
struct Nested {
St st;
}
function stateA() external returns (uint256);
function owner() external returns (address);
function structsMap(address,uint256) external returns (St memory);
function structsArray(uint256) external returns (St memory);
function otherI() external returns (address);
function err0() external;
function err1() external;
function err2(uint256,uint256) external;
function newSt(uint256) external returns (St memory);
function getSt(uint256) external view returns (St memory);
function removeSt(St memory) external;
function setOtherI(address) external;
}