Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP:Minibatch adjustor #174

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 54 additions & 1 deletion pylearn2/training_algorithms/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,16 @@ def __call__(self, algorithm):
algorithm.learning_rate.set_value(new_lr)

class MomentumAdjustor(TrainExtension):
"""
The class that scales the momentum linearly, starting from epoch start and
saturating to final_momentum at epoch saturate.
"""
def __init__(self, final_momentum, start, saturate):
"""
final_momentum: the momentum coefficient to use at the end
of learning.
start: the epoch on which to start growing the momentum coefficient.
saturate: the epoch on which the moment should reach its final value
saturate: the epoch on which the momentum should reach its final value
"""

if saturate < start:
Expand Down Expand Up @@ -636,6 +640,55 @@ def current_momentum(self):
alpha = 1.
return self._init_momentum * (1.-alpha)+alpha*self.final_momentum

class BatchSizeAdjustor(TrainExtension):
"""
This class adjust the minibatch size at each epoch. Based on the
final batch size, batch size can grow or decay the batch size.
"""
def __init__(self, final_batch_size, start, saturate):
"""
final_batch_size: the bathc size to use at the end of the learning.
start: the epoch on which to start adjusting the minibatch size.
saturate: the epoch on which the batch_size should reach its final value
"""
if saturate < start:
raise TypeError("Batch size can't saturate at its maximum value before it starts increasing.")

self.__dict__.update(locals())
del self.self
self._initialized = False
self._count = 0

def on_monitor(self, model, dataset, algorithm):
if not self._initialized:
self._init_batch_size = algorithm.batch_size
if self.final_batch_size < self._init_batch_size:
print "Decaying minibatch size."
else:
print "Growing minibatch size."

self._initialized = True
self._count += 1
algorithm.batch_size = self.current_batch_size()

def current_batch_size(self):
w = self.saturate - self.start

if w == 0:
# saturate=start, so just jump straight to final momentum
if self._count >= self.start:
return self.final_batch_size
return self._init_batch_size

alpha = float(self._count - self.start) / float(w)
if alpha < 0.:
alpha = 0.
if alpha > 1.:
alpha = 1.
batch_size = np.floor(self._init_batch_size * (1. - alpha) + alpha * self.final_batch_size)
print "Current minibatch size: ", batch_size
return batch_size

class OneOverEpoch(TrainExtension):
"""
Scales the learning rate like one over # epochs
Expand Down