Skip to content

Commit

Permalink
black format master (apache#6494)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and kevinthesun committed Sep 18, 2020
1 parent 0cb30ad commit 3ac8815
Show file tree
Hide file tree
Showing 25 changed files with 588 additions and 430 deletions.
1 change: 1 addition & 0 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _linux_compile(output, objects, options, compile_cmd="g++", compile_shared=F
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)


def _windows_shared(output, objects, options):
cmd = ["clang"]
cmd += ["-O2", "-flto=full", "-fuse-ld=lld-link"]
Expand Down
61 changes: 37 additions & 24 deletions python/tvm/micro/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def unarchive(cls, archive_path, base_dir):
The unarchived artifact.
"""
if os.path.exists(base_dir):
raise ValueError(f'base_dir exists: {base_dir}')
raise ValueError(f"base_dir exists: {base_dir}")

base_dir_parent, base_dir_name = os.path.split(base_dir)
temp_dir = os.path.join(base_dir_parent, f'__tvm__{base_dir_name}')
temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}")
os.mkdir(temp_dir)
try:
with tarfile.open(archive_path) as tar_f:
Expand All @@ -75,32 +75,36 @@ def unarchive(cls, archive_path, base_dir):
temp_dir_contents = os.listdir(temp_dir)
if len(temp_dir_contents) != 1:
raise ArtifactBadArchiveError(
'Expected exactly 1 subdirectory at root of archive, got '
f'{temp_dir_contents!r}')
"Expected exactly 1 subdirectory at root of archive, got "
f"{temp_dir_contents!r}"
)

metadata_path = os.path.join(temp_dir, temp_dir_contents[0], 'metadata.json')
metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json")
if not metadata_path:
raise ArtifactBadArchiveError('No metadata.json found in archive')
raise ArtifactBadArchiveError("No metadata.json found in archive")

with open(metadata_path) as metadata_f:
metadata = json.load(metadata_f)

version = metadata.get('version')
version = metadata.get("version")
if version != cls.ENCODING_VERSION:
raise ArtifactBadArchiveError(
f'archive version: expect {cls.EXPECTED_VERSION}, found {version}')
f"archive version: expect {cls.EXPECTED_VERSION}, found {version}"
)

os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir)

artifact_cls = cls
for sub_cls in cls.__subclasses__():
if (sub_cls.ARTIFACT_TYPE is not None and
sub_cls.ARTIFACT_TYPE == metadata.get('artifact_type')):
if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get(
"artifact_type"
):
artifact_cls = sub_cls
break

return artifact_cls.from_unarchived(
base_dir, metadata['labelled_files'], metadata['metadata'])
base_dir, metadata["labelled_files"], metadata["metadata"]
)
finally:
shutil.rmtree(temp_dir)

Expand Down Expand Up @@ -128,7 +132,7 @@ def __init__(self, base_dir, labelled_files, metadata):
for f in files:
f_path = os.path.join(self.base_dir, f)
if not os.path.lexists(f_path):
raise ArtifactFileNotFoundError(f'{f} (label {label}): not found at {f_path}')
raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}")

if os.path.islink(f_path):
link_path = os.path.readlink(f_path)
Expand All @@ -140,7 +144,8 @@ def __init__(self, base_dir, labelled_files, metadata):
link_fullpath = os.path.realpath(link_fullpath)
if not link_fullpath.startswith(self.base_dir):
raise ArtifactBadSymlinkError(
f'{f} (label {label}): symlink points outside artifact tree')
f"{f} (label {label}): symlink points outside artifact tree"
)

def abspath(self, rel_path):
"""Return absolute path to the member with the given relative path."""
Expand Down Expand Up @@ -168,29 +173,37 @@ def archive(self, archive_path):
The value of archive_path, after potentially making the computation describe above.
"""
if os.path.isdir(archive_path):
archive_path = os.path.join(archive_path, f'{os.path.basename(self.base_dir)}.tar')
archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar")

archive_name = os.path.splitext(os.path.basename(archive_path))[0]
with tarfile.open(archive_path, 'w') as tar_f:
with tarfile.open(archive_path, "w") as tar_f:

def _add_file(name, data, f_type):
tar_info = tarfile.TarInfo(name=name)
tar_info.type = f_type
data_bytes = bytes(data, 'utf-8')
data_bytes = bytes(data, "utf-8")
tar_info.size = len(data)
tar_f.addfile(tar_info, io.BytesIO(data_bytes))

_add_file(f'{archive_name}/metadata.json',
json.dumps({'version': self.ENCODING_VERSION,
'labelled_files': self.labelled_files,
'metadata': self.metadata},
indent=2,
sort_keys=True),
tarfile.REGTYPE)
_add_file(
f"{archive_name}/metadata.json",
json.dumps(
{
"version": self.ENCODING_VERSION,
"labelled_files": self.labelled_files,
"metadata": self.metadata,
},
indent=2,
sort_keys=True,
),
tarfile.REGTYPE,
)
for dir_path, _, files in os.walk(self.base_dir):
for f in files:
file_path = os.path.join(dir_path, f)
archive_file_path = os.path.join(
archive_name, os.path.relpath(file_path, self.base_dir))
archive_name, os.path.relpath(file_path, self.base_dir)
)
if not os.path.islink(file_path):
tar_f.add(file_path, archive_file_path, recursive=False)
continue
Expand Down
63 changes: 31 additions & 32 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, root=None, debug=False):
if debug or root is not None:
with util.TempDirectory.set_keep_for_debug():
self.tempdir = util.tempdir(custom_path=root)
_LOG.info('Created debug mode workspace at: %s', self.tempdir.temp_dir)
_LOG.info("Created debug mode workspace at: %s", self.tempdir.temp_dir)
else:
self.tempdir = util.tempdir()

