Skip to content

Commit

Permalink
Merge pull request #4528 from niboshi/tempfile-permission-log-report
Browse files Browse the repository at this point in the history
Fix temporary file permission issue in LogReport
  • Loading branch information
okuta committed Mar 29, 2018
2 parents d3bc261 + e98eead commit 74f2042
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 30 deletions.
18 changes: 6 additions & 12 deletions chainer/dataset/download.py
Expand Up @@ -2,11 +2,12 @@
import os
import shutil
import sys
import tempfile

import filelock
from six.moves.urllib import request

from chainer import utils


_dataset_root = os.environ.get('CHAINER_DATASET_ROOT',
os.path.expanduser('~/.chainer/dataset'))
Expand Down Expand Up @@ -100,15 +101,12 @@ def cached_download(url):
if os.path.exists(cache_path):
return cache_path

temp_root = tempfile.mkdtemp(dir=cache_root)
try:
with utils.tempdir(dir=cache_root) as temp_root:
temp_path = os.path.join(temp_root, 'dl')
sys.stderr.write('Downloading from {}...\n'.format(url))
request.urlretrieve(url, temp_path)
with filelock.FileLock(lock_path):
shutil.move(temp_path, cache_path)
finally:
shutil.rmtree(temp_root)

return cache_path

Expand Down Expand Up @@ -140,10 +138,6 @@ def cache_or_load_file(path, creator, loader):
if os.path.exists(path):
return loader(path)

file_name = os.path.basename(path)
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, file_name)

try:
os.makedirs(_dataset_root)
except OSError:
Expand All @@ -152,12 +146,12 @@ def cache_or_load_file(path, creator, loader):

lock_path = os.path.join(_dataset_root, '_create_lock')

try:
with utils.tempdir() as temp_dir:
file_name = os.path.basename(path)
temp_path = os.path.join(temp_dir, file_name)
content = creator(temp_path)
with filelock.FileLock(lock_path):
if not os.path.exists(path):
shutil.move(temp_path, path)
finally:
shutil.rmtree(temp_dir)

return content
8 changes: 2 additions & 6 deletions chainer/testing/serializer.py
@@ -1,8 +1,7 @@
import os
import shutil
import tempfile

from chainer import serializers
from chainer import utils


def save_and_load(src, dst, filename, saver, loader):
Expand All @@ -21,13 +20,10 @@ def save_and_load(src, dst, filename, saver, loader):
object.
"""
tempdir = tempfile.mkdtemp()
try:
with utils.tempdir() as tempdir:
path = os.path.join(tempdir, filename)
saver(path, src)
loader(path, dst)
finally:
shutil.rmtree(tempdir)


def save_and_load_npz(src, dst):
Expand Down
9 changes: 3 additions & 6 deletions chainer/training/extensions/_snapshot.py
@@ -1,9 +1,9 @@
import os
import shutil
import tempfile

from chainer.serializers import npz
from chainer.training import extension
from chainer import utils


def snapshot_object(target, filename, savefun=npz.save_npz):
Expand Down Expand Up @@ -82,10 +82,7 @@ def _snapshot_object(trainer, target, filename, savefun):
fn = filename.format(trainer)
prefix = 'tmp' + fn

tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
try:
with utils.tempdir(prefix=prefix, dir=trainer.out) as tmpdir:
tmppath = os.path.join(tmpdir, fn)
savefun(tmppath, target)
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
13 changes: 7 additions & 6 deletions chainer/training/extensions/log_report.py
@@ -1,7 +1,6 @@
import json
import os
import shutil
import tempfile
import warnings

import six
Expand All @@ -10,6 +9,7 @@
from chainer import serializer as serializer_module
from chainer.training import extension
from chainer.training import trigger as trigger_module
from chainer import utils


class LogReport(extension.Extension):
Expand Down Expand Up @@ -96,12 +96,13 @@ def __call__(self, trainer):
# write to the log file
if self._log_name is not None:
log_name = self._log_name.format(**stats_cpu)
fd, path = tempfile.mkstemp(prefix=log_name, dir=trainer.out)
with os.fdopen(fd, 'w') as f:
json.dump(self._log, f, indent=4)
with utils.tempdir(prefix=log_name, dir=trainer.out) as tempd:
path = os.path.join(tempd, 'log.json')
with open(path, 'w') as f:
json.dump(self._log, f, indent=4)

new_path = os.path.join(trainer.out, log_name)
shutil.move(path, new_path)
new_path = os.path.join(trainer.out, log_name)
shutil.move(path, new_path)

# reset the summary for the next output
self._init_summary()
Expand Down
16 changes: 16 additions & 0 deletions chainer/utils/__init__.py
@@ -1,3 +1,7 @@
import contextlib
import shutil
import tempfile

import numpy

# import classes and functions
Expand Down Expand Up @@ -31,3 +35,15 @@ def force_type(dtype, value):
return value.astype(dtype, copy=False)
else:
return value


@contextlib.contextmanager
def tempdir(**kwargs):
# A context manager that defines a lifetime of a temporary directory.
ignore_errors = kwargs.pop('ignore_errors', False)

temp_dir = tempfile.mkdtemp(**kwargs)
try:
yield temp_dir
finally:
shutil.rmtree(temp_dir, ignore_errors=ignore_errors)

0 comments on commit 74f2042

Please sign in to comment.