Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to obtain the gradient of each parameter in the last epoch of training #2226

Closed
tuming1990 opened this issue Apr 7, 2016 · 38 comments
Closed

Comments

@tuming1990
Copy link

I want to obtain the gradient of each parameter in the last epoch of training. Is there a way to do so in Keras?

Thanks,
Ming

@philipperemy
Copy link

You can have the outputs of a particular layer by:
http://keras.io/faq/#how-can-i-visualize-the-output-of-an-intermediate-layer

The parameters (weights and so on) are easily retrieved in your model object.

To compute the gradient you can use this code:

import keras.backend as K
import numpy as np

X = K.placeholder(ndim=2) #specify the right placeholder
Y = K.sum(K.square(X)) # loss function
fn = K.function([X], K.gradients(Y, [X])) #function to call the gradient

That's a partial answer. I hope this helps.

@gideonite
Copy link

How do you call fn? I tried fn(model.layers[k].input) where model.layers[k] is a layers.core.Dense.

@philipperemy
Copy link

philipperemy commented May 14, 2016

Here is an example where you can call the function on a 2x2 matrix. I hope this helps.

import keras.backend as K
import numpy as np

X = K.placeholder(ndim=2)
Y = K.sum(K.square(K.round(X)))
fn = K.function([X], K.gradients(Y, [X]))
print fn([np.ones((2, 2), dtype=np.float32)])

@philipperemy
Copy link

@gideonite If you can make it work on a toy example, please let me know

@gideonite
Copy link

@philipperemy I wound up printing out the weight of my model during training to debug and have since moved on. Thank you, I appreciate the help.

@aravindr93
Copy link

aravindr93 commented Jun 24, 2016

@philipperemy and @gideonite
I'm facing a very similar issue as well. I wish to compute the derivatives wrt parameters of the network to write the update equations outside the Keras loop. I'm currently using the tensorflow back-end, but can switch if the other has some special functionality that will help. A rough code snippet is below:

model = Sequential()
model.add(....) # add a couple of layers here
x = tf.placeholder(tf.float32, shape=(None, out_dim))
y = model(x)
loss = K.sum(K.square(y-target))  # just think of any standard loss fn

I need the gradient of loss wrt each parameter in the neural network for targeted reinforcement learning application, so the model.fit style functions are not useful for me. Roughly, what i cant to accomplish is:
param_grad = tf.gradients(loss, model_params)

The problem is, I am unable to get the symbolic model_params. If I do model.get_weights or something, it get's me the numeric weights and not a symbolic one. Would appreciate some help.

@philipperemy
Copy link

From what I know, it's very hard to do it in Keras.

In your case, I strongly advice you to use Tensorflow only WITHOUT keras. It's much easier.

If you're interested in the final layer and if you use the MSE, you can always reverse-engineer the backpropagation function to find the gradient but that's a very specific case:

w[i+1]-w[i] = - learning_rate x dE/dW

Then your gradient would be:

dE/dW = (w[i] - w[i+1])/learning_rate

When you do model.get_weights(), your weights are evaluated. If you want the symbolic ones, do something like:

for layer in model.layers
    layer.W or layer.b

Hope this helps. Let me know if you can make it in Keras!

@nejlag
Copy link

nejlag commented Aug 9, 2016

Hi @philipperemy.
I tried to do the same thing. But weights do not change! I posted the whole thing here.

@ebanner
Copy link

ebanner commented Sep 5, 2016

Here's how I did it:

def get_gradients(model):
    """Return the gradient of every trainable weight in model

    Parameters
    -----------
    model : a keras model instance

    First, find all tensors which are trainable in the model. Surprisingly,
    `model.trainable_weights` will return tensors for which
    trainable=False has been set on their layer (last time I checked), hence the extra check.
    Next, get the gradients of the loss with respect to the weights.

    """
    weights = [tensor for tensor in model.trainable_weights if model.get_layer(tensor.name[:-2]).trainable]
    optimizer = model.optimizer

    return optimizer.get_gradients(model.total_loss, weights)

@davidljung
Copy link

@ebanner What is model.total_loss? My model (Theano 0.9.0dev2) object has no such attribute - it only seems to have a .loss attribute and that is just the string name (e.g. "mse").

@ebanner
Copy link

ebanner commented Sep 21, 2016

@davidljung model.total_loss is a tensor containing the loss, which is determined by the type of loss you are using. I'm guessing you did not compile your model first? Here's a minimal example using categorical crossentropy loss:

from keras.layers import Input, Dense
from keras.models import Model

input = Input(shape=(2,))
probs = Dense(2, activation='softmax', name='probs')(input)

model = Model(input=input, output=probs)
model.compile(optimizer='sgd', loss='categorical_crossentropy')

model.total_loss
# ==> Elemwise{mul,no_inplace}.0

I am also using theano 0.9.0dev2 for the record.

@jf003320018
Copy link

jf003320018 commented Oct 24, 2016

@ebanner I have tried you method, but I do not obtain the values, but get things like: [Elemwise{add,no_inplace}.0, GpuFromHost.0, GpuFromHost.0, GpuFromHost.0].
Could you help me?

