From d7690ab50a93f88a1fe8dc7a16799f578ca7bc28 Mon Sep 17 00:00:00 2001 From: Alex Rothberg Date: Sun, 28 Dec 2014 22:27:48 -0500 Subject: [PATCH] Added forced_even to BatchIterator ensure equal batch size --- nolearn/lasagne.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nolearn/lasagne.py b/nolearn/lasagne.py index 5c2c64c..bb7ccdf 100644 --- a/nolearn/lasagne.py +++ b/nolearn/lasagne.py @@ -41,8 +41,9 @@ def negative_log_likelihood(output, prediction): class BatchIterator(object): - def __init__(self, batch_size): + def __init__(self, batch_size, forced_even=False): self.batch_size = batch_size + self.forced_even = forced_even def __call__(self, X, y=None, test=False): self.X, self.y = X, y @@ -55,6 +56,8 @@ def __iter__(self): for i in range((n_samples + bs - 1) / bs): sl = slice(i * bs, (i + 1) * bs) Xb = self.X[sl] + if self.forced_even and len(Xb) != bs: + continue if self.y is not None: yb = self.y[sl] else: