Skip to content
This repository has been archived by the owner on May 7, 2020. It is now read-only.

Commit

Permalink
Add "preload" option to Polyps912 dataset (#3)
Browse files Browse the repository at this point in the history
* Add "preload" option to Polyps912 dataset
* Add tests for preloading
  • Loading branch information
lamblin authored and fvisin committed Jul 4, 2017
1 parent 46260ef commit 29451cb
Showing 1 changed file with 89 additions and 21 deletions.
110 changes: 89 additions & 21 deletions dataset_loaders/images/polyps912.py
Expand Up @@ -31,6 +31,9 @@ class Polyps912Dataset(ThreadedDataset):
which_set: string
A string in ['train', 'val', 'valid', 'test'], corresponding to
the set to be returned.
preload: bool
Whether to preload all the images in memory (as a list of
ndarrays) to minimize disk access.
References
----------
Expand Down Expand Up @@ -75,7 +78,7 @@ def filenames(self):
self._filenames = filenames
return self._filenames

def __init__(self, which_set='train', *args, **kwargs):
def __init__(self, which_set='train', preload=False, *args, **kwargs):

# Put which_set in canonical form: training, validation or testing
if which_set in ("train", "training"):
Expand All @@ -87,12 +90,57 @@ def __init__(self, which_set='train', *args, **kwargs):
else:
raise ValueError("Unknown set requested: %s" % which_set)

self.preload = preload

# Define the images and mask paths
self.image_path = os.path.join(self.path, self.which_set, 'images')
self.mask_path = os.path.join(self.path, self.which_set, 'masks2')

if self.preload:
self._preload_data()

super(Polyps912Dataset, self).__init__(*args, **kwargs)

def _load_image(self, image_batch, mask_batch, img_name, prefix=None):
"""Load one image and append the data to image/mask_batch
Parameters
----------
image_batch: list
The new image data will be appended to that argument.
mask_batch: list
The new mask data will be appended to that argument.
img_name: string
Name of the new image to load.
prefix: string (optional)
Prefix for the new image to load.
"""
from skimage import io
img = io.imread(os.path.join(self.image_path, img_name + ".bmp"))
img = img.astype(floatX) / 255.
mask = np.array(io.imread(os.path.join(self.mask_path,
img_name + ".tif")),
dtype='int32')

image_batch.append(img)
mask_batch.append(mask)

def _preload_data(self):
"""Preload all data in memory.
The images will be stored as a list of ndarrays in self.image_all,
the masks will be in self.mask_all, in the same order as
self.filename.
In addition, self.image_name_to_idx will contain a dictionary
mapping the root of the image name to its index.
"""
self.image_all = []
self.mask_all = []
self.image_name_to_idx = {}
for idx, img_name in enumerate(self.filenames):
self._load_image(self.image_all, self.mask_all, img_name)
self.image_name_to_idx[img_name] = idx

def get_names(self):
"""Return a dict of names, per prefix/subset."""
return {'default': self.filenames}
Expand All @@ -106,22 +154,17 @@ def load_sequence(self, sequence):
labels, their subset (i.e. category, clip, prefix) and their
filenames.
"""
from skimage import io
image_batch, mask_batch, filename_batch = [], [], []

for prefix, img_name in sequence:

img = io.imread(os.path.join(self.image_path, img_name + ".bmp"))
img = img.astype(floatX) / 255.

mask = np.array(io.imread(os.path.join(self.mask_path,
img_name + ".tif")))
mask = mask.astype('int32')

# Add to minibatch
image_batch.append(img)
mask_batch.append(mask)
filename_batch.append(img_name)
if self.preload:
for prefix, img_name in sequence:
idx = self.image_name_to_idx[img_name]
image_batch.append(self.image_all[idx])
mask_batch.append(self.mask_all[idx])
filename_batch.append(img_name)
else:
for prefix, img_name in sequence:
self._load_image(image_batch, mask_batch, img_name, prefix)
filename_batch.append(img_name)

ret = {}
ret['data'] = np.array(image_batch)
Expand All @@ -131,7 +174,18 @@ def load_sequence(self, sequence):
return ret


def test():
def test(preload=False):
# Instrument Polyps912Dataset._load_image to count its calls
orig_load_image = Polyps912Dataset._load_image

def _load_image(*args, **kwargs):
_load_image.count += 1
return orig_load_image(*args, **kwargs)
_load_image.count = 0

Polyps912Dataset._load_image = _load_image

start_build = time.time()
trainiter = Polyps912Dataset(
which_set='train',
batch_size=10,
Expand All @@ -142,7 +196,8 @@ def test():
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=False)
use_threads=False,
preload=preload)

validiter = Polyps912Dataset(
which_set='valid',
Expand All @@ -152,7 +207,8 @@ def test():
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=False)
use_threads=False,
preload=preload)

testiter = Polyps912Dataset(
which_set='test',
Expand All @@ -162,7 +218,8 @@ def test():
return_one_hot=True,
return_01c=True,
return_list=True,
use_threads=False)
use_threads=False,
preload=preload)

# Get number of classes
nclasses = trainiter.nclasses
Expand Down Expand Up @@ -192,6 +249,9 @@ def test():
test_nsamples, test_batch_size, test_nbatches))

start = time.time()
print("Time to build{} the datasets: {}".format(
" and preload" if preload else "",
start - start_build))
tot = 0
max_epochs = 1

Expand Down Expand Up @@ -219,9 +279,17 @@ def test():
tot += part
print("Minibatch %s time: %s (%s)" % (str(mb), part, tot))

if preload:
expected_count = train_nsamples + valid_nsamples + test_nsamples
else:
expected_count = train_nsamples * max_epochs
assert _load_image.count == expected_count, (
_load_image.count, expected_count)


def run_tests():
test()
for preload in (False, True):
test(preload)


if __name__ == '__main__':
Expand Down

0 comments on commit 29451cb

Please sign in to comment.