diff --git a/dataset_iterator/datasetIO/memoryIO.py b/dataset_iterator/datasetIO/memoryIO.py index 43555f0..dee30dc 100644 --- a/dataset_iterator/datasetIO/memoryIO.py +++ b/dataset_iterator/datasetIO/memoryIO.py @@ -1,16 +1,64 @@ from .datasetIO import DatasetIO import threading +import numpy as np +from multiprocessing import managers, shared_memory, Value +from ..shared_memory import to_shm, get_idx_from_shm + +_MEMORYIO_SHM_MANAGER = {} +_MEMORYIO_UID = None class MemoryIO(DatasetIO): - def __init__(self, datasetIO:DatasetIO): + def __init__(self, datasetIO: DatasetIO, use_shm: bool = True): super().__init__() self.datasetIO = datasetIO self.__lock__ = threading.Lock() - self.datasets=dict() + self.datasets = dict() + self.use_shm = use_shm + global _MEMORYIO_UID + if _MEMORYIO_UID is None: + try: + _MEMORYIO_UID = Value("i", 0) + except OSError: # In this case the OS does not allow us to use multiprocessing. We resort to an int for indexing. + _MEMORYIO_UID = 0 + + if isinstance(_MEMORYIO_UID, int): + self.uid = _MEMORYIO_UID + _MEMORYIO_UID += 1 + else: + # Doing Multiprocessing.Value += x is not process-safe. + with _MEMORYIO_UID.get_lock(): + self.uid = _MEMORYIO_UID.value + _MEMORYIO_UID.value += 1 + if use_shm: + self._start_shm_manager() + + def _start_shm_manager(self): + global _MEMORYIO_SHM_MANAGER + _MEMORYIO_SHM_MANAGER[self.uid] = managers.SharedMemoryManager() + _MEMORYIO_SHM_MANAGER[self.uid].start() + self.shm_manager_on = True + + def _stop_shm_manager(self): + global _MEMORYIO_SHM_MANAGER + if _MEMORYIO_SHM_MANAGER[self.uid] is not None: + _MEMORYIO_SHM_MANAGER[self.uid].shutdown() + _MEMORYIO_SHM_MANAGER[self.uid].join() + _MEMORYIO_SHM_MANAGER[self.uid] = None + self.shm_manager_on = False + + def _to_shm(self, array): + global _MEMORYIO_SHM_MANAGER + shapes, dtypes, shm_name, _ = to_shm(_MEMORYIO_SHM_MANAGER[self.uid], array) + return shapes[0], dtypes[0], shm_name def close(self): + if self.use_shm: + for shma in self.datasets.values(): + shma.unlink() self.datasets.clear() self.datasetIO.close() + if self.use_shm and self.shm_manager_on: + self._stop_shm_manager() def get_dataset_paths(self, channel_keyword, group_keyword): return self.datasetIO.get_dataset_paths(channel_keyword, group_keyword) @@ -18,8 +66,13 @@ def get_dataset_paths(self, channel_keyword, group_keyword): def get_dataset(self, path): if path not in self.datasets: with self.__lock__: + if self.use_shm and not self.shm_manager_on: + self._start_shm_manager() if path not in self.datasets: - self.datasets[path] = self.datasetIO.get_dataset(path)[:] # load into memory + if self.use_shm: + self.datasets[path] = ShmArrayWrapper(*self._to_shm(self.datasetIO.get_dataset(path)[:])) + else: + self.datasets[path] = ArrayWrapper(self.datasetIO.get_dataset(path)[:]) # load into memory return self.datasets[path] def get_attribute(self, path, attribute_name): @@ -36,3 +89,36 @@ def __contains__(self, key): def get_parent_path(self, path): self.datasetIO.get_parent_path(path) + + +class ArrayWrapper: + def __init__(self, array): + self.array = array + self.shape = array.shape + + def __getitem__(self, item): + return np.copy(self.array[item]) + + def __len__(self): + return self.shape[0] + + +class ShmArrayWrapper: + def __init__(self, shape, dtype, shm_name): + self.shape = shape + self.dtype = dtype + self.shm_name = shm_name + + def __getitem__(self, item): + assert isinstance(item, (int, np.integer)), f"only integer index supported: recieved: {item} of type: {type(item)}" + return get_idx_from_shm(item, (self.shape,), (self.dtype,), self.shm_name, array_idx=0) + + def __len__(self): + return self.shape[0] + + def unlink(self): + try: + existing_shm = shared_memory.SharedMemory(self.shm_name) + existing_shm.unlink() + except Exception: + pass \ No newline at end of file diff --git a/dataset_iterator/hard_sample_mining.py b/dataset_iterator/hard_sample_mining.py index f70eadb..368c891 100644 --- a/dataset_iterator/hard_sample_mining.py +++ b/dataset_iterator/hard_sample_mining.py @@ -40,11 +40,12 @@ def close(self): self.iterator.close() def on_epoch_begin(self, epoch, logs=None): - self.wait_for_me.clear() # will block + if self.proba_per_metric is not None: + self.wait_for_me.clear() # will block def on_epoch_end(self, epoch, logs=None): - if self.period==1 or (epoch + 1 + self.start_epoch) % self.period == 0: - if (epoch > 0 or not self.skip_first) and epoch + self.start_epoch >= self.start_from_epoch: + if self.period == 1 or (epoch + 1 + self.start_epoch) % self.period == 0: + if (epoch > 0 or not self.skip_first) and epoch + 1 + self.start_epoch >= self.start_from_epoch: self.target_iterator.close() self.iterator.open() metrics = self.compute_metrics() @@ -63,7 +64,7 @@ def on_epoch_end(self, epoch, logs=None): proba = self.proba_per_metric # set probability to iterator in case of multiprocessing iwth OrderedEnqueeur this will be taken into account only a next epoch has iterator has already been sent to processes at this stage self.target_iterator.set_index_probability(proba) - self.wait_for_me.set() # release block + self.wait_for_me.set() # release block def on_train_end(self, logs=None): self.close() diff --git a/setup.py b/setup.py index dc9bad2..f34f9cd 100644 --- a/setup.py +++ b/setup.py @@ -5,14 +5,14 @@ setuptools.setup( name="dataset_iterator", - version="0.4.1", + version="0.4.2", author="Jean Ollion", author_email="jean.ollion@polytechnique.org", description="Keras-style data iterator for images contained in dataset files such as hdf5 or PIL readable files. Images can be contained in several files.", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/jeanollion/dataset_iterator.git", - download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.1/dataset_iterator-0.4.1.tar.gz', + download_url='https://github.com/jeanollion/dataset_iterator/releases/download/v0.4.2/dataset_iterator-0.4.2.tar.gz', keywords=['Iterator', 'Dataset', 'Image', 'Numpy'], packages=setuptools.find_packages(), classifiers=[ #https://pypi.org/classifiers/