Skip to content

Commit

Permalink
Merge pull request #7731 from niboshi/reporter-multithread
Browse files Browse the repository at this point in the history
Fix reporter for multi-thread use
  • Loading branch information
mergify[bot] committed Jul 12, 2019
2 parents 314548c + ab23f02 commit eabb66e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
24 changes: 17 additions & 7 deletions chainer/reporter.py
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import json
import threading
import typing as tp # NOQA
import warnings

Expand All @@ -17,6 +18,9 @@
import chainerx


_thread_local = threading.local()


def _copy_variable(value):
if isinstance(value, variable.Variable):
return copy.copy(value)
Expand Down Expand Up @@ -78,11 +82,11 @@ def __init__(self):

def __enter__(self):
"""Makes this reporter object current."""
_reporters.append(self)
_get_reporters().append(self)

def __exit__(self, exc_type, exc_value, traceback):
"""Recovers the previous reporter object to the current."""
_reporters.pop()
_get_reporters().pop()

@contextlib.contextmanager
def scope(self, observation):
Expand Down Expand Up @@ -171,12 +175,17 @@ def report(self, values, observer=None):
self.observation.update(values)


_reporters = [] # type: tp.Optional[tp.List[Reporter]]
def _get_reporters():
try:
reporters = _thread_local.reporters
except AttributeError:
reporters = _thread_local.reporters = []
return reporters


def get_current_reporter():
"""Returns the current reporter object."""
return _reporters[-1]
return _get_reporters()[-1]


def report(values, observer=None):
Expand Down Expand Up @@ -231,8 +240,9 @@ def __call__(self, x, y):
of the observed value.
"""
if _reporters:
current = _reporters[-1]
reporters = _get_reporters()
if reporters:
current = reporters[-1]
current.report(values, observer)


Expand All @@ -244,7 +254,7 @@ def report_scope(observation):
except that it does not make the reporter current redundantly.
"""
current = _reporters[-1]
current = _get_reporters()[-1]
old = current.observation
current.observation = observation
yield
Expand Down
32 changes: 32 additions & 0 deletions tests/chainer_tests/test_reporter.py
@@ -1,5 +1,7 @@
import contextlib
import tempfile
import threading
import time
import unittest

import numpy
Expand Down Expand Up @@ -28,6 +30,36 @@ def test_enter_exit(self):
self.assertIs(chainer.get_current_reporter(), reporter2)
self.assertIs(chainer.get_current_reporter(), reporter1)

def test_enter_exit_threadsafe(self):
# This test ensures reporter.__enter__ correctly stores the reporter
# in the thread-local storage.

def thread_func(reporter, record):
with reporter:
# Sleep for a tiny moment to cause an overlap of the context
# managers.
time.sleep(0.01)
record.append(chainer.get_current_reporter())

record1 = [] # The current repoter in each thread is stored here.
record2 = []
reporter1 = chainer.Reporter()
reporter2 = chainer.Reporter()
thread1 = threading.Thread(
target=thread_func,
args=(reporter1, record1))
thread2 = threading.Thread(
target=thread_func,
args=(reporter2, record2))
thread1.daemon = True
thread2.daemon = True
thread1.start()
thread2.start()
thread1.join()
thread2.join()
self.assertIs(record1[0], reporter1)
self.assertIs(record2[0], reporter2)

def test_scope(self):
reporter1 = chainer.Reporter()
reporter2 = chainer.Reporter()
Expand Down

0 comments on commit eabb66e

Please sign in to comment.