Skip to content

Commit

Permalink
Merge pull request #113 from bartvm/constant_iterator
Browse files Browse the repository at this point in the history
Re-introduce constant iterator
  • Loading branch information
bartvm committed Jan 18, 2015
2 parents 3f4d5e4 + f719746 commit 752b30f
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions blocks/datasets/schemes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from abc import ABCMeta, abstractmethod

import six
Expand Down Expand Up @@ -70,10 +69,26 @@ def __init__(self, batch_size, times=None):
self.times = times

def get_request_iterator(self):
if self.times is None:
return itertools.repeat(self.batch_size)
else:
return itertools.repeat(self.batch_size, self.times)
return ConstantIterator(self.batch_size, self.times)


class ConstantIterator(six.Iterator):
def __init__(self, batch_size, times=None):
self.batch_size = batch_size
self.times = times
if times is not None:
self.current = 0

def __iter__(self):
return self

def __next__(self):
if self.times is not None:
if self.current == self.times:
raise StopIteration
else:
self.current += 1
return self.batch_size


class SequentialScheme(BatchScheme):
Expand Down

0 comments on commit 752b30f

Please sign in to comment.