Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve slither-flat #1125

Merged
merged 3 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope
self._custom_errors: Dict[str, "CustomErrorContract"] = {}

# The only str is "*"
self._using_for: Dict[Union[str, Type], List[str]] = {}
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._kind: Optional[str] = None
self._is_interface: bool = False

Expand Down Expand Up @@ -245,7 +245,7 @@ def events_as_dict(self) -> Dict[str, "Event"]:
###################################################################################

@property
def using_for(self) -> Dict[Union[str, Type], List[str]]:
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for

# endregion
Expand Down
2 changes: 1 addition & 1 deletion slither/solc_parsing/solidity_types/type_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
def parse_type(
t: Union[Dict, UnknownType],
caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"],
):
) -> Type:
"""
caller_context can be a SlitherCompilationUnitSolc because we recursively call the function
and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc
Expand Down
45 changes: 24 additions & 21 deletions slither/tools/flattening/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,31 @@ def main():
args = parse_args()

slither = Slither(args.filename, **vars(args))
flat = Flattening(
slither,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
pragma_solidity=args.pragma_solidity,
)

try:
strategy = Strategy[args.strategy]
except KeyError:
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
logger.error(to_log)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)
for compilation_unit in slither.compilation_units:

flat = Flattening(
compilation_unit,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
pragma_solidity=args.pragma_solidity,
)

try:
strategy = Strategy[args.strategy]
except KeyError:
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
logger.error(to_log)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)


if __name__ == "__main__":
Expand Down
152 changes: 113 additions & 39 deletions slither/tools/flattening/flattening.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
import re
import uuid
from collections import namedtuple
from enum import Enum as PythonEnum
from pathlib import Path
from typing import List, Set, Dict, Optional

from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import SolidityFunction, EnumContract, StructureContract
from slither.core.declarations.contract import Contract
from slither.core.slither_core import SlitherCore
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.exceptions import SlitherException
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall, InternalCall
from slither.tools.flattening.export.export import (
Export,
export_as_json,
Expand Down Expand Up @@ -44,15 +48,16 @@ class Flattening:
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods
def __init__(
self,
slither: SlitherCore,
compilation_unit: SlitherCompilationUnit,
external_to_public=False,
remove_assert=False,
private_to_internal=False,
export_path: Optional[str] = None,
pragma_solidity: Optional[str] = None,
):
self._source_codes: Dict[Contract, str] = {}
self._slither: SlitherCore = slither
self._source_codes_top_level: Dict[TopLevel, str] = {}
self._compilation_unit: SlitherCompilationUnit = compilation_unit
self._external_to_public = external_to_public
self._remove_assert = remove_assert
self._use_abi_encoder_v2 = False
Expand All @@ -63,20 +68,32 @@ def __init__(

self._check_abi_encoder_v2()

for contract in slither.contracts:
for contract in compilation_unit.contracts:
self._get_source_code(contract)

self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level)

def _get_source_code_top_level(self, elems: List[TopLevel]) -> None:
for elem in elems:
src_mapping = elem.source_mapping
content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]
self._source_codes_top_level[elem] = content[start:end]

def _check_abi_encoder_v2(self):
"""
Check if ABIEncoderV2 is required
Set _use_abi_encorder_v2
:return:
"""
for compilation_unit in self._slither.compilation_units:
for p in compilation_unit.pragma_directives:
if "ABIEncoderV2" in str(p.directive):
self._use_abi_encoder_v2 = True
return
for p in self._compilation_unit.pragma_directives:
if "ABIEncoderV2" in str(p.directive):
self._use_abi_encoder_v2 = True
return

def _get_source_code(
self, contract: Contract
Expand All @@ -88,7 +105,7 @@ def _get_source_code(
:return:
"""
src_mapping = contract.source_mapping
content = self._slither.source_code[src_mapping["filename_absolute"]]
content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]

