Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch timeout mechanism to subprocess.run #659

Merged
merged 10 commits into from Jan 20, 2023
23 changes: 11 additions & 12 deletions backend/coreapp/compiler_wrapper.py
Expand Up @@ -15,7 +15,6 @@

from coreapp.platforms import Platform
import coreapp.util as util
from coreapp.util import exception_on_timeout

from .error import AssemblyError, CompilationError
from .models.scratch import Asm, Assembly
Expand Down Expand Up @@ -164,11 +163,11 @@ def compile_code(
# Run compiler
try:
st = round(time.time() * 1000)
compile_proc = exception_on_timeout(
settings.COMPILATION_TIMEOUT_SECONDS
)(sandbox.run_subprocess)(
compile_proc = sandbox.run_subprocess(
cc_cmd,
mounts=[compiler.path],
mounts=(
[compiler.path] if compiler.platform != platforms.DUMMY else []
),
shell=True,
env={
"PATH": PATH,
Expand All @@ -182,6 +181,7 @@ def compile_code(
"MWCIncludes": "/tmp",
"TMPDIR": "/tmp",
},
timeout=settings.COMPILATION_TIMEOUT_SECONDS,
)
et = round(time.time() * 1000)
logging.debug(f"Compilation finished in: {et - st} ms")
Expand All @@ -195,8 +195,8 @@ def compile_code(
# Shlex issue?
logging.debug("Compilation failed: %s", e)
raise CompilationError(str(e))
except TimeoutError as e:
raise CompilationError(str(e))
except subprocess.TimeoutExpired as e:
raise CompilationError("Compilation failed: timeout expired")

if not object_path.exists():
error_msg = (
Expand Down Expand Up @@ -245,9 +245,7 @@ def assemble_asm(platform: Platform, asm: Asm) -> Assembly:

# Run assembler
try:
assemble_proc = exception_on_timeout(settings.ASSEMBLY_TIMEOUT_SECONDS)(
sandbox.run_subprocess
)(
assemble_proc = sandbox.run_subprocess(
platform.assemble_cmd,
mounts=[],
shell=True,
Expand All @@ -256,11 +254,12 @@ def assemble_asm(platform: Platform, asm: Asm) -> Assembly:
"INPUT": sandbox.rewrite_path(asm_path),
"OUTPUT": sandbox.rewrite_path(object_path),
},
timeout=settings.ASSEMBLY_TIMEOUT_SECONDS,
)
except subprocess.CalledProcessError as e:
raise AssemblyError.from_process_error(e)
except TimeoutError as e:
raise AssemblyError(str(e))
except subprocess.TimeoutExpired as e:
raise AssemblyError("Timeout expired")

# Assembly failed
if assemble_proc.returncode != 0:
Expand Down
13 changes: 13 additions & 0 deletions backend/coreapp/compilers.py
Expand Up @@ -33,6 +33,8 @@
SWITCH,
)

import platform as platform_stdlib

logger = logging.getLogger(__name__)

CONFIG_PY = "config.py"
Expand Down Expand Up @@ -98,6 +100,12 @@ def available(self) -> bool:
return settings.DUMMY_COMPILER


@dataclass(frozen=True)
class DummyLongRunningCompiler(DummyCompiler):
def available(self) -> bool:
return settings.DUMMY_COMPILER and platform_stdlib.system() != "Windows"


@dataclass(frozen=True)
class ClangCompiler(Compiler):
flags: ClassVar[Flags] = COMMON_CLANG_FLAGS
Expand Down Expand Up @@ -167,6 +175,10 @@ def preset_from_name(name: str) -> Optional[Preset]:

DUMMY = DummyCompiler(id="dummy", platform=platforms.DUMMY, cc="")

DUMMY_LONGRUNNING = DummyLongRunningCompiler(
id="dummy_longrunning", platform=platforms.DUMMY, cc="sleep 3600"
)

# GBA
AGBCC = GCCCompiler(
id="agbcc",
Expand Down Expand Up @@ -699,6 +711,7 @@ def preset_from_name(name: str) -> Optional[Preset]:

_all_compilers: List[Compiler] = [
DUMMY,
DUMMY_LONGRUNNING,
# GBA
AGBCC,
OLD_AGBCC,
Expand Down
7 changes: 1 addition & 6 deletions backend/coreapp/decompiler_wrapper.py
Expand Up @@ -6,7 +6,6 @@

from coreapp.m2c_wrapper import M2CError, M2CWrapper
from coreapp.platforms import Platform
from coreapp.util import exception_on_timeout
from django.conf import settings

logger = logging.getLogger(__name__)
Expand All @@ -31,11 +30,7 @@ def decompile(
if len(asm.splitlines()) > MAX_M2C_ASM_LINES:
return "/* Too many lines to decompile; please run m2c manually */"
try:
ret = exception_on_timeout(settings.DECOMPILATION_TIMEOUT_SECONDS)(
M2CWrapper.decompile
)(asm, context, compiler, platform.arch)
except TimeoutError as e:
ret = f"/* Timeout error while running m2c */\n{default_source_code}"
ret = M2CWrapper.decompile(asm, context, compiler, platform.arch)
except M2CError as e:
ret = f"{e}\n{default_source_code}"
except Exception:
Expand Down
19 changes: 8 additions & 11 deletions backend/coreapp/diff_wrapper.py
Expand Up @@ -8,7 +8,6 @@

from coreapp.platforms import DUMMY, Platform
from coreapp.flags import ASMDIFF_FLAG_PREFIX
from coreapp.util import exception_on_timeout
from django.conf import settings

from .compiler_wrapper import DiffResult, PATH
Expand Down Expand Up @@ -100,17 +99,16 @@ def get_objdump_target_function_flags(
raise NmError(f"No nm command for {platform.id}")

try:
nm_proc = exception_on_timeout(settings.OBJDUMP_TIMEOUT_SECONDS)(
sandbox.run_subprocess
)(
nm_proc = sandbox.run_subprocess(
[platform.nm_cmd] + [sandbox.rewrite_path(target_path)],
shell=True,
env={
"PATH": PATH,
},
timeout=settings.OBJDUMP_TIMEOUT_SECONDS,
)
except TimeoutError as e:
raise NmError(str(e))
except subprocess.TimeoutExpired as e:
raise NmError("Timeout expired")
except subprocess.CalledProcessError as e:
raise NmError.from_process_error(e)

Expand Down Expand Up @@ -165,19 +163,18 @@ def run_objdump(

if platform.objdump_cmd:
try:
objdump_proc = exception_on_timeout(
settings.OBJDUMP_TIMEOUT_SECONDS
)(sandbox.run_subprocess)(
objdump_proc = sandbox.run_subprocess(
platform.objdump_cmd.split()
+ flags
+ [sandbox.rewrite_path(target_path)],
shell=True,
env={
"PATH": PATH,
},
timeout=settings.OBJDUMP_TIMEOUT_SECONDS,
)
except TimeoutError as e:
raise ObjdumpError(str(e))
except subprocess.TimeoutExpired as e:
raise ObjdumpError("Timeout expired")
except subprocess.CalledProcessError as e:
raise ObjdumpError.from_process_error(e)
else:
Expand Down
6 changes: 3 additions & 3 deletions backend/coreapp/sandbox.py
Expand Up @@ -69,9 +69,7 @@ def sandbox_command(self, mounts: List[Path], env: Dict[str, str]) -> List[str]:
"--env", "PATH=/usr/bin:/bin",
"--cwd", "/tmp",
"--rlimit_fsize", "soft",
"--rlimit_nofile", "soft",
"--rlimit_cpu", "30", # seconds
"--time_limit", "30", # seconds
"--rlimit_nofile", "soft",
# the following are settings that can be removed once we are done with wine
"--bindmount_ro", f"{settings.WINEPREFIX}:/wine",
"--env", "WINEDEBUG=-all",
Expand All @@ -98,6 +96,7 @@ def run_subprocess(
mounts: Optional[List[Path]] = None,
env: Optional[Dict[str, str]] = None,
shell: bool = False,
timeout: Optional[float] = None,
) -> subprocess.CompletedProcess[str]:
mounts = mounts if mounts is not None else []
env = env if env is not None else {}
Expand Down Expand Up @@ -129,4 +128,5 @@ def run_subprocess(
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=timeout,
)
32 changes: 30 additions & 2 deletions backend/coreapp/tests.py
Expand Up @@ -12,7 +12,7 @@
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase
from coreapp.compilers import Language
from coreapp.compilers import DummyCompiler, Language
from coreapp import compilers, platforms

from coreapp.compiler_wrapper import CompilerWrapper
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_dummy_compiler(self) -> None:
len(result.elf_object), 0, "The compilation result should be non-null"
)

@parameterized.expand(input=[(c,) for c in compilers.available_compilers()]) # type: ignore
@parameterized.expand(input=[(c,) for c in compilers.available_compilers() if not isinstance(c, DummyCompiler)], skip_on_empty=True) # type: ignore
def test_all_compilers(self, compiler: Compiler) -> None:
"""
Ensure that we can run a simple compilation/diff for all available compilers
Expand Down Expand Up @@ -555,6 +555,34 @@ def test_all_compilers(self, compiler: Compiler) -> None:
self.assertTrue("rows" in diff)
self.assertGreater(len(diff["rows"]), 0)

@requiresCompiler(compilers.DUMMY_LONGRUNNING)
def test_compiler_timeout(self) -> None:
with self.settings(COMPILATION_TIMEOUT_SECONDS=3):
scratch_dict = {
"compiler": compilers.DUMMY_LONGRUNNING.id,
"platform": platforms.DUMMY.id,
"context": "",
"target_asm": "asm(AAAAAAAA)",
}

# Test that we can create a scratch
scratch = self.create_scratch(scratch_dict)

compile_dict = {
"slug": scratch.slug,
"compiler": compilers.DUMMY_LONGRUNNING.id,
"compiler_flags": "",
"source_code": "source(AAAAAAAA)",
}

# Test that we can compile a scratch
response = self.client.post(
reverse("scratch-compile", kwargs={"pk": scratch.slug}), compile_dict
)

self.assertFalse(response.json()["success"])
self.assertIn("timeout expired", response.json()["compiler_output"].lower())


class DecompilationTests(BaseTestCase):
@requiresCompiler(GCC281)
Expand Down
65 changes: 1 addition & 64 deletions backend/coreapp/util.py
@@ -1,20 +1,8 @@
import hashlib
import logging
import time
import dill

import django

import multiprocessing
import functools
import platform

# For reasons of thread safety, guincorn refuses to let us join processes forked from a worker thread
# To get around this, we opt to spawn a fresh process instead.
mp = multiprocessing.get_context("spawn")

from typing import Tuple, TypeVar, Callable, Any, cast
from queue import Queue
from typing import Tuple

logger = logging.getLogger(__name__)

Expand All @@ -24,54 +12,3 @@

def gen_hash(key: Tuple[str, ...]) -> str:
return hashlib.sha256(str(key + (_startup_time,)).encode("utf-8")).hexdigest()


F = TypeVar("F", bound=Callable[..., Any])

# Python 3.10+ should allow this to be typed more concretely
# (see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators)

# Windows requires multiprocessing processes to be in top-level scope
def worker(queue: Queue[Any], func: bytes, args: Any, kwargs: Any) -> Any:
try:
# As we're in a new, spawn'ed environment, we have to do the bare minimum initalization ourselved
# (i.e. the django app registry)
django.setup()

ret = dill.loads(func)(*args, **kwargs)
queue.put(ret)
except Exception as e:
queue.put(e)


def exception_on_timeout(timeout_seconds: float) -> Callable[[F], F]:
def timeout_inner(func: F) -> F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# If the timeout is 0 or less, call the function directly without a timeout
if timeout_seconds <= 0:
return func(*args, **kwargs)

queue: Queue[Any] = mp.Queue()

# On Windows, multiprocessing uses pickle under the hood to serialize arguments
# It doesn't play nicely with arbitary functions, so we explicitly use its
# more versatile cousin (dill) to handle the serialization ourselves
p = mp.Process(target=worker, args=(queue, dill.dumps(func), args, kwargs))
p.start()
p.join(timeout_seconds)

if p.is_alive():
# The process has hanged - terminate, and throw an error
p.terminate()
p.join()
raise TimeoutError("Process timed out")
else:
ret = queue.get()
if isinstance(ret, Exception):
raise ret
return ret

return cast(F, wrapper)

return timeout_inner
4 changes: 0 additions & 4 deletions backend/decompme/settings.py
Expand Up @@ -33,7 +33,6 @@
COMPILATION_CACHE_SIZE=(int, 100),
WINEPREFIX=(str, "/tmp/wine"),
COMPILATION_TIMEOUT_SECONDS=(int, 10),
DECOMPILATION_TIMEOUT_SECONDS=(int, 5),
ASSEMBLY_TIMEOUT_SECONDS=(int, 3),
OBJDUMP_TIMEOUT_SECONDS=(int, 3),
TIMEOUT_SCALE_FACTOR=(int, 1),
Expand Down Expand Up @@ -211,8 +210,5 @@
COMPILATION_TIMEOUT_SECONDS = (
env("COMPILATION_TIMEOUT_SECONDS", int) * TIMEOUT_SCALE_FACTOR
)
DECOMPILATION_TIMEOUT_SECONDS = (
env("DECOMPILATION_TIMEOUT_SECONDS", int) * TIMEOUT_SCALE_FACTOR
)
ASSEMBLY_TIMEOUT_SECONDS = env("ASSEMBLY_TIMEOUT_SECONDS", int) * TIMEOUT_SCALE_FACTOR
OBJDUMP_TIMEOUT_SECONDS = env("OBJDUMP_TIMEOUT_SECONDS", int) * TIMEOUT_SCALE_FACTOR