Skip to content

Commit

Permalink
Merge pull request #4524 from niboshi/bp-4461-master
Browse files Browse the repository at this point in the history
[backport] Save snapshot in the OS default permission
  • Loading branch information
kmaehashi committed Mar 27, 2018
2 parents f659c20 + e4e32da commit cb74ad5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
13 changes: 6 additions & 7 deletions chainer/training/extensions/_snapshot.py
Expand Up @@ -81,12 +81,11 @@ def snapshot(trainer):
def _snapshot_object(trainer, target, filename, savefun):
fn = filename.format(trainer)
prefix = 'tmp' + fn
fd, tmppath = tempfile.mkstemp(prefix=prefix, dir=trainer.out)

tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
try:
savefun(tmppath, target)
except Exception:
os.close(fd)
os.remove(tmppath)
raise
os.close(fd)
shutil.move(tmppath, os.path.join(trainer.out, fn))
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
@@ -1,3 +1,4 @@
import os
import unittest

import mock
Expand All @@ -21,4 +22,30 @@ def test_trigger(self):
self.assertEqual(snapshot.trigger, (1, 'epoch'))


class TestSnapshotSaveFile(unittest.TestCase):

def setUp(self):
self.trainer = testing.get_trainer_with_mock_updater()
self.trainer.out = '.'
self.trainer._done = True

def tearDown(self):
if os.path.exists('myfile.dat'):
os.remove('myfile.dat')

def test_save_file(self):
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
snapshot(self.trainer)

self.assertTrue(os.path.exists('myfile.dat'))

def test_clean_up_tempdir(self):
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
snapshot(self.trainer)

left_tmps = [fn for fn in os.listdir('.')
if fn.startswith('tmpmyfile.dat')]
self.assertEqual(len(left_tmps), 0)


testing.run_module(__name__, __file__)

0 comments on commit cb74ad5

Please sign in to comment.