Skip to content

Commit

Permalink
Merge pull request #95 from rizar/next_epoch
Browse files Browse the repository at this point in the history
Reintroducing the `next_epoch` method
  • Loading branch information
bartvm committed Jan 14, 2015
2 parents 3a0c945 + e4fee30 commit ab8c6b6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 11 deletions.
41 changes: 31 additions & 10 deletions blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def reset(self, state):
self.close(state)
return self.open()

def next_epoch(self, state):
"""Switches the dataset state to the next epoch.
The default implementation for this method is to reset the state.
Returns
-------
state : object
The state for the next epoch.
"""
return self.reset(state)

def close(self, state):
"""Cleanly close the dataset e.g. close file handles."""
pass
Expand Down Expand Up @@ -215,6 +228,11 @@ def close(self):
"""Gracefully close the data stream, e.g. releasing file handles."""
pass

@abstractmethod
def next_epoch(self):
"""Switch the data stream to the next epoch."""
pass

@abstractmethod
def get_epoch_iterator(self, as_dict=False):
return DataIterator(self, self.iteration_scheme.get_request_iterator()
Expand Down Expand Up @@ -255,6 +273,7 @@ def __init__(self, dataset, **kwargs):
super(DataStream, self).__init__(**kwargs)
self.dataset = dataset
self.data_state = self.dataset.open()
self._fresh_state = True

@property
def sources(self):
Expand All @@ -265,22 +284,21 @@ def close(self):

def reset(self):
self.data_state = self.dataset.reset(self.data_state)
self._fresh_state = True

def next_epoch(self):
self.data_state = self.dataset.next_epoch(self.data_state)

def get_data(self, request=None):
"""Get data from the dataset."""
return self.dataset.get_data(self.data_state, request)

def get_epoch_iterator(self, **kwargs):
"""Get an epoch iterator for the data stream.
Notes
-----
This also calls the data stream's :meth:`reset` method. If you have
a data stream where a single epoch doesn't iterate over the entire
data set, you should overwrite this method.
"""
self.reset()
"""Get an epoch iterator for the data stream."""
if not self._fresh_state:
self.next_epoch()
else:
self._fresh_state = False
return super(DataStream, self).get_epoch_iterator(**kwargs)


Expand All @@ -300,6 +318,9 @@ def close(self):
def reset(self):
self.data_stream.reset()

def next_epoch(self):
self.data_stream.next_epoch()

def get_epoch_iterator(self, **kwargs):
"""Get an epoch iterator for the wrapped data set.
Expand Down
58 changes: 57 additions & 1 deletion tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from six.moves import zip

from blocks.datasets import ContainerDataset, DataStreamMapping
from blocks.datasets import ContainerDataset, DataStream, DataStreamMapping
from blocks.datasets.schemes import ConstantScheme, BatchSizeScheme


def test_dataset():
Expand Down Expand Up @@ -41,3 +42,58 @@ def test_sources_selection():
stream = ContainerDataset({'features': features, 'targets': targets},
sources=('targets',)).get_default_stream()
assert list(stream.get_epoch_iterator()) == list(zip(targets))


def test_data_driven_epochs():

class TestDataset(ContainerDataset):
sources = ('data',)
default_scheme = ConstantScheme(1)

def __init__(self):
self.data = [[1, 2, 3, 4],
[5, 6, 7, 8]]

def open(self):
epoch_iter = iter(self.data)
data_iter = iter(next(epoch_iter))
return (epoch_iter, data_iter)

def next_epoch(self, state):
try:
data_iter = iter(next(state[0]))
return (state[0], data_iter)
except StopIteration:
return self.open()


def get_data(self, state, request):
data = []
for i in range(request):
data.append(next(state[1]))
return (data,)

epochs = []
epochs.append([([1],), ([2],), ([3],), ([4],)])
epochs.append([([5],), ([6],), ([7],), ([8],)])
stream = TestDataset().get_default_stream()
assert list(stream.get_epoch_iterator()) == epochs[0]
assert list(stream.get_epoch_iterator()) == epochs[1]
assert list(stream.get_epoch_iterator()) == epochs[0]

stream.reset()
for i, epoch in zip(range(2), stream.epochs):
assert list(epoch) == epochs[i]

# test scheme reseting between epochs
class TestScheme(BatchSizeScheme):

def get_request_iterator(self):
return iter([1, 2, 1, 3])

epochs = []
epochs.append([([1],), ([2, 3],), ([4],)])
epochs.append([([5],), ([6, 7],), ([8],)])
stream = DataStream(TestDataset(), iteration_scheme=TestScheme())
for i, epoch in zip(range(2), stream.epochs):
assert list(epoch) == epochs[i]

0 comments on commit ab8c6b6

Please sign in to comment.