Expand Down Expand Up @@ -132,11 +149,9 @@ def _get_source_code(
if self._private_to_internal:
for variable in contract.state_variables_declared:
if variable.visibility == "private":
print(variable.source_mapping)
attributes_start = variable.source_mapping["start"]
attributes_end = attributes_start + variable.source_mapping["length"]
attributes = content[attributes_start:attributes_end]
print(attributes)
regex = re.search(r" private ", attributes)
if regex:
to_patch.append(
Expand Down Expand Up @@ -191,35 +206,54 @@ def _pragmas(self) -> str:
ret += f"pragma solidity {self._pragma_solidity};\n"
else:
# TODO support multiple compiler version
ret += f"pragma solidity {list(self._slither.crytic_compile.compilation_units.values())[0].compiler_version.version};\n"
ret += f"pragma solidity {list(self._compilation_unit.crytic_compile.compilation_units.values())[0].compiler_version.version};\n"

if self._use_abi_encoder_v2:
ret += "pragma experimental ABIEncoderV2;\n"
return ret

def _export_from_type(self, t, contract, exported, list_contract):
def _export_from_type(
self,
t: Type,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
):
if isinstance(t, UserDefinedType):
if isinstance(t.type, (EnumContract, StructureContract)):
if t.type.contract != contract and t.type.contract not in exported:
self._export_list_used_contracts(t.type.contract, exported, list_contract)
t_type = t.type
if isinstance(t_type, (EnumContract, StructureContract)):
if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level
)
else:
assert isinstance(t.type, Contract)
if t.type != contract and t.type not in exported:
self._export_list_used_contracts(t.type, exported, list_contract)
self._export_list_used_contracts(
t.type, exported, list_contract, list_top_level
)
elif isinstance(t, MappingType):
self._export_from_type(t.type_from, contract, exported, list_contract)
self._export_from_type(t.type_to, contract, exported, list_contract)
self._export_from_type(t.type_from, contract, exported, list_contract, list_top_level)
self._export_from_type(t.type_to, contract, exported, list_contract, list_top_level)
elif isinstance(t, ArrayType):
self._export_from_type(t.type, contract, exported, list_contract)
self._export_from_type(t.type, contract, exported, list_contract, list_top_level)

def _export_list_used_contracts( # pylint: disable=too-many-branches
self, contract: Contract, exported: Set[str], list_contract: List[Contract]
self,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
):
# TODO: investigate why this happen
if not isinstance(contract, Contract):
return
if contract.name in exported:
return
exported.add(contract.name)
for inherited in contract.inheritance:
self._export_list_used_contracts(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)

# Find all the external contracts called
externals = contract.all_library_calls + contract.all_high_level_calls
Expand All @@ -228,41 +262,67 @@ def _export_list_used_contracts( # pylint: disable=too-many-branches
externals = list({e[0] for e in externals if e[0] != contract})

for inherited in externals:
self._export_list_used_contracts(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)

for list_libs in contract.using_for.values():
for lib_candidate_type in list_libs:
if isinstance(lib_candidate_type, UserDefinedType):
lib_candidate = lib_candidate_type.type
if isinstance(lib_candidate, Contract):
self._export_list_used_contracts(
lib_candidate, exported, list_contract, list_top_level
)

# Find all the external contracts use as a base type
local_vars = []
for f in contract.functions_declared:
local_vars += f.variables

for v in contract.variables + local_vars:
self._export_from_type(v.type, contract, exported, list_contract)
self._export_from_type(v.type, contract, exported, list_contract, list_top_level)

for s in contract.structures:
for elem in s.elems.values():
self._export_from_type(elem.type, contract, exported, list_contract)
self._export_from_type(elem.type, contract, exported, list_contract, list_top_level)

# Find all convert and "new" operation that can lead to use an external contract
for f in contract.functions_declared:
for ir in f.slithir_operations:
if isinstance(ir, NewContract):
if ir.contract_created != contract and not ir.contract_created in exported:
self._export_list_used_contracts(
ir.contract_created, exported, list_contract
ir.contract_created, exported, list_contract, list_top_level
)
if isinstance(ir, TypeConversion):
self._export_from_type(ir.type, contract, exported, list_contract)
self._export_from_type(
ir.type, contract, exported, list_contract, list_top_level
)

for read in ir.read:
if isinstance(read, TopLevel):
if read not in list_top_level:
list_top_level.append(read)
if isinstance(ir, InternalCall):
function_called = ir.function
if isinstance(function_called, FunctionTopLevel):
list_top_level.append(function_called)

if contract not in list_contract:
list_contract.append(contract)

def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
path = Path(self._export_path, f"{contract.name}.sol")
list_top_level: List[TopLevel] = []
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")

content = ""
content += self._pragmas()

for listed_top_level in list_top_level:
content += self._source_codes_top_level[listed_top_level]
content += "\n"

for listed_contract in list_contracts:
content += self._source_codes[listed_contract]
content += "\n"
Expand All @@ -271,7 +331,7 @@ def _export_contract_with_inheritance(self, contract) -> Export:

def _export_most_derived(self) -> List[Export]:
ret: List[Export] = []
for contract in self._slither.contracts_derived:
for contract in self._compilation_unit.contracts_derived:
ret.append(self._export_contract_with_inheritance(contract))
return ret

Expand All @@ -281,8 +341,13 @@ def _export_all(self) -> List[Export]:
content = ""
content += self._pragmas()

for top_level_content in self._source_codes_top_level.values():
content += "\n"
content += top_level_content
content += "\n"

contract_seen = set()
contract_to_explore = list(self._slither.contracts)
contract_to_explore = list(self._compilation_unit.contracts)

# We only need the inheritance order here, as solc can compile
# a contract that use another contract type (ex: state variable) that he has not seen yet
Expand All @@ -303,9 +368,17 @@ def _export_all(self) -> List[Export]:

def _export_with_import(self) -> List[Export]:
exports: List[Export] = []
for contract in self._slither.contracts:
for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
list_top_level: List[TopLevel] = []
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)

if list_top_level:
logger.info(
"Top level objects are not yet supported with the local import flattening"
)
for elem in list_top_level:
logger.info(f"Missing {elem} for {contract.name}")

path = Path(self._export_path, f"{contract.name}.sol")

Expand Down Expand Up @@ -341,12 +414,13 @@ def export( # pylint: disable=too-many-arguments,too-few-public-methods
elif strategy == Strategy.LocalImport:
exports = self._export_with_import()
else:
contracts = self._slither.get_contract_from_name(target)
if len(contracts) != 1:
contracts = self._compilation_unit.get_contract_from_name(target)
if len(contracts) == 0:
logger.error(f"{target} not found")
return
contract = contracts[0]
exports = [self._export_contract_with_inheritance(contract)]
exports = []
for contract in contracts:
exports.append(self._export_contract_with_inheritance(contract))

if json:
export_as_json(exports, json)
Expand Down