Skip to content

Commit

Permalink
ENH: same function to dump/load zipped + unzipped
Browse files Browse the repository at this point in the history
We cary around the file handle to avoid race conditions with open files
on the disk.
  • Loading branch information
GaelVaroquaux committed Dec 22, 2011
1 parent 3a42862 commit dbcf1e6
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 59 deletions.
2 changes: 0 additions & 2 deletions joblib/__init__.py
Expand Up @@ -105,8 +105,6 @@
from .hashing import hash
from .numpy_pickle import dump
from .numpy_pickle import load
from .numpy_pickle import dumpz
from .numpy_pickle import loadz
from .parallel import Parallel
from .parallel import delayed
from .parallel import cpu_count
146 changes: 98 additions & 48 deletions joblib/numpy_pickle.py
Expand Up @@ -13,11 +13,13 @@
import shutil
import tempfile
import zipfile
import cStringIO
import warnings

if sys.version_info[0] == 3:
from pickle import _Unpickler as Unpickler
from cStringIO import StringIO as BytesIO
else:
from io import BytesIO
from pickle import Unpickler


Expand Down Expand Up @@ -82,14 +84,15 @@ class NumpyUnpickler(Unpickler):
"""
dispatch = Unpickler.dispatch.copy()

def __init__(self, filename, mmap_mode=None):
def __init__(self, filename, file_handle=None, mmap_mode=None):
self._filename = os.path.basename(filename)
self.mmap_mode = mmap_mode
self._dirname = os.path.dirname(filename)
file_handle = self._open_file(self._filename)
if isinstance(file_handle, basestring):
# To hanlde memmap, we need to have file names
file_handle = open(file_handle, 'rb')
if file_handle is None:
file_handle = self._open_file(self._filename)
if isinstance(file_handle, basestring):
# To handle memmap, we need to have file names
file_handle = open(file_handle, 'rb')
self.file_handle = file_handle
Unpickler.__init__(self, self.file_handle)
import numpy as np
Expand Down Expand Up @@ -129,33 +132,59 @@ class ZipNumpyUnpickler(NumpyUnpickler):
""" A subclass of our Unpickler to unpickle on the fly from zips.
"""

def __init__(self, filename):
kwargs = dict(compression=zipfile.ZIP_DEFLATED, mode='r')
def __init__(self, file_handle):
kwargs = dict(compression=zipfile.ZIP_DEFLATED)
if sys.version_info >= (2, 5):
kwargs['allowZip64'] = True
self._zip_file = zipfile.ZipFile(filename, **kwargs)
self._zip_file = zipfile.ZipFile(file_handle, **kwargs)
NumpyUnpickler.__init__(self, 'joblib_dump.pkl',
mmap_mode=None)

def _open_file(self, name):
"Return the path of the given file in our store"
decompression_buffer = cStringIO.StringIO(
decompression_buffer = BytesIO(
self._zip_file.read(os.path.join('dump_file', name)))
return decompression_buffer


###############################################################################
# Utility functions

def dump(value, filename):
def dump(value, filename, zipped=False):
""" Persist an arbitrary Python object into a filename, with numpy arrays
saved as separate .npy files.
Parameters
-----------
value: any Python object
The object to store to disk
filename: string
The name of the file in which it is to be stored
zipped: boolean, optional
Whether to compress the data on the disk or not
Returns
-------
filenames: list of strings
The list of file names in which the data is stored. If zipped
is false, each array is stored in a different file.
See Also
--------
joblib.load : corresponding loader
joblib.dump : function to save the object in a compressed dump
Notes
-----
zipped file take extra disk space during the dump, and extra
memory during the loading.
"""
if zipped:
return _dump_zipped(value, filename)
else:
return _dump(value, filename)


def _dump(value, filename):
try:
pickler = NumpyPickler(filename)
pickler.dump(value)
Expand All @@ -166,37 +195,9 @@ def dump(value, filename):
return pickler._filenames


def load(filename, mmap_mode=None):
""" Reconstruct a Python object and the numpy arrays it contains from
a persisted file.
This function loads the numpy array files saved separately. If
the mmap_mode argument is given, it is passed to np.save and
arrays are loaded as memmaps. As a consequence, the reconstructed
object might not match the original pickled object.
See Also
--------
joblib.dump : function to save an object
joblib.loadz: load from a compressed file dump
"""
try:
unpickler = NumpyUnpickler(filename, mmap_mode=mmap_mode)
obj = unpickler.load()
finally:
if 'unpickler' in locals() and hasattr(unpickler, 'file'):
unpickler.file.close()
return obj


def dumpz(value, filename):
def _dump_zipped(value, filename):
""" Persist an arbitrary Python object into a compressed zip
filename.
See Also
--------
joblib.loadz : corresponding loader
joblib.dump : function to save an object
"""
kwargs = dict(compression=zipfile.ZIP_DEFLATED, mode='w')
if sys.version_info >= (2, 5):
Expand All @@ -207,7 +208,7 @@ def dumpz(value, filename):
tmp_dir = tempfile.mkdtemp(prefix='joblib-',
dir=os.path.dirname(filename))
try:
dump(value, os.path.join(tmp_dir, 'joblib_dump.pkl'))
_dump(value, os.path.join(tmp_dir, 'joblib_dump.pkl'))
for sub_file in os.listdir(tmp_dir):
# We use a different arcname (archive name) to avoid having
# the name of our tmp_dir in the archive
Expand All @@ -220,21 +221,70 @@ def dumpz(value, filename):
return [filename]


def loadz(filename):
def load(filename, mmap_mode=None):
""" Reconstruct a Python object and the numpy arrays it contains from
a persisted zipped dump.
a persisted file.
Parameters
-----------
filename: string
The name of the file from which to load the object
mmap_mode: {None, 'r+', 'r', 'w+', 'c'}, optional
If not None, the arrays are memory-mapped from the disk. This
mode has not effect for zipped files. Note that in this
case the reconstructed object might not longer match exactly
the originaly pickled object.
Returns
-------
result: any Python object
The object stored in the file.
See Also
--------
joblib.dump : function to save an object
joblib.loadz: load from a compressed file dump
Notes
-----
This function loads the numpy array files saved separately. If
the mmap_mode argument is given, it is passed to np.save and
arrays are loaded as memmaps. As a consequence, the reconstructed
object might not match the original pickled object.
"""
# Code to detect zip files
_ZIP_PREFIX = 'PK\x03\x04'
try:
# Py3k compatibility
from numpy.compat import asbytes
_ZIP_PREFIX = asbytes(_ZIP_PREFIX)
except ImportError:
pass

file_handle = open(filename, 'rb')
if file_handle.read(len(_ZIP_PREFIX)) == _ZIP_PREFIX:
if mmap_mode is not None:
warnings.warn('file "%(filename)s" appears to be a zip, '
'ignoring mmap_mode "%(mmap_mode)s" flag passed'
% locals(),
Warning, stacklevel=2)
unpickler = ZipNumpyUnpickler(file_handle=file_handle)
else:
# Pickling needs file-handles at the beginning of the file
file_handle.seek(0)
unpickler = NumpyUnpickler(filename,
file_handle=file_handle,
mmap_mode=mmap_mode)

try:
unpickler = ZipNumpyUnpickler(filename)
obj = unpickler.load()
finally:
if 'unpickler' in locals() and hasattr(unpickler, 'file'):
unpickler._zip_file.close()
if 'unpickler' in locals():
if hasattr(unpickler, 'file'):
unpickler.file.close()
if hasattr(unpickler, '_zip_file'):
unpickler._zip_file.close()
return obj


18 changes: 9 additions & 9 deletions joblib/test/test_numpy_pickle.py
Expand Up @@ -113,11 +113,10 @@ def test_standard_types():
#""" Test pickling and saving with standard types.
#"""
filename = env['filename']
for dump, load in [(numpy_pickle.dump, numpy_pickle.load),
(numpy_pickle.dumpz, numpy_pickle.loadz)]:
for zipped in [True, False]:
for member in typelist:
dump(member, filename)
_member = load(filename)
numpy_pickle.dump(member, filename, zipped=zipped)
_member = numpy_pickle.load(filename)
# We compare the pickled instance to the reloaded one only if it
# can be compared to a copied one
if member == copy.deepcopy(member):
Expand All @@ -128,20 +127,21 @@ def test_standard_types():
def test_numpy_persistence():
filename = env['filename']
a = np.random.random(10)
for dump, load in [(numpy_pickle.dump, numpy_pickle.load),
(numpy_pickle.dumpz, numpy_pickle.loadz)]:
for zipped in [True, False]:
for obj in (a,), (a, a), [a, a, a]:
filenames = dump(obj, filename)
if dump is numpy_pickle.dump:
filenames = numpy_pickle.dump(obj, filename, zipped=zipped)
if not zipped:
# Check that one file was created per array
yield nose.tools.assert_equal, len(filenames), len(obj) + 1
# Check that these files do exist
for file in filenames:
yield nose.tools.assert_true, \
os.path.exists(os.path.join(env['dir'], file))
else:
yield nose.tools.assert_equal, len(filenames), 1

# Unpickle the object
obj_ = load(filename)
obj_ = numpy_pickle.load(filename)
# Check that the items are indeed arrays
for item in obj_:
yield nose.tools.assert_true, isinstance(item, np.ndarray)
Expand Down

0 comments on commit dbcf1e6

Please sign in to comment.