-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras update after single batch which exceeds the GPU memory #3556
Comments
I think it's solved by this BVLC/caffe#1663. |
I solved this by change the optimizer.py. |
@wx405557858 I'm curious how you did this. I hacked something together that seems to work, but I'd be interested in a better way. Also it might be useful to have Keras. Here is how I did it below. class NadamAccum(Optimizer):
'''
Nesterov Adam optimizer: Much like Adam is essentially RMSprop with momentum,
Nadam is Adam RMSprop with Nesterov momentum.
Default parameters follow those provided in the paper.
It is recommended to leave the parameters of this optimizer
at their default values.
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
'''
def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, schedule_decay=0.004, accum_iters=1, **kwargs):
super(NadamAccum, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0.)
self.m_schedule = K.variable(1.)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.schedule_decay = schedule_decay
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
t = (self.iterations + 1.)/self.accum_iters
accum_switch = K.floor((self.accum_iters - K.mod(self.iterations + 1., self.accum_iters))/self.accum_iters)
# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = self.beta_1 * (1. - 0.5 * (K.pow(0.96, t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (1. - 0.5 * (K.pow(0.96, (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, accum_switch*m_schedule_new + (1-accum_switch)*self.m_schedule))
shapes = [x.shape for x in K.batch_get_value(params)]
ms = [K.zeros(shape) for shape in shapes]
vs = [K.zeros(shape) for shape in shapes]
gs = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + ms + vs
for p, gp, m, v, ga in zip(params, grads, ms, vs, gs):
g = (ga + gp)/self.accum_iters
# the following equations given in [1]
g_prime = g / (1. - m_schedule_new)
m_t = self.beta_1 * m + (1. - self.beta_1) * g
m_t_prime = m_t / (1. - m_schedule_next)
v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
self.updates.append(K.update(m, (1-accum_switch)*m + accum_switch*m_t))
self.updates.append(K.update(v, (1-accum_switch)*v + accum_switch*v_t))
self.updates.append(K.update(ga, (1-accum_switch)*(ga + gp)))
p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append(K.update(p, (1-accum_switch)*p + accum_switch*new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay,
'accum_iters': self.accum_iters}
base_config = super(NadamAccum, self).get_config()
return dict(list(base_config.items()) + list(config.items())) |
@the-moliver Yeah, we did exactly the same! I have a flag calculated by (self.iteration % accum_iters) == 0 . It will turn into 1 after accum_iters batches. I think maybe can write a wrapper to wrap every optimizer and change the updates base on accum_iters. Or just implement each optimizer's _accum version. There's only several optimizers. class Adam_accumulate(Optimizer):
'''Adam accumulate optimizer.
Default parameters follow those provided in the original paper. Wait for several mini-batch to update
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=5, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
print t.eval()
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * m))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
# print self.updates
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items())) |
@wx405557858 I tried using your code, the loss seems to explode. |
@raghakot It works for my model. I assume it should be universal. Would the loss converge with normal Adam optimizer in your case? |
Yes. It converges with regular Adam. @the-moliver version seems to work too. |
@raghakot Thanks for your pointing out. I'm not quite sure what's the exact problem. But it's nice to know the-moliver's solution works for you. |
Set flag = K.cast(flag, dtype='float32') and it works. Thanks wx405557858 |
Thank you for your sharing. I am new here, but I have several trouble at first. What is the relation between |
final batch_size = accum_iters * original batch_size |
Hi, @wx405557858 ,could you please show your |
@soon-will the |
@wx405557858 shouldn't m be v? |
@the-moliver, I am getting an error K.floor doesnt exist on this line: Was K.floor and K.mod recently removed from Keras backend? Cant find them here: https://github.com/fchollet/keras/tree/master/keras/backend |
@jackkwok |
This feature extremely useful and must be added in official repository. |
Code by @wx405557858 with fixes. I checked it in my project and it seemed to work fine: from keras.optimizers import Optimizer
from keras import backend as K
import numpy as np
class Adam_accumulate(Optimizer):
'''Adam accumulate optimizer.
Default parameters follow those provided in the original paper. Wait for several mini-batch to update
# Arguments
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor.
# References
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=10, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
flag = K.cast(flag, dtype='float32')
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items())) |
Thanks @ZFTurbo for the fixes. This is version of code for Keras 2.0.8 with fixed constraints issue and get_updates parameters. from keras.optimizers import Optimizer
from keras import backend as K
import numpy as np
class Adam_accumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, accum_iters=20, **kwargs):
super(Adam_accumulate, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.accum_iters = K.variable(accum_iters)
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [(self.iterations, self.iterations + 1)]
t = self.iterations + 1
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
ms = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
vs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
gs = [K.variable(np.zeros(K.get_value(p).shape)) for p in params]
self.weights = ms + vs
for p, g, m, v, gg in zip(params, grads, ms, vs, gs):
flag = K.equal(self.iterations % self.accum_iters, 0)
flag = K.cast(flag, dtype='float32')
gg_t = (1 - flag) * (gg + g)
m_t = (self.beta_1 * m) + (1. - self.beta_1) * (gg + flag * g) / self.accum_iters
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square((gg + flag * g) / self.accum_iters)
p_t = p - flag * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append((m, flag * m_t + (1 - flag) * m))
self.updates.append((v, flag * v_t + (1 - flag) * v))
self.updates.append((gg, gg_t))
new_p = p_t
# apply constraints
if getattr(p, 'constraint', None) is not None:
c = constraints[p]
new_p = c(new_p)
self.updates.append((p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon}
base_config = super(Adam_accumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items())) |
Hi Guys, thanks for the previous code, i have been trying to replicate the same for SGD with nestrov, class SGDAccum(Optimizer):
"""Stochastic gradient descent optimizer.
Includes support for momentum,
learning rate decay, and Nesterov momentum.
# Arguments
lr: float >= 0. Learning rate.
momentum: float >= 0. Parameter updates momentum.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""
def __init__(self, lr=0.01, momentum=0., decay=0.,
nesterov=False, accum_iters=1, **kwargs):
super(SGDAccum, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, name='iterations')
self.lr = K.variable(lr, name='lr')
self.momentum = K.variable(momentum, name='momentum')
self.decay = K.variable(decay, name='decay')
self.accum_iters = K.variable(accum_iters)
self.initial_decay = decay
self.nesterov = nesterov
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
lr *= (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
accum_switch = K.equal(self.iterations % self.accum_iters, 0)
accum_switch = K.cast(accum_switch, dtype='float32')
# momentum
shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
temp_grads = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + moments
for p, cg, m, tg in zip(params, grads, moments, temp_grads):
g = cg + tg
v = self.momentum * m - (lr * g / self.accum_iters) # velocity
self.updates.append(K.update(m, (1 - accum_switch) * m + accum_switch * v))
self.updates.append(K.update(tg, (1 - accum_switch) * g))
if self.nesterov:
new_p = p + self.momentum * v - (lr * g / self.accum_iters)
else:
new_p = p + v
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, (1 - accum_switch) * p + accum_switch * new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'momentum': float(K.get_value(self.momentum)),
'decay': float(K.get_value(self.decay)),
'nesterov': self.nesterov,
'accum_iters': self.accum_iters}
base_config = super(SGDAccum, self).get_config()
return dict(list(base_config.items()) + list(config.items())) Can someone please verify that it look's about right ? |
@gamers5a your function doesn't work in latest Keras version. There were to much changes in Adam function between 1.2.1 and 2.0.8 versions. Hope someone fix it as well. @viig99 I believe your functions works just fine. Here is the logs of 3 runs: SGD (default, batch=32):
SGDAccum (accum_iters=1, batch=32)
SGDAccum (accum_iters=2, batch=16)
But there is problem with model.save() method: |
That will have to be included in the optimizers.py file, in the serialize and de-serialize methods. I would like to point out that batch accumulation is an incredibly useful option and should be provided with the main package, can we improve the visibility on this, or is their a better / preferred way to restructure the code ? |
@viig99 may be you can try to add your changes directly in SGD optimizer in official repository as pull request. Because SGDAccum with default |
https://www.hastebin.com/efabasizas.py this is the one i was using, i am pretty sure there are better ways of doing things, for now i am saving weights and restarting networks with those weights. |
@noagarcia @viig99 However, even I could save the model, when I load the model, it still ended with error: |
First of all, very happy that I found this thread - great stuff! Thanks all for sharing :) Wondering - performance wise - isn't it better to use K.switch instead of For example, something of this spirit: maybe_assign_params = K.switch(
self.iterations%self.accum_iters == 0,
K.update(p, new_p),
K.update_add(tiny_dummy_param,0) #or some other dummy no-op
)
self.updates.append(maybe_assign_params) to avoid doing K.update of all parameters into themselves for every n-1/n of the steps. |
Can it be used along with batch normalization or do I need to change it a bit?? |
@rydevera3 @phobrain I am using this code to test optimizer: import keras.backend as K
import numpy as np
import tensorflow as tf
import random as rn
# Reproducibility
# https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
np.random.seed(42)
rn.seed(12345)
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
tf.set_random_seed(1234)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
from keras import models, layers
model = models.Sequential()
model.add(layers.Conv2D(8, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
from keras.datasets import mnist
from keras.utils import to_categorical
(train_images, train_labels), _ = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
model_2 = models.clone_model(model)
model_2.set_weights(model.get_weights())
model_3 = models.clone_model(model)
model_3.set_weights(model.get_weights())
optimizer = Adam(lr=0.0001)
model.compile(
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with Adam, 1st run:')
model.fit(train_images, train_labels, epochs=5, batch_size=32, shuffle=False)
optimizer_2 = Adam(lr=0.0001)
model_2.compile(
optimizer=optimizer_2,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with Adam, 2nd run:')
model_2.fit(train_images, train_labels, epochs=5, batch_size=32, shuffle=False)
optimizer_3 = AdamAccumulate(lr=0.0001, accum_iters=8)
model_3.compile(
optimizer=optimizer_3,
loss='categorical_crossentropy',
metrics=['accuracy'])
print('\nTraining with AdamAccumulate:')
model_3.fit(train_images, train_labels, epochs=5, batch_size=4, shuffle=False) Also run it with env variables: What I got:
As you can see Adam reproduces itsef exactly, but AdamAccumulate gives different results. I noticed some mistakes in the code of optimizer, will post my version later, just need to fix some strange behavior. Hard to debug TF code) |
Hey everyone, I've corrected some bugs in @nik-ko 's implementation (mainly the learning rate which wasn't adjusting correctly). Here it is:
And using @alexeydevederkin 's test, everything seems to work almost perfectly:
|
With @Dutil 's code, I don't see my earlier-mentioned "complaint about something being used twice," tho other model details are different by now so that could be the cause, and I get reasonable results with my siamese model using keyword vectors, doubling the batch of 1024. In the same siamese model using VGG16, doubling batch of 32, on 1st try my held-back positive test cases all had the same value (0.01187402) which is binary-correct but too fishy. Rerunning, got two creditable epochs with hold-out testing between. But I see about the same run profile as for adagrad, so wondering if it makes sense (blindly QA'ing for now). adagrad 11/4 15080/15080 3414s 226ms/step AdamAcc 11/21 15394/15394 3190s 207ms/step
Will try @alexeydevederkin 's version next. |
My version of Adam optimizer with accumulated gradient (slightly different from @Dutil 's - closer results to import keras.backend as K
from keras.legacy import interfaces
from keras.optimizers import Optimizer
class AdamAccumulate(Optimizer):
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, accum_iters=1, **kwargs):
if accum_iters < 1:
raise ValueError('accum_iters must be >= 1')
super(AdamAccumulate, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
self.accum_iters = K.variable(accum_iters, K.dtype(self.iterations))
self.accum_iters_float = K.cast(self.accum_iters, K.floatx())
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
completed_updates = K.cast(K.tf.floordiv(self.iterations, self.accum_iters), K.floatx())
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * completed_updates))
t = completed_updates + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
# self.iterations incremented after processing a batch
# batch: 1 2 3 4 5 6 7 8 9
# self.iterations: 0 1 2 3 4 5 6 7 8
# update_switch = 1: x x (if accum_iters=4)
update_switch = K.equal((self.iterations + 1) % self.accum_iters, 0)
update_switch = K.cast(update_switch, K.floatx())
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats
for p, g, m, v, vhat, tg in zip(params, grads, ms, vs, vhats, gs):
sum_grad = tg + g
avg_grad = sum_grad / self.accum_iters_float
m_t = (self.beta_1 * m) + (1. - self.beta_1) * avg_grad
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(avg_grad)
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, (1 - update_switch) * vhat + update_switch * vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, (1 - update_switch) * m + update_switch * m_t))
self.updates.append(K.update(v, (1 - update_switch) * v + update_switch * v_t))
self.updates.append(K.update(tg, (1 - update_switch) * sum_grad))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, (1 - update_switch) * p + update_switch * new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulate, self).get_config()
return dict(list(base_config.items()) + list(config.items())) Tests:
I'm not very familiar with Tensorflow, but maybe it could be further improved (for speed) by using conditional updates instead of updating variables with the same values. |
With @alexeydevederkin 's version on the VGG case, python 2.7:
|
@phobrain Seems like an issue with different behavior of division You could try to change computation of completed_updates = K.cast(K.tf.floordiv(self.iterations, self.accum_iters), K.floatx()) Does it work now in python2? |
It works, and epoch 1 positive holdout accuracy is 'normal' (76%; 82% maybe highest epoch 1 seen). Again, do these timings make sense? Run time is the the same as with plain adagrad. Given my naivete, I wonder if it could be me, but can't see any way to screw it up. :-) ... aha, unless I'm supposed to double batch_size as well?
I let @Dutil 's run a few epochs til the holdout tests got worse, and it didn't make a difference in accuracy range. Since I have a BatchNormalization, the results are not rigorous, so will try without it next (using keyword vectors, since VGG is so slow.. which is where I need it in the end, since 2 224x224 pics at a time means batch=32), so I can compare the two AdamAccumulate versions, and for a while stop wondering whether Khashoggi was the reincarnation of Archduke Franz Ferdinand. Here are the epoch 1 holdout pos/neg results for both versions, same range as different adadelta runs:
Epoch 4 of @Dutil where I bailed:
|
Keyword (binary) vector results, same training pairs of pics involved, batch_size=1024, no BatchNormalization after 1st below. 'Epochs' are 3 epochs each, and go til crude criteria not satisfied.
Adadelta w/ BatchNormalization [Epochs/runs restarted due to decrease in holdout accuracy]
Adadelta
accum_iters=2, batch_size = 1024 (as above cases)
accum_iters=2, batch_size = 512
accum_iters=3, batch_size = 1024
accum_iters=2, batch_size = 1024
accum_iters=2, batch_size=512
accum_iters=3, batch_size = 1024
accum_iters=4, batch_size = 1024
The test case answers my question about batch size: it is reduced by the acum_iters factor:
|
Run time of optimizer with accumulation should be similar to run time of optimizer without accumulation with the same batch_size (but not effective batch size). For example, run time of I would guess that the way we tweak optimizers here won't work with |
An answer to my naive expectation of a different epoch time is that the same number of cases are being processed, the only diff is the accounting. I realized the thing to do is try batch_size=64 with VGG16, i.e. 2x what I can fit in memory, and, forgetting to recomment out BatchNorm I get
Retrying w/out BatchNorm. |
batch_size=64, accum_iters=2: one run: positive test always<80%, dropped to 50's after a few epochs. batch_size=96, accum_iters=3
batch_size=128, accum_iters=4 [got low memory msgs; OOM failure on higher batch]
Adam w/ BatchNorm, batch_size=128 fits w/ Adam, it turns out.
lr=0.00125
NN's are far more fun than horse racing, because the horses are real. In this case, apparently even snails will do. Some morbid labeling of keyword-vector-net-generated pairs while waiting, makes life bloom anew; pics won't render in chrome/default, since not https: https://forums.craigslist.org/?ID=295644868 I suspect the limitations in accuracy depend on the types of per-pic data more than batch size or net topology, though BatchNorm gives a tantalizing boost to the convergence rate of positive holdouts with the keywords (above), so I'm hoping a leverageable insight will dawn from that. Histograms plus keyword vectors get positive accuracy up to around 92% (faster runs means big sample), and it seems a convolutional method should get closer to that than ~85%. In the end, I'll mix and match the methods dynamically according to AI personality requirements when interacting. |
here is my solution that works for any optimizer! (with tensorflow backend) import sys
import tensorflow
from tensorflow.keras import backend as K
def convert_to_accumulate_gradient_optimizer(orig_optimizer, update_params_frequency, accumulate_sum_or_mean=True):
if update_params_frequency < 1:
raise ValueError('update_params_frequency must be >= 1')
print('update_params_frequency: %s' % update_params_frequency)
print('accumulate_sum_or_mean: %s' % accumulate_sum_or_mean)
orig_get_gradients = orig_optimizer.get_gradients
orig_get_updates = orig_optimizer.get_updates
accumulated_iterations = K.variable(0, dtype='int64', name='accumulated_iterations')
orig_optimizer.accumulated_iterations = accumulated_iterations
def updated_get_gradients(self, loss, params):
return self.accumulate_gradient_accumulators
def updated_get_updates(self, loss, params):
self.accumulate_gradient_accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
updates_accumulated_iterations = K.update_add(accumulated_iterations, 1)
new_grads = orig_get_gradients(loss, params)
if not accumulate_sum_or_mean:
new_grads = [g / K.cast(update_params_frequency, K.dtype(g)) for g in new_grads]
self.updated_grads = [K.update_add(p, g) for p, g in zip(self.accumulate_gradient_accumulators, new_grads)]
def update_function():
with tensorflow.control_dependencies(orig_get_updates(loss, params)):
reset_grads = [K.update(p, K.zeros(K.int_shape(p), dtype=K.dtype(p))) for p in self.accumulate_gradient_accumulators]
return tensorflow.group(*(reset_grads + [updates_accumulated_iterations]))
def just_store_function():
return tensorflow.group(*[updates_accumulated_iterations])
update_switch = K.equal((updates_accumulated_iterations) % update_params_frequency, 0)
with tensorflow.control_dependencies(self.updated_grads):
self.updates = [K.switch(update_switch, update_function, just_store_function)]
return self.updates
orig_optimizer.get_gradients = updated_get_gradients.__get__(orig_optimizer, type(orig_optimizer))
orig_optimizer.get_updates = updated_get_updates.__get__(orig_optimizer, type(orig_optimizer)) And simple unit tests from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K
import numpy as np
import pytest
import tensorflow as tf
def get_simple_linear_model(orig_optimizer, update_params_frequency, accumulate_sum_or_mean):
inputs = Input(shape=(1, ), dtype='float32')
outputs = Dense(1, use_bias=False, kernel_initializer='ones')(inputs)
model = Model(inputs=inputs, outputs=outputs)
convert_to_accumulate_gradient_optimizer(orig_optimizer, update_params_frequency=update_params_frequency,
accumulate_sum_or_mean=accumulate_sum_or_mean)
def y_loss(y_true, y_pred):
return K.mean(y_pred)
def get_w():
return model.get_weights()[0][0][0]
def get_sgd_iteration():
return orig_optimizer.get_weights()[orig_optimizer.weights.index(orig_optimizer.iterations)]
model.compile(optimizer=orig_optimizer, loss=y_loss)
return model, get_w, get_sgd_iteration
def test_update_just_when_need():
model, get_w, get_sgd_iteration = get_simple_linear_model(SGD(lr=1.0), 2, False)
w_before_call = get_w()
model.fit(x=np.array([[2.0]], dtype=np.float32), y=np.array([[0.0]], dtype=np.float32), batch_size=1)
w_after_first_call = get_w()
global_step_after_first_call = get_sgd_iteration()
model.fit(x=np.array([[3.0]], dtype=np.float32), y=np.array([[0.0]], dtype=np.float32), batch_size=1)
w_after_second_call = get_w()
global_step_after_second_call = get_sgd_iteration()
assert global_step_after_first_call == 0
assert global_step_after_second_call == 1
assert w_before_call == 1.0
assert w_after_first_call == 1.0
assert w_after_second_call == -1.5
def test_reset_after_update():
model, get_w, get_sgd_iteration = get_simple_linear_model(SGD(lr=1.0), 1, False)
model.fit(x=np.array([[2.0]], dtype=np.float32), y=np.array([[0.0]], dtype=np.float32), batch_size=1)
model.fit(x=np.array([[3.0]], dtype=np.float32), y=np.array([[0.0]], dtype=np.float32), batch_size=1)
w_after_second_call = get_w()
assert w_after_second_call == -4.0 |
@noamwies Thanks for sharing the code. I think the following line should be corrected:
as
|
My implementation with rewriting optimizer: |
@alexeydevederkin I am getting the error: upon running your code, could you please help me with my problem. I am running it on Python 3.7 and TF2. Also, TF doesnt have keras legacy interfaces, how could we replace your code for Tensorflow? (I installes Keras just for this optimizer). Thanks a lot in advance |
I have the same problem, I'm trying to get to work a gradient accumulator optimizer with keras and TF2 without success by the moment. |
@viig99 - Upon using SGD accumulate function, I am getting the error - Can you suggest what could be the cause ? Thanks |
It seems like all these code could not run with tensorflow keras. I changed the code to work with TF Keras (e.g. change from keras to tf.keras, btw). The code could be complied but I could run it properly (look like a stuck in something without doing anything) |
'tensorflow.python.keras.optimizer_v2.OptimizerV2' was introduced since tensorflow 1.13. The design of 'OptimizerV2' seems an overhaul from the original 'Optimizer' class. I think the code snippets above only worked for the old 'Optimizer' class, i.e. only worked for tf.keras optimizers with tensorflow version 1.12 or lower. |
Thx for the info @jkjung-avt. I am trying to work with OptimizerV2 but it is indeed not easy. |
Hi, could anyone show to to use this code for a bert finetune? I mean should just replace this with bert's optimization.py or do something else? thanks |
@652994331 : Are you able to run your code with TF keras? I supposed it does not work when converting the code to TF keras and run it. But please let me know if it is possible from your side. thx. |
both keras and tf.keras can refer this: |
Has anyone encountered this problem while using AdamAccumulate? |
@Pari-singh I encountered this problem and still stuck on this problem. Can you solve it? If already resolved Please tell me |
Did you verify that the implementation works well, in terms of expected performance and runtime? Would be nice to know. Currently, for accumulated gradients I typically modify the train step to handle when and how gradients are updated. However your approach might be more convenient as it makes it possible to still use model.fit() or model.fit_generator() for the training loop. |
Can keras support to update parameters after a relative large batch size which exceed the GPU memory if feeded in one time?
My model now can only be feeded batch_size=4 samples a time due to GPU 12G memory. The loss is difficult to decline when batch_size=4. So I want to update the parameters after 32 samples. Will keras be able to support this? It seems that Caffe can support this.
Thanks!
The text was updated successfully, but these errors were encountered: