Skip to content

Commit

Permalink
Merge pull request #103 from rizar/as_dict_for_epochs
Browse files Browse the repository at this point in the history
Rename epochs to iterate_epochs and add as_dict argument
  • Loading branch information
rizar committed Jan 16, 2015
2 parents d40cce1 + 22450a0 commit 862ea8a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,7 @@ def get_epoch_iterator(self, as_dict=False):
if self.iteration_scheme else None,
as_dict=as_dict)

@property
def epochs(self):
def iterate_epochs(self, as_dict=False):
"""Allow iteration through all epochs.
Notes
Expand All @@ -256,7 +255,7 @@ def epochs(self):
"""
while True:
yield self.get_epoch_iterator()
yield self.get_epoch_iterator(as_dict=as_dict)


class DataStream(AbstractDataStream):
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def test_dataset():
assert list(epoch) == list(zip(data))

# Check if iterating over multiple epochs works
for i, epoch in zip(range(2), stream.epochs):
for i, epoch in zip(range(2), stream.iterate_epochs()):
assert list(epoch) == list(zip(data))
for i, epoch in enumerate(stream.epochs):
for i, epoch in enumerate(stream.iterate_epochs()):
assert list(epoch) == list(zip(data))
if i == 1:
break
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_data(self, state, request):
assert list(stream.get_epoch_iterator()) == epochs[0]

stream.reset()
for i, epoch in zip(range(2), stream.epochs):
for i, epoch in zip(range(2), stream.iterate_epochs()):
assert list(epoch) == epochs[i]

# test scheme reseting between epochs
Expand All @@ -95,5 +95,5 @@ def get_request_iterator(self):
epochs.append([([1],), ([2, 3],), ([4],)])
epochs.append([([5],), ([6, 7],), ([8],)])
stream = DataStream(TestDataset(), iteration_scheme=TestScheme())
for i, epoch in zip(range(2), stream.epochs):
for i, epoch in zip(range(2), stream.iterate_epochs()):
assert list(epoch) == epochs[i]

0 comments on commit 862ea8a

Please sign in to comment.