Skip to content

Commit

Permalink
fix: feature broken in brownie
Browse files Browse the repository at this point in the history
feat: feature support for hardhat
  • Loading branch information
joaosantos15 committed May 21, 2021
1 parent dbb2959 commit 9973ce7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 20 deletions.
27 changes: 16 additions & 11 deletions mythx_cli/fuzz/ide/brownie.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mythx_cli.fuzz.exceptions import BuildArtifactsError
from mythx_cli.fuzz.ide.generic import IDEArtifacts, JobBuilder

from ...util import sol_files_by_directory
from ...util import sol_files_by_directory, get_content_from_file
from ...util import files_by_directory

LOGGER = logging.getLogger("mythx-cli")
Expand All @@ -18,20 +18,20 @@ def __init__(self, build_dir=None, targets=None, map_to_original_source=False):
if targets:
include = []
for target in targets:
if not map_to_original_source:
LOGGER.debug(f"Mapping instrumented code")
include.extend(files_by_directory(target, ".sol"))
else:
# We replace .sol with .sol.original in case the target is a file and not a directory
target = target.replace(".sol", ".sol.original")
LOGGER.debug(f"Mapping original code, {target}")
include.extend(files_by_directory(target, ".sol.original"))
# if not map_to_original_source:
LOGGER.debug(f"Mapping instrumented code")
include.extend(files_by_directory(target, ".sol"))
# else:
# # We replace .sol with .sol.original in case the target is a file and not a directory
# target = target.replace(".sol", ".sol.original")
# LOGGER.debug(f"Mapping original code, {target}")
# include.extend(files_by_directory(target, ".sol.original"))
self._include = include

self._build_dir = build_dir or Path("./build/contracts")
build_files_by_source_file = self._get_build_artifacts(self._build_dir)

self._contracts, self._sources = self.fetch_data(build_files_by_source_file)
self._contracts, self._sources = self.fetch_data(build_files_by_source_file, map_to_original_source)

@property
def contracts(self):
Expand All @@ -41,7 +41,7 @@ def contracts(self):
def sources(self):
return self._sources

def fetch_data(self, build_files_by_source_file):
def fetch_data(self, build_files_by_source_file, map_to_original_source=False):
result_contracts = {}
result_sources = {}
for source_file, contracts in build_files_by_source_file.items():
Expand Down Expand Up @@ -83,6 +83,11 @@ def fetch_data(self, build_files_by_source_file):
"source": target_file["source"],
"ast": target_file["ast"],
}

if map_to_original_source and Path(source_file_dep+".original").is_file():
# we check if the current source file has a non instrumented version
# if it does, we include that one as the source code
result_sources[source_file_dep]["source"] = get_content_from_file(source_file_dep+".original")
return result_contracts, result_sources

@staticmethod
Expand Down
24 changes: 18 additions & 6 deletions mythx_cli/fuzz/ide/hardhat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,27 @@
from mythx_cli.fuzz.exceptions import BuildArtifactsError
from mythx_cli.fuzz.ide.generic import IDEArtifacts, JobBuilder

from ...util import sol_files_by_directory
from ...util import sol_files_by_directory, LOGGER, files_by_directory, get_content_from_file


class HardhatArtifacts(IDEArtifacts):
def __init__(self, build_dir=None, targets=None):
def __init__(self, build_dir=None, targets=None, map_to_original_source=False):
self._include = []
if targets:
include = []
for target in targets:
# if not map_to_original_source:
include.extend(sol_files_by_directory(target))
# else:
# # We replace .sol with .sol.original in case the target is a file and not a directory
# target = target.replace(".sol", ".sol.original")
# LOGGER.debug(f"Mapping original code, {target}")
# include.extend(files_by_directory(target, ".sol.original"))
self._include = include

print("-----> ",self._include)
self._build_dir = Path(build_dir).absolute() or Path("./artifacts").absolute()
self._contracts, self._sources = self.fetch_data()
self._contracts, self._sources = self.fetch_data(map_to_original_source)

@property
def contracts(self):
Expand All @@ -29,7 +36,7 @@ def contracts(self):
def sources(self):
return self._sources

def fetch_data(self):
def fetch_data(self, map_to_original_source=False):
result_contracts = {}
result_sources = {}

Expand Down Expand Up @@ -98,12 +105,17 @@ def fetch_data(self):
"ast": data["ast"],
}

if map_to_original_source and Path(source_file_dep+".original").is_file():
# we check if the current source file has a non instrumented version
# if it does, we include that one as the source code
result_sources[source_file_dep]["source"] = get_content_from_file(source_file_dep+".original")

return result_contracts, result_sources


class HardhatJob:
def __init__(self, target: List[str], build_dir: Path):
artifacts = HardhatArtifacts(build_dir, targets=target)
def __init__(self, target: List[str], build_dir: Path, map_to_original_source: bool):
artifacts = HardhatArtifacts(build_dir, targets=target, map_to_original_source=map_to_original_source)
self._jb = JobBuilder(artifacts)
self.payload = None

Expand Down
2 changes: 1 addition & 1 deletion mythx_cli/fuzz/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def fuzz_run(ctx, address, more_addresses, corpus_target, map_to_original_source
artifacts = BrownieJob(target, analyze_config["build_directory"], map_to_original_source=map_to_original_source)
artifacts.generate_payload()
elif ide == IDE.HARDHAT:
artifacts = HardhatJob(target, analyze_config["build_directory"])
artifacts = HardhatJob(target, analyze_config["build_directory"], map_to_original_source=map_to_original_source)
artifacts.generate_payload()
elif ide == IDE.TRUFFLE:
raise click.exceptions.UsageError(
Expand Down
15 changes: 13 additions & 2 deletions mythx_cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def write_or_print(ctx, data: str, mode="a+") -> None:
LOGGER.debug(f"Writing data to {ctx['output']}")
outfile.write(data + "\n")


def sol_files_by_directory(target_path: AnyStr) -> List:
"""Gathers all the .sol files inside the target path
including sub-directories and returns them as a List.
Expand All @@ -260,6 +261,7 @@ def sol_files_by_directory(target_path: AnyStr) -> List:
"""
return files_by_directory(target_path, ".sol")


def files_by_directory(target_path: AnyStr, extension: AnyStr) -> List:
"""Gathers all the target extension files inside the target path
including sub-directories and returns them as a List.
Expand All @@ -281,7 +283,7 @@ def files_by_directory(target_path: AnyStr, extension: AnyStr) -> List:
else:
""" If it's a valid target extension file there is no need to search further and we just append it to our
list to be returned, removing the .original extension, leaving only the .sol """
target_files.append(target_path.replace(".original",""))
target_files.append(target_path.replace(".original", ""))
source_dir = os.walk(target_path)
for sub_dir in source_dir:
if len(sub_dir[2]) > 0:
Expand All @@ -298,5 +300,14 @@ def files_by_directory(target_path: AnyStr, extension: AnyStr) -> List:
file_name = file_prefix + "/" + file
LOGGER.debug(f"Found target extension file: {file_name}")
# We remove the .original extension, added by Scribble
target_files.append(file_name.replace(".original",""))
target_files.append(file_name.replace(".original", ""))
return target_files


def get_content_from_file(file_path: AnyStr) -> AnyStr:
reader = open(file_path)
try:
source_code = reader.read()
finally:
reader.close()
return source_code

0 comments on commit 9973ce7

Please sign in to comment.