Skip to content

Commit

Permalink
Fix generate_default_trace post API updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrofin committed Mar 3, 2022
1 parent fc54961 commit cad0a9d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
35 changes: 21 additions & 14 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def wait(self):
self._event.wait()


def kill_process_ignore_exceptions(p: subprocess.Popen):
# kill the process and ignore exceptions. Exceptions would be thrown if the
# process has already been killed/finished (which is inherently in a race
# condition with us killing it)
try:
p.kill()
finally:
return # pylint: disable=lost-exception


class WorkerCancellationManager:
"""A thread-safe object that can be used to signal cancellation.
Expand All @@ -113,28 +123,20 @@ def __init__(self):
self._done = False
self._lock = threading.Lock()

def _kill(self, p: subprocess.Popen):
# kill the process and ignore any exceptions due to e.g. this being in a
# race condition with the process terminating.
try:
p.kill()
finally:
return # pylint: disable=lost-exception

def register_process(self, p: subprocess.Popen):
"""Register a process for potential cancellation."""
with self._lock:
if not self._done:
self._processes.add(p)
return
self._kill(p)
kill_process_ignore_exceptions(p)

def signal(self):
"""Cancel any pending work."""
with self._lock:
self._done = True
for p in self._processes:
self._kill(p)
kill_process_ignore_exceptions(p)

def unregister_process(self, p: subprocess.Popen):
with self._lock:
Expand All @@ -144,7 +146,7 @@ def unregister_process(self, p: subprocess.Popen):

def start_cancellable_process(
cmdline: List[str],
timeout: int,
timeout: float,
cancellation_manager: Optional[WorkerCancellationManager],
want_output: bool = False) -> Optional[bytes]:
"""Start a cancellable process.
Expand All @@ -167,9 +169,14 @@ def start_cancellable_process(
if cancellation_manager:
cancellation_manager.register_process(p)

retcode = p.wait(timeout=timeout)
if cancellation_manager:
cancellation_manager.unregister_process(p)
try:
retcode = p.wait(timeout=timeout)
except subprocess.TimeoutExpired as e:
kill_process_ignore_exceptions(p)
raise e
finally:
if cancellation_manager:
cancellation_manager.unregister_process(p)
if retcode != 0:
raise ProcessKilledError(
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
Expand Down
15 changes: 15 additions & 0 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

"""Tests for compiler_opt.rl.compilation_runner."""

import os
import string
import subprocess
import time
from unittest import mock

from absl import flags
import tensorflow as tf

from google.protobuf import text_format
Expand Down Expand Up @@ -224,5 +227,17 @@ def test_start_subprocess_output(self):
self.fail('output should have been non-empty')
self.assertNotEmpty(output_str)

def test_timeout_kills_process(self):
sentinel_file = os.path.join(flags.FLAGS.test_tmpdir,
'test_timeout_kills_test_file')
if os.path.exists(sentinel_file):
os.remove(sentinel_file)
with self.assertRaises(subprocess.TimeoutExpired):
compilation_runner.start_cancellable_process(
['bash', '-c', 'sleep 1s ; touch ' + sentinel_file],
timeout=0.5, cancellation_manager=None)
time.sleep(2)
self.assertFalse(os.path.exists(sentinel_file))

if __name__ == '__main__':
tf.test.main()
35 changes: 21 additions & 14 deletions compiler_opt/tools/generate_default_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
import os
import queue
import random
# see https://bugs.python.org/issue33315 - we do need these types, but must
# currently use them as string annotations
from typing import List, Tuple, Optional # pylint:disable=unused-import

from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from tf_agents.system import system_multiprocessing as multiprocessing

from compiler_opt.rl import compilation_runner
from compiler_opt.rl.inlining import inlining_runner

flags.DEFINE_string('data_path', None, 'Path to folder containing IR files.')
Expand All @@ -48,7 +52,9 @@
FLAGS = flags.FLAGS


def worker(runner, work_queue: queue.Queue, results_queue: queue.Queue):
def worker(runner: compilation_runner.CompilationRunner,
work_queue: 'queue.Queue[Tuple[str, ...]]',
results_queue: 'queue.Queue[Optional[List[str]]]'):
"""What each worker process does.
Each worker picks a workitem from the work_queue, process it, and deposits
Expand All @@ -67,11 +73,11 @@ def worker(runner, work_queue: queue.Queue, results_queue: queue.Queue):
except queue.Empty:
return
try:
record = runner.collect_data(module_triple, '', None)
results_queue.put((module_triple, record))
(records, _, _) = runner.collect_data(module_triple, '', None)
results_queue.put(records)
except: # pylint: disable=bare-except
logging.error('Failed to compile %s.', module_triple)
results_queue.put((module_triple, None))
results_queue.put(None)


def main(_):
Expand All @@ -97,7 +103,7 @@ def main(_):
sizes_and_paths = [(os.path.getsize(p + '.bc'), p) for p in module_paths]
sizes_and_paths.sort(reverse=True)
sorted_module_paths = [p for _, p in sizes_and_paths]
file_paths = [
module_specs = [
tuple([p + suffix for suffix in file_suffix]) for p in sorted_module_paths
]

Expand All @@ -107,10 +113,10 @@ def main(_):
with tf.io.TFRecordWriter(FLAGS.output_path) as file_writer:
ctx = multiprocessing.get_context()
m = ctx.Manager()
results_queue = m.Queue()
work_queue = m.Queue()
for path in file_paths:
work_queue.put(path)
results_queue: 'queue.Queue[Optional[List[str]]]' = m.Queue()
work_queue: 'queue.Queue[Tuple[str, ...]]' = m.Queue()
for module_spec in module_specs:
work_queue.put(module_spec)
processes = [
ctx.Process(
target=functools.partial(worker, runner, work_queue, results_queue))
Expand All @@ -121,13 +127,14 @@ def main(_):
p.start()

total_successful_examples = 0
total_work = len(file_paths)
total_work = len(module_specs)
total_failed_examples = 0
for _ in range(0, total_work):
_, record = results_queue.get()
if record and len(record[0]) > 0:
records = results_queue.get()
if records:
total_successful_examples += 1
file_writer.write(record[0][0])
for r in records:
file_writer.write(r)
else:
total_failed_examples += 1

Expand All @@ -137,7 +144,7 @@ def main(_):
total_failed_examples, total_work)

print('%d of %d modules succeeded.' %
(total_successful_examples, len(file_paths)))
(total_successful_examples, len(module_specs)))
for p in processes:
p.join()

Expand Down

0 comments on commit cad0a9d

Please sign in to comment.