Skip to content

Commit

Permalink
[Unity] Enable ccache for nn.SourceModule
Browse files Browse the repository at this point in the history
This PR is a mirror PR of apache#16176, and enables ccache in `nn.SourceModule` compilation.
  • Loading branch information
cyx-6 committed Nov 30, 2023
1 parent d52a9bf commit 1fec2b3
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 14 deletions.
52 changes: 42 additions & 10 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_cc():
return None


def create_shared(output, objects, options=None, cc=None):
def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create shared library.
Parameters
Expand All @@ -80,13 +80,19 @@ def create_shared(output, objects, options=None, cc=None):
cc : Optional[str]
The compiler command.
cwd : Optional[str]
The urrent working directory.
ccache_env : Optional[Dict[str, str]]
The environment variable for ccache. Set `None` to disable ccache by default.
"""
cc = cc or get_cc()

if _is_linux_like():
_linux_compile(output, objects, options, cc, compile_shared=True)
_linux_compile(output, objects, options, cc, cwd, ccache_env, compile_shared=True)
elif _is_windows_like():
_windows_compile(output, objects, options)
_windows_compile(output, objects, options, cwd, ccache_env)
else:
raise ValueError("Unsupported platform")

Expand Down Expand Up @@ -139,7 +145,7 @@ def create_staticlib(output, inputs, ar=None):
raise ValueError("Unsupported platform")


def create_executable(output, objects, options=None, cc=None):
def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_env=None):
"""Create executable binary.
Parameters
Expand All @@ -155,13 +161,19 @@ def create_executable(output, objects, options=None, cc=None):
cc : Optional[str]
The compiler command.
cwd : Optional[str]
The urrent working directory.
ccache_env : Optional[Dict[str, str]]
The environment variable for ccache. Set `None` to disable ccache by default.
"""
cc = cc or get_cc()

if _is_linux_like():
_linux_compile(output, objects, options, cc)
_linux_compile(output, objects, options, cc, cwd, ccache_env)
elif _is_windows_like():
_windows_compile(output, objects, options)
_windows_compile(output, objects, options, cwd, ccache_env)
else:
raise ValueError("Unsupported platform")

Expand Down Expand Up @@ -275,7 +287,9 @@ def _fcompile(outputs, objects, options=None):
return _fcompile


def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
def _linux_compile(
output, objects, options, compile_cmd, cwd=None, ccache_env=None, compile_shared=False
):
cmd = [compile_cmd]
if compile_cmd != "nvcc":
if compile_shared or output.endswith(".so") or output.endswith(".dylib"):
Expand All @@ -294,7 +308,15 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
cmd += objects
if options:
cmd += options
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
env = None
if ccache_env is not None:
if shutil.which("ccache"):
cmd.insert(0, "ccache")
env = os.environ.copy()
env.update(ccache_env)
else:
raise ValueError("ccache not found")
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
Expand All @@ -303,7 +325,7 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
raise RuntimeError(msg)


def _windows_compile(output, objects, options):
def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
cmd = ["clang"]
cmd += ["-O2"]

Expand All @@ -318,9 +340,19 @@ def _windows_compile(output, objects, options):
cmd += objects
if options:
cmd += options
env = None
if ccache_env is not None:
if shutil.which("ccache"):
cmd.insert(0, "ccache")
env = os.environ.copy()
env.update(ccache_env)
else:
raise ValueError("ccache not found")

try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, env=env
)
(out, _) = proc.communicate()
except FileNotFoundError:
raise RuntimeError(
Expand Down
19 changes: 15 additions & 4 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
impure external function callings, inplace mutation, etc.
"""
import os
import shutil
import sys
import tempfile
from collections import OrderedDict
Expand Down Expand Up @@ -589,15 +590,23 @@ def _detect_output_suffix(output_format: str) -> str:
compile_options.append("-DDMLC_USE_FOPEN64=0")
compile_options.append("-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>")
with tempfile.TemporaryDirectory() as temp_dir:
source_file = os.path.join(temp_dir, f"main{source_suffix}")
with open(source_file, "w", encoding="utf-8") as file:
source_file = f"main{source_suffix}"
with open(
os.path.join(temp_dir, f"main{source_suffix}"), "w", encoding="utf-8"
) as file:
file.write(source_code)
output_file = os.path.join(temp_dir, f"main{output_suffix}")
output_file = f"main{output_suffix}"
if shutil.which("ccache"):
ccache_env = {"CCACHE_COMPILERCHECK": "content"}
else:
ccache_env = None
_cc.create_shared(
output=output_file,
objects=[source_file],
options=compile_options,
cc=compiler,
cwd=temp_dir,
ccache_env=ccache_env,
)
func_names: List[str] = []
func_specs: List[_spec.ExternFunctionSpec] = []
Expand All @@ -606,7 +615,9 @@ def _detect_output_suffix(output_format: str) -> str:
func_specs.append(func_spec)
if func_spec.symbol is None:
func_spec.symbol = func_name
library = load_static_library(output_file, func_names=func_names)
library = load_static_library(
os.path.join(temp_dir, f"main{output_suffix}"), func_names=func_names
)
module_spec = _spec.ExternModuleSpec(
library=library,
functions=func_specs,
Expand Down
79 changes: 79 additions & 0 deletions tests/python/contrib/test_ccache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test contrib.cc with ccache"""
import os
import pytest
import shutil
import tempfile
import tvm
from tvm.contrib.cc import create_shared, create_executable, _is_linux_like, _is_windows_like


def _src_gen(text):
return """
#include <iostream>
int main() {
std::cout << "text";
return 0;
}""".replace(
"text", text
)


def _compile(f_create, text, output):
with tempfile.TemporaryDirectory() as temp_dir:
src_path = os.path.join(temp_dir, "src.cpp")
with open(src_path, "w", encoding="utf-8") as file:
file.write(_src_gen(text))
log_path = os.path.join(temp_dir, "log.txt")
ccache_env = {
"CCACHE_COMPILERCHECK": "content",
"CCACHE_LOGFILE": log_path,
}
f_create(output, ["src.cpp"], ["-c"], cwd=temp_dir, ccache_env=ccache_env)
with open(log_path, "r", encoding="utf-8") as file:
log = file.read()
return log


@pytest.mark.skipif(shutil.which("ccache") is None, reason="ccache not installed")
def test_shared():
if _is_linux_like():
_ = _compile(create_shared, "shared", "main.o")
log = _compile(create_shared, "shared", "main.o")
assert "Succeeded getting cached result" in log
elif _is_windows_like():
_ = _compile(create_shared, "shared", "main.obj")
log = _compile(create_shared, "shared", "main.obj")
assert "Succeeded getting cached result" in log


@pytest.mark.skipif(shutil.which("ccache") is None, reason="ccache not installed")
def test_executable():
if _is_linux_like():
_ = _compile(create_executable, "executable", "main")
log = _compile(create_executable, "executable", "main")
assert "Succeeded getting cached result" in log
elif _is_windows_like():
_ = _compile(create_executable, "executable", "main.exe")
log = _compile(create_executable, "executable", "main.exe")
assert "Succeeded getting cached result" in log


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1fec2b3

Please sign in to comment.