@NikitaRomanov
Copy link

@ebanner I have tried you method too but have got the same result as @jf003320018 . What can we do with this?

@ebanner
Copy link

ebanner commented Nov 8, 2016

OK here's a full working example from start to finish. Hopefully this will clear things up. What you do with the gradient tensors is define a keras function to evaluate those tensors for a particular setting of the model's inputs. Then call the function on a particular setting of the inputs!

Define model

from keras.layers import Input, Dense
from keras.models import Model

input = Input(shape=[2])
probs = Dense(1, activation='sigmoid')(input)

model = Model(input=input, output=probs)
model.compile(optimizer='sgd', loss='binary_crossentropy')

Get gradient tensors

weights = model.trainable_weights # weight tensors
weights = [weight for weight in weights if model.get_layer(weight.name[:-2]).trainable] # filter down weights tensors to only ones which are trainable
gradients = model.optimizer.get_gradients(model.total_loss, weights) # gradient tensors

print weights
# ==> [dense_1_W, dense_1_b]

Define keras function to return gradients

import keras.backend as K

input_tensors = [model.inputs[0], # input data
                 model.sample_weights[0], # how much to weight each sample by
                 model.targets[0], # labels
                 K.learning_phase(), # train or test mode
]

get_gradients = K.function(inputs=input_tensors, outputs=gradients)

Get gradients of weights for particular (X, sample_weight, y, learning_mode) tuple

from keras.utils.np_utils import to_categorical

inputs = [[[1, 2]], # X
          [1], # sample weights
          [[1]], # y
          0 # learning phase in TEST mode
]

print zip(weights, get_gradients(inputs))
# ==> [(dense_1_W, array([[-0.42342907],
                          [-0.84685814]], dtype=float32)),
       (dense_1_b, array([-0.42342907], dtype=float32))]

@jf003320018
Copy link

@ebanner Thank you for your answer. It really works. But could you tell me how to calculate the gradients for two or more samples simultaneously? Because I do not know the meanings of 'model.sample_weights', I cannot modify the code. Thank you very much.

@ebanner
Copy link

ebanner commented Nov 8, 2016

sample_weight is documented here:

sample_weight: optional array of the same length as x, containing weights to apply to the model's loss for each sample...

Passing a value of 1 for each sample gives all samples equal importance in the eyes of the optimizer.

As for scaling up the example to an arbitrary number of samples, see this example (using the same function defined in my previous post):

Get gradients of weights for particular (X, sample_weight, y, learning_mode) tuple

from keras.utils.np_utils import to_categorical

nb_sample = 10

inputs = [np.random.randn(nb_sample, 2), # X
          np.ones(nb_sample), # sample weights
          np.random.randint(2, size=[nb_sample, 1]), # y
          0 # learning phase in TEST mode
]

print zip(weights, get_gradients(inputs))
# ==> [(dense_2_W, array([[-0.1869444 ],
                          [ 0.34009627]], dtype=float32)),
       (dense_2_b, array([ 0.17382634], dtype=float32))]

@jf003320018
Copy link

@ebanner The problem is soved. Thank you very much.

@jchen114
Copy link

jchen114 commented Feb 2, 2017

How would I apply the gradients that I retrieve using this onto a separate model with the same parameters?

@patyork
Copy link
Contributor

patyork commented Feb 3, 2017

model1.layers[i].set_weights(model1.layers[i].get_weights() + model2_layer_i__gradients)

@pokey
Copy link
Contributor

pokey commented Feb 13, 2017

In case anyone is trying to do this with the Sequential model and seeing errors like:

AttributeError: 'Sequential' object has no attribute 'total_loss'

you can use model.model in place of model in such places. Eg

gradients = model.optimizer.get_gradients(model.model.total_loss, weights)

@gokceneraslan
Copy link
Contributor

I sent a PR to visualize grads via TensorBoard, see #6313.

@stale stale bot added the stale label Jul 18, 2017
@stale
Copy link

stale bot commented Jul 18, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@shaifugpt
Copy link

@ebanner I have a query regarding the learning phase in TEST/TRAIN mode in inputs. What is their effect. I get the same set of gradients in both

@embanner
Copy link

embanner commented Aug 21, 2017

@shaifugpt Here's the documentation for learning phase. https://keras.io/backend/

learning_phase()
Returns the learning phase flag.

The learning phase flag is a bool tensor (0 = test, 1 = train) to be passed as input to any Keras function that uses a different behavior at train time and test time.

Returns

Learning phase (scalar integer tensor or Python integer).

For instance when learning_phase=1, a dropout layer will actually perform dropout (i.e. zero out each input activation with some probability). Whereas if learning_phase=0 a dropout layer will instead scale the input accordingly (as opposed to zeroing it out).

You're getting the same gradients in both cases because none of the layers you are using depend on the learning phase.

@shaifugpt
Copy link

@ebanner If the first layer in the model is a merge layer. We need to pass two sets of inputs for each. I am passing it as:

