Skip to content

Commit

Permalink
made pickle multi-processing safe
Browse files Browse the repository at this point in the history
  • Loading branch information
justanhduc committed May 15, 2021
1 parent 2c4140b commit 123ff28
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions neural_monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from shutil import copyfile
from collections import namedtuple, deque
import functools
from multiprocessing import Lock

from typing import (
List,
Expand Down Expand Up @@ -47,6 +48,7 @@
_TRACKS = collections.OrderedDict()
hooks = {}
lock = utils.ReadWriteLock()
plock = Lock()
Git = namedtuple('Git', ('branch', 'commit_id', 'commit_message', 'commit_datetime', 'commit_user', 'commit_email'))

# setup logger
Expand Down Expand Up @@ -331,7 +333,8 @@ def __setattr__(self, attr, val):
def initialize(self, model_name: Optional[str] = None, root: Optional[str] = None,
current_folder: Optional[str] = None, print_freq: Optional[int] = 1,
num_iters: Optional[int] = None, prefix: Optional[str] = 'run',
use_tensorboard: Optional[bool] = True, with_git: Optional[bool] = False) -> None:
use_tensorboard: Optional[bool] = True, with_git: Optional[bool] = False,
not_found_warn=True) -> None:
"""
:param model_name:
Expand Down Expand Up @@ -396,7 +399,7 @@ def initialize(self, model_name: Optional[str] = None, root: Optional[str] = Non

if os.path.exists(self.current_folder):
lock.acquire_read()
self.load_state()
self.load_state(not_found_warn=not_found_warn)
lock.release_read()
else:
os.makedirs(self.current_folder, exist_ok=True)
Expand Down Expand Up @@ -426,47 +429,59 @@ def initialize(self, model_name: Optional[str] = None, root: Optional[str] = Non

self._thread.start()

def load_state(self) -> None:
def load_state(self, not_found_warn=False) -> None:
self.current_run = os.path.basename(self.current_folder)

try:
plock.acquire()
lock.acquire_read()
log = self.read_log()
lock.release_read()
plock.release()

try:
self.num_stats = log['num']
except KeyError:
root_logger.warning('No record found for `num`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `num`', exc_info=True)

try:
self.num_stats = log['mat']
except KeyError:
root_logger.warning('No record found for `mat`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `mat`', exc_info=True)

try:
self.hist_stats = log['hist']
except KeyError:
root_logger.warning('No record found for `hist`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `hist`', exc_info=True)

if self.num_iters is None:
try:
self.num_iters = log['num_iters']
except KeyError:
root_logger.warning('No record found for `num_iters`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `num_iters`', exc_info=True)

try:
self.iter = log['iter']
except KeyError:
root_logger.warning('No record found for `iter`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `iter`', exc_info=True)

try:
self.epoch = log['epoch']
except KeyError:
if self.num_iters:
self.epoch = self.iter // self.num_iters
else:
root_logger.warning('No record found for `epoch`', exc_info=True)
if not_found_warn:
root_logger.warning('No record found for `epoch`', exc_info=True)

except FileNotFoundError:
root_logger.warning(f'`{self._log_file}` not found in `{self.file_folder}`', exc_info=True)
if not_found_warn:
root_logger.warning(f'`{self._log_file}` not found in `{self.file_folder}`', exc_info=True)

def _get_new_folder(self, path):
runs = [folder for folder in os.listdir(path) if folder.startswith(self.prefix)]
Expand Down Expand Up @@ -1251,6 +1266,7 @@ def _flush(self):
# scatter point set(s)
self._scatter(points)

plock.acquire()
lock.acquire_write()
with open(os.path.join(self.file_folder, self._log_file), 'wb') as f:
dump_dict = {
Expand All @@ -1264,6 +1280,7 @@ def _flush(self):
pkl.dump(dump_dict, f, pkl.HIGHEST_PROTOCOL)
f.close()
lock.release_write()
plock.release()

iter_show = 'Epoch {} Iteration {}/{} ({:.2f}%)'.format(
epoch + 1, it % self.num_iters, self.num_iters,
Expand Down

0 comments on commit 123ff28

Please sign in to comment.