Expand All @@ -50,50 +50,49 @@ def path(self):


# Required C runtime libraries, in link order.
CRT_RUNTIME_LIB_NAMES = ['utvm_rpc_server', 'utvm_rpc_common', 'common']
CRT_RUNTIME_LIB_NAMES = ["utvm_rpc_server", "utvm_rpc_common", "common"]


TVM_ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
TVM_ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))


CRT_ROOT_DIR = os.path.join(TVM_ROOT_DIR, 'src', 'runtime', 'crt')
CRT_ROOT_DIR = os.path.join(TVM_ROOT_DIR, "src", "runtime", "crt")


RUNTIME_LIB_SRC_DIRS = (
[os.path.join(CRT_ROOT_DIR, n) for n in CRT_RUNTIME_LIB_NAMES] +
[os.path.join(TVM_ROOT_DIR,
'3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/'
'libraries/crc16')])
RUNTIME_LIB_SRC_DIRS = [os.path.join(CRT_ROOT_DIR, n) for n in CRT_RUNTIME_LIB_NAMES] + [
os.path.join(
TVM_ROOT_DIR,
"3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/" "libraries/crc16",
)
]


RUNTIME_SRC_REGEX = re.compile(r'^.*\.cc?$', re.IGNORECASE)
RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE)


_CRT_DEFAULT_OPTIONS = {
'ccflags': ['-std=c++11'],
'ldflags': ['-std=gnu++14'],
'include_dirs': [
f'{TVM_ROOT_DIR}/include',
f'{TVM_ROOT_DIR}/3rdparty/dlpack/include',
f'{TVM_ROOT_DIR}/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/'
'TARGET_SDK_11/libraries/crc16/',
f'{TVM_ROOT_DIR}/3rdparty/dmlc-core/include',
f'{CRT_ROOT_DIR}/include'
"ccflags": ["-std=c++11"],
"ldflags": ["-std=gnu++14"],
"include_dirs": [
f"{TVM_ROOT_DIR}/include",
f"{TVM_ROOT_DIR}/3rdparty/dlpack/include",
f"{TVM_ROOT_DIR}/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/"
"TARGET_SDK_11/libraries/crc16/",
f"{TVM_ROOT_DIR}/3rdparty/dmlc-core/include",
f"{CRT_ROOT_DIR}/include",
],
'profile': {
'common': ['-Wno-unused-variable']
}
"profile": {"common": ["-Wno-unused-variable"]},
}


def default_options(target_include_dir):
"""Return default opts passed to Compile commands."""
bin_opts = copy.deepcopy(_CRT_DEFAULT_OPTIONS)
bin_opts['include_dirs'].append(target_include_dir)
bin_opts["include_dirs"].append(target_include_dir)
lib_opts = copy.deepcopy(bin_opts)
lib_opts['profile']['common'].append('-Werror')
lib_opts['cflags'] = ['-Wno-error=incompatible-pointer-types']
return {'bin_opts': bin_opts, 'lib_opts': lib_opts}
lib_opts["profile"]["common"].append("-Werror")
lib_opts["cflags"] = ["-Wno-error=incompatible-pointer-types"]
return {"bin_opts": bin_opts, "lib_opts": lib_opts}


def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=None):
Expand Down Expand Up @@ -121,17 +120,17 @@ def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=No
lib_opts = _CRT_DEFAULT_OPTIONS if lib_opts is None else lib_opts
bin_opts = _CRT_DEFAULT_OPTIONS if bin_opts is None else bin_opts