inputs = [[[trainX[len(trainX)-40:len(trainX)],trainX[len(trainX)-40:len(trainX)]]], # X
      [1], # sample weights
      [trainY[len(trainX)-40:len(trainX)]], # y
      1 # learning phase in Train mode
]

Then,

   grads=get_gradients(inputs)

gives error TypeError: unhashable type: 'list'

What is the correct way of passing them

@RyanCV
Copy link

RyanCV commented Sep 8, 2017

I got the following error by running the code to get the following errors. I use keras 2.0.6 with theano 0.9.0. How to solve this? Thanks. @ebanner
```
weights = [weight for weight in weights if model.get_layer(weight.name[:-2]).trainable] # filter down weights tensors to only ones which are trainable
AttributeError: 'NoneType' object has no attribute 'trainable'

@guoxiaolu
Copy link

Thanks, it is useful @ebanner . In keras 2.0.6, weight name has changed, like 'conv2d_1/kernel:0', @RyanCV

@mathieumb
Copy link

@shaifugpt

I just ran into the same error for the same reason (multiple inputs) and passing them sequentially seems to fix it.

I.e.:

 input_tensors = [model.inputs[0],
                 model.inputs[1],  #etc
                 model.sample_weights[0], # how much to weight each sample by
                 model.targets[0], # labels
                 K.learning_phase(), # train or test mode
    ]

Then:

inputs = [inputs[0],
               inputs[1],  
              [1], # sample weights
              trainY, # y
              0 # learning phase in TEST mode
    ]

@sachinruk
Copy link
Contributor

sachinruk commented Apr 17, 2018

For anyone looking for this with Keras 2.0 onwards this is the syntax:

import keras.backend as K

weights = model.trainable_weights # weight tensors
gradients = model.optimizer.get_gradients(model.total_loss, weights) # gradient tensors
input_tensors = model.inputs + model.sample_weights + model.targets + [K.learning_phase()]
get_gradients = K.function(inputs=input_tensors, outputs=gradients)
inputs = [x, x_off, np.ones(len(x)), y, 0]
grads = get_gradients(inputs)

doing this is no longer necessary, and gave me an error: weights = [weight for weight in weights if model.get_layer(weight.name[:-2]).trainable] # filter down weights tensors to only ones which are trainable.

Also note that my model had two X variables, hence why I have: x, x_off.

@AmitLit
Copy link

AmitLit commented May 16, 2018

@sachinruk Absolutley perfect!

Thank you

@ptiwald
Copy link

ptiwald commented Sep 1, 2018

@sachinruk Is it possible to calculate the gradients using "sub_losses", too? By sub losses I mean having 2 or more outputs (and losses: total_loss = loss_1 + loss_2) and then doing something like
model.optimizer.get_gradients(model.LOSS_1, weights) ?

@getamu
Copy link

getamu commented Oct 2, 2018

@ebanner

sample_weight is documented here:

sample_weight: optional array of the same length as x, containing weights to apply to the model's loss for each sample...

Passing a value of 1 for each sample gives all samples equal importance in the eyes of the optimizer.

As for scaling up the example to an arbitrary number of samples, see this example (using the same function defined in my previous post):

Get gradients of weights for particular (X, sample_weight, y, learning_mode) tuple

from keras.utils.np_utils import to_categorical

nb_sample = 10

inputs = [np.random.randn(nb_sample, 2), # X
          np.ones(nb_sample), # sample weights
          np.random.randint(2, size=[nb_sample, 1]), # y
          0 # learning phase in TEST mode
]

print zip(weights, get_gradients(inputs))
# ==> [(dense_2_W, array([[-0.1869444 ],
                          [ 0.34009627]], dtype=float32)),
       (dense_2_b, array([ 0.17382634], dtype=float32))]

I am getting InvalidArgumentError: transpose expects a vector of size 4. But input(1) is a vector of size 3 When I do what you did. Is it because the training input data for me is 3D? as I am using the word embedding

@Youmna-H
Copy link

What if there are multiple outputs, so model.total_loss consists of multiple losses and model.targets is also multiple labels, but we are only interested in one target? specifying model.targets[0][0] to select the first target does not work.

@michelleowen
Copy link

@shaifugpt @mathieumb did you figure out how to work with multiple inputs? I also got same error.

@shaifugpt
Copy link

@michelleowen Try something like:
Input_tensors = [model1.inputs[0],
model1.inputs[1],
model1.sample_weights[0], # how much to weight each sample by
model1.targets[0], # labels
K.learning_phase(),]

@shaifugpt
Copy link

@ebanner Is it possible compute gradient with respect to a specific weight connection of a layer rather than all weights.

@zhanglin010
Copy link

@shaifugpt Hi, may I know whether you find a way to compute gradient with respsect to a specific weight?

@hnsl
Copy link

hnsl commented Aug 15, 2021

I found this old bug via google.

Here's the "modern" version:

def get_gradients(model : tf.keras.Model, x, y_true):
    """Return the gradient of every trainable weight in model"""
    with tf.GradientTape() as tape:
        loss = model.compiled_loss(y_true, model(x))

    return tape.gradient(loss, model.trainable_weights)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests