Skip to content

Commit

Permalink
Merge d87ef88 into 0b34ed0
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Apr 12, 2018
2 parents 0b34ed0 + d87ef88 commit 679d126
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 33 deletions.
13 changes: 7 additions & 6 deletions chainer/training/extensions/_snapshot.py
Expand Up @@ -81,11 +81,12 @@ def snapshot(trainer):
def _snapshot_object(trainer, target, filename, savefun):
fn = filename.format(trainer)
prefix = 'tmp' + fn

tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
fd, tmppath = tempfile.mkstemp(prefix=prefix, dir=trainer.out)
try:
savefun(tmppath, target)
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
except Exception:
os.close(fd)
os.remove(tmppath)
raise
os.close(fd)
shutil.move(tmppath, os.path.join(trainer.out, fn))
@@ -1,4 +1,3 @@
import os
import unittest

import mock
Expand All @@ -22,30 +21,4 @@ def test_trigger(self):
self.assertEqual(snapshot.trigger, (1, 'epoch'))


class TestSnapshotSaveFile(unittest.TestCase):

def setUp(self):
self.trainer = testing.get_trainer_with_mock_updater()
self.trainer.out = '.'
self.trainer._done = True

def tearDown(self):
if os.path.exists('myfile.dat'):
os.remove('myfile.dat')

def test_save_file(self):
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
snapshot(self.trainer)

self.assertTrue(os.path.exists('myfile.dat'))

def test_clean_up_tempdir(self):
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
snapshot(self.trainer)

left_tmps = [fn for fn in os.listdir('.')
if fn.startswith('tmpmyfile.dat')]
self.assertEqual(len(left_tmps), 0)


testing.run_module(__name__, __file__)

0 comments on commit 679d126

Please sign in to comment.