mod_build_dir = workspace.relpath(os.path.join('build', 'module'))
mod_build_dir = workspace.relpath(os.path.join("build", "module"))
os.makedirs(mod_build_dir)
mod_src_dir = workspace.relpath(os.path.join('src', 'module'))
mod_src_dir = workspace.relpath(os.path.join("src", "module"))
os.makedirs(mod_src_dir)
mod_src_path = os.path.join(mod_src_dir, 'module.c')
module.save(mod_src_path, 'cc')
mod_src_path = os.path.join(mod_src_dir, "module.c")
module.save(mod_src_path, "cc")

libs = []
for lib_src_dir in RUNTIME_LIB_SRC_DIRS:
lib_name = os.path.basename(lib_src_dir)
lib_build_dir = workspace.relpath(f'build/{lib_name}')
lib_build_dir = workspace.relpath(f"build/{lib_name}")
os.makedirs(lib_build_dir)

lib_srcs = []
Expand All @@ -143,6 +142,6 @@ def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=No

libs.append(compiler.library(mod_build_dir, [mod_src_path], lib_opts))

runtime_build_dir = workspace.relpath(f'build/runtime')
runtime_build_dir = workspace.relpath(f"build/runtime")
os.makedirs(runtime_build_dir)
return compiler.binary(runtime_build_dir, libs, bin_opts)
31 changes: 19 additions & 12 deletions python/tvm/micro/class_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ class ClassFactory:
# When not None, the superclass from which all cls must derive.
SUPERCLASS = None

def __init__(self, cls: typing.Callable, init_args: typing.List[JsonSerializable],
init_kw: typing.Dict[str, JsonSerializable]):
def __init__(
self,
cls: typing.Callable,
init_args: typing.List[JsonSerializable],
init_kw: typing.Dict[str, JsonSerializable],
):
self.cls = cls
self.init_args = init_args
self.init_kw = init_kw
Expand All @@ -55,13 +59,15 @@ def instantiate(self):

@property
def to_json(self):
return json.dumps({
'cls': '.'.join([self.cls.__module__, self.cls.__name__]),
'init_args': self.init_args,
'init_kw': self.init_kw,
})
return json.dumps(
{
"cls": ".".join([self.cls.__module__, self.cls.__name__]),
"init_args": self.init_args,
"init_kw": self.init_kw,
}
)

EXPECTED_KEYS = ('cls', 'init_args', 'init_kw')
EXPECTED_KEYS = ("cls", "init_args", "init_kw")

@classmethod
def from_json(cls, data):
Expand All @@ -84,14 +90,15 @@ def from_json(cls, data):
"""
obj = json.loads(data)
if not isinstance(obj, dict):
raise SerializedFactoryError(f'deserialized json payload: want dict, got: {obj!r}')
raise SerializedFactoryError(f"deserialized json payload: want dict, got: {obj!r}")

for key in cls.EXPECTED_KEYS:
if key not in obj:
raise SerializedFactoryError(
f'deserialized json payload: expect key {key}, got: {obj!r}')
f"deserialized json payload: expect key {key}, got: {obj!r}"
)

cls_package_name, cls_name = obj['cls'].rsplit('.', 1)
cls_package_name, cls_name = obj["cls"].rsplit(".", 1)
cls_package = importlib.import_module(cls_package_name)
cls_obj = getattr(cls_package, cls_name)
return cls(cls_obj, obj['init_args'], obj['init_kw'])
return cls(cls_obj, obj["init_args"], obj["init_kw"])

0 comments on commit 3ac8815

Please sign in to comment.