diff --git a/nolearn/lasagne.py b/nolearn/lasagne.py index 5c2c64c..bb7ccdf 100644 --- a/nolearn/lasagne.py +++ b/nolearn/lasagne.py @@ -41,8 +41,9 @@ def negative_log_likelihood(output, prediction): class BatchIterator(object): - def __init__(self, batch_size): + def __init__(self, batch_size, forced_even=False): self.batch_size = batch_size + self.forced_even = forced_even def __call__(self, X, y=None, test=False): self.X, self.y = X, y @@ -55,6 +56,8 @@ def __iter__(self): for i in range((n_samples + bs - 1) / bs): sl = slice(i * bs, (i + 1) * bs) Xb = self.X[sl] + if self.forced_even and len(Xb) != bs: + continue if self.y is not None: yb = self.y[sl] else: