Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guided Backpropagation in Keras #1777

Closed
mbz opened this issue Feb 21, 2016 · 23 comments
Closed

Guided Backpropagation in Keras #1777

mbz opened this issue Feb 21, 2016 · 23 comments

Comments

@mbz
Copy link

mbz commented Feb 21, 2016

I'm trying to implement Saliency Maps and Guided Backpropagation in Keras using the following code on Lasagne.

https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb

I manged to make the first part (Saliency Maps) work as follow:

def compile_saliency_function(model):
    """
    Compiles a function to compute the saliency maps and predicted classes
    for a given minibatch of input images.
    """
    inp = model.layers[0].get_input()
    outp = model.layers[-1].get_output()
    max_outp = T.max(outp, axis=1)
    saliency = theano.grad(max_outp.sum(), wrt=inp)
    max_class = T.argmax(outp, axis=1)
    return theano.function([inp], [saliency, max_class])
def show_images(img_original, saliency, max_class, title):
    classes = [str(x) for x in range(10)]
    # get out the first map and class from the mini-batch
    saliency = saliency[0]
    saliency = saliency[::-1].transpose(1, 2, 0)
    max_class = max_class[0]
    # plot the original image and the three saliency map variants
    plt.figure(figsize=(10, 10), facecolor='w')
    plt.suptitle("Class: " + classes[max_class] + ". Saliency: " + title)
    plt.subplot(2, 2, 1)
    plt.title('input')
    plt.imshow(img_original)
    plt.subplot(2, 2, 2)
    plt.title('abs. saliency')
    plt.imshow(np.abs(saliency).max(axis=-1), cmap='gray')
    plt.subplot(2, 2, 3)
    plt.title('pos. saliency')
    x = (np.maximum(0, saliency) / saliency.max())[:,:,0]
    plt.imshow(x)
    plt.subplot(2, 2, 4)
    plt.title('neg. saliency')
    x = (np.maximum(0, -saliency) / -saliency.min())[:,:,0]
    plt.imshow(x)
    # plt.show()

Now, I'm working on the second part (Guided Backpropagation) but it doesn't work. Here is the code:

class ModifiedBackprop(object):

    def __init__(self, nonlinearity):
        self.nonlinearity = nonlinearity
        self.ops = {}  # memoizes an OpFromGraph instance per tensor type

    def __call__(self, x):
        # OpFromGraph is oblique to Theano optimizations, so we need to move
        # things to GPU ourselves if needed.
        if theano.sandbox.cuda.cuda_enabled:
            maybe_to_gpu = theano.sandbox.cuda.as_cuda_ndarray_variable
        else:
            maybe_to_gpu = lambda x: x
        # We move the input to GPU if needed.
        x = maybe_to_gpu(x)
        # We note the tensor type of the input variable to the nonlinearity
        # (mainly dimensionality and dtype); we need to create a fitting Op.
        tensor_type = x.type
        # If we did not create a suitable Op yet, this is the time to do so.
        if tensor_type not in self.ops:
            # For the graph, we create an input variable of the correct type:
            inp = tensor_type()
            # We pass it through the nonlinearity (and move to GPU if needed).
            outp = maybe_to_gpu(self.nonlinearity(inp))
            # Then we fix the forward expression...
            op = theano.OpFromGraph([inp], [outp])
            # ...and replace the gradient with our own (defined in a subclass).
            op.grad = self.grad
            # Finally, we memoize the new Op
            self.ops[tensor_type] = op
        # And apply the memoized Op to the input we got.
        return self.ops[tensor_type](x)
class GuidedBackprop(ModifiedBackprop):
    def grad(self, inputs, out_grads):
        (inp,) = inputs
        (grd,) = out_grads
        dtype = inp.dtype
        return (grd * (inp > 0).astype(dtype) * (grd > 0).astype(dtype),)
modded_relu = GuidedBackprop(keras.activations.relu)  # important: only instantiate this once!
for layer in model.layers:
    if 'activation' in layer.get_config() and layer.get_config()['activation'] == 'relu':
        layer.activation = modded_relu
        # layer.activation = theano.function([],[])

I've tested the code about in Theano and it's working so my guess is that there is something wrong with the way I'm replacing the activation in the layers (last snippet). Any idea what I'm doing wrong?

@mbz
Copy link
Author

mbz commented Feb 25, 2016

hm, still doesn't work for me. I still see no difference.

btw, I don't understand why taking relu from it's original function is
different from getting it from activation?
In layers/core.py the code sets these two to one thing:

def __init__(self, activation, **kwargs):
    super(Activation, self).__init__(**kwargs)
    self.activation = activations.get(activation)

On Thu, Feb 25, 2016 at 1:49 PM, francescopittaluga <
notifications@github.com> wrote:

The following works. It's kind of hacky tho.

r = 1 # Set to index of one of the layers in Keras model with relu activation
relu = model.layers[r].activation
modded_relu = GuidedBackprop(relu)
relu_layers = [layer for layer in model.layers if getattr(layer, 'activation', None) is relu]for layer in relu_layers:
layer.activation = modded_relu


Reply to this email directly or view it on GitHub
#1777 (comment).

@FlorianImagia
Copy link

Hi,

Thanks for sharing your code.
I've succeed to make the guided backdrop.
You should not compile your model BEFORE the modifications of the non linearities.
If like me you are loading a json file, you'll have to modify Keras to avoid the compilation at this time.
And no need to compile the model after the modifications.

@Sandy4321
Copy link

Are there full code example
On Mar 24, 2016 16:44, "FlorianImagia" notifications@github.com wrote:

Hi,

Thanks for sharing your code.
I've succeed to make the guided backdrop work.
You should not compile your model BEFORE the modifications of the non
linearities.
If like me you are loading a son file, you'll have to modify Keras to
avoid the compilation at this time.
Don't forget to compile after the modifications.


You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub
#1777 (comment)

@FlorianImagia
Copy link

If you don't compile the code like I said, the code given by mbz is pretty complete.
And you can complete the missing parts with the link to the Lasagne implementation (90% of the code is identical):
https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb

@michaelosthege
Copy link

michaelosthege commented Aug 22, 2016

Hi,
I'm trying to make this work with both theano and tensorflow backends, but there are some problems with the code above.
I had to modify the compile_saliency_function to use the keras.backend and call the sum in a different way:

keras.backend.gradients(keras.backend.sum(max_outp), inp)

At the moment I can make saliency maps with default backpropagation independent of the backend.

For Guided Backprop the problem is the ModifiedBackprop class which heavily depends on theano APIs. (theano.sandbox.cuda.cuda_enabled, theano.sandbox.cuda.as_cuda_ndarray_variable, theano.OpFromGraph)
I couldn't find keras equivalents for these and my feeling is that this may have to be done in a completely different way.
In addition to that I only got it to work if I pass the GuidedBackprop instance as the activation parameter when I create the layers:

modifiedRelu = GuidedBackprop(keras.activations.relu)
model = keras.models.Sequential()
model.add(keras.layers.ZeroPadding2D((1,1),input_shape=(3,224,224)))
model.add(keras.layers.Convolution2D(64, 3, 3, activation=modifiedRelu))
model.add(keras.layers.ZeroPadding2D((1,1)))

(substituting them in an existing model did not actually substitute them.)

Does someone have an idea for backend-invariant implementation of ModifiedBackprop for substitution in existing models?

Related: tensorflow implementation using tf.gradient_override_map, which is a lot more elegant. (Does theano have something like tf.gradient_override_map?)

@ssierral
Copy link

@michaelosthege Did you try the saliency example? I am using the same function and get as result the following error:

def compile_saliency_function(model):
    """
    Compiles a function to compute the saliency maps and predicted classes
    for a given minibatch of input images.
    """
    inp = model.layers[0].input
    outp = model.layers[-1].output
    max_outp = K.T.max(outp, axis=1)
    saliency = K.gradients(K.sum(max_outp), inp)
    max_class = K.T.argmax(outp, axis=1)
    return K.function([inp, K.learning_phase()], [saliency, max_class])

compile_saliency_function(new_model)([X_train[:20], 0])[0]

Output log:

DisconnectedInputError:  
Backtrace when that variable is created:
[...]
---> 10     saliency = K.gradients(K.sum(max_outp), inp)
[...]
  File "/home/ssierral/.conda/envs/sierraenv/lib/python2.7/site-packages/keras/models.py", line 197, in model_from_json
    return layer_from_config(config, custom_objects=custom_objects)
[...]

BTW, my model is composed of a embedding layer, an 1D CNN, a K-max pooling function and a softmax layer. Thank you,

@michaelosthege
Copy link

@inulinux12 I reproduced the notebook from the lasagne link.
Currently I'm on vacation with barely enough internet to write this reply, but

  1. I don't recall encountering a DisconnectedInputError when I did this weeks ago
  2. you should do the derivative from before the softmax layer, because if you do the gradient on softmax, you ask for "show me what maximizes my target class and minimizes everything else" instead of "show me what maximizes my target class"

hope this helps
cheers

@ssierral
Copy link

ssierral commented Sep 28, 2016

It seems to be that using K-max pooling with 1D convolutions is messing the model, and theano is unable to calculate the gradient. I have tried with a basic mnist cnn and the code works perfectly, but 1D CNN for text does not.

EDIT:
It seems to be that it was the embedding layer, I have my code running perfectly

@joeliven
Copy link

joeliven commented Oct 6, 2016

Hi, quick question regarding this if either/any of you have a second...I am trying to implement the compile_saliency_function (basic version right now, not guided version) from the code above
but when I got to actually compile the theano function my program just seems to stall (it's been running for 10+ mins trying to compile the function now).
Has anybody experienced a similar problem? Does anyone know why this might be the case? Note, I am using the last layer of the network (softmax) and doing on a cpu right now.
Thanks!!
Joel

@jf003320018
Copy link

@michaelosthege Do you realize guided backprop with tensorflow backend in keras? If you did, could you share your code? I want to realize it but have not idea. Thank you.

@michaelosthege
Copy link

@jf003320018 No, this did not work because of missing APIs (see my comment of Aug 22). I only managed with the theano backend. All essentials for that are already posted.

Not exactly guided backprop, but potentially relevant: for understanding what your network does: DeepLIFT and Quiver

@AvantiShri
Copy link

^ FYI the DeepLIFT framework does implement Guided Backprop, and as of a few days ago there is a tensorflow branch. You can find an example of the different importance score methods implemented within DeepLIFT here: https://github.com/kundajelab/deeplift/blob/tensorflow/examples/public/tal_gata/TAL-GATA%20simulation.ipynb.

Heads up: DeepLIFT relies on defining a "reference" input which represents an input that lacks any interesting features. If you don't have a good sense of what the reference input is, I suggest exploring different methods. DeepLIFT has mostly been developed in the context of genomics where the notion of a reference is better-defined than for images; as the creator of DeepLIFT, I recommend using Guided Backprop if you are working on image-like data, at least until we figure out what a good reference for images is (all-zeros doesn't work well, as dark patches can be an interesting feature sometimes). Also, please let me know if you run into any issues. The tensorflow branch passes my unit tests but I have not stress-tested it yet.

@AvantiShri
Copy link

(Also feel free to shoot me an email if you are interested in support for layers other than the ones currently supported)

@jf003320018
Copy link

jf003320018 commented Dec 11, 2016

@michaelosthege Sorry for disturbing you again. Today, I realize this code in keras using theano backend as described in https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb. But I find that the results of guided backprop is as same as those of saliency map. So it seems that we cannot just use the ModifiedBackprop and GuidedBackprop from Lasagne.
Since you have reproduced the results in keras using theano backend. So I want to know how you realize it. Thank you very much.

@johnny5550822
Copy link

johnny5550822 commented Dec 20, 2016

@michaelosthege @jf003320018 for the guided propagation, machael you mentioned that there is lack of APIs to do that, did you let the author of Keras to know about it so that someone can work on the APIs?

FYI, guided backpropagation is very straightforward in Torch7 (like 5-10 lines). It seems difficult to make it work here.

@michaelosthege
Copy link

@johnny550822 We are missing APIs in tensorflow here, so it is not a Keras issue to start with. And if you look closely this is related to the architectural differences between the two frameworks.
Considering that there are alternatives such as DeepLIFT (or using the other backed), I don't think it is worth the struggle.

@johnny5550822
Copy link

@michaelosthege :( oh...no. Can you briefly tell me what is missing? Because this may relate to other things that I want to try out with tensorflow+keras... (I am new to them)

@michaelosthege
Copy link

michaelosthege commented Dec 20, 2016 via email

@johnny5550822
Copy link

@michaelosthege thanks!

@Andrjusha
Copy link

Did anyone succeed to implement guided backpropagation on keras, not just a basic saliency map defined by the simple gradient? There should be a great difference between the basic version and guided backpropagation. In my implementation I can't achieve any difference between these two approaches.

@vinayakumarr
Copy link

saliency map

My code is given below for computing gradients in order to visualize the sailency map

from keras import backend as K
import theano
def compile_saliency_function(model):
"""
Compiles a function to compute the saliency maps and predicted classes
for a given minibatch of input images.
"""
inp = model.layers[0].input
#print("-----------------------input-----------------------------")
#print(inp)
outp = model.layers[-1].output
#print("-----------------------output----------------------------")
#print(outp)
max_outp = K.T.max(outp, axis=1)
#print(max_outp)
saliency = K.gradients(K.sum(max_outp),inp)
#print(saliency.eval())
max_class = K.T.argmax(outp, axis=1)
print(max_class)
v1 = K.function([inp, K.learning_phase()], [saliency, max_class])
return v1

v = compile_saliency_function(model)([X_train[:1], 0])[0]
print(v)

I am getting 421 vector (Because i have 43 features ). But i want to get 424 (where 4 is the number of classes). Could you please tell where should i make changes?

NotAndOr added a commit to NotAndOr/bookmarks that referenced this issue May 3, 2017
move keras-team/keras#1777 from 20161231.md to 20160521.md.
move https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb from 20161231.md to 20160521.md.
move https://github.com/kundajelab/deeplift/blob/tensorflow/examples/public/tal_gata/TAL-GATA%20simulation.ipynb from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/38340791/guided-back-propagation-in-tensorflow from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1412.6806 from 20161231.md to 20160521.md.
move http://www.cs.toronto.edu/~guerzhoy/321/lec/W07/HowConvNetsSee.pdf from 20161231.md to 20160521.md.
move tensorflow/tensorflow#6422 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/39793505/in-tensorflow-is-it-possible-to-use-different-learning-rate-for-different-part/39793644#39793644 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1608.00530 from 20161231.md to 20160521.md.
move https://github.com/artvandelay/Deep_Inside_Convolutional_Networks from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1611.05418 from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1704.07911 from 20161231.md to 20160521.md.
move https://blogs.nvidia.com/blog/2017/04/27/how-nvidias-neural-net-makes-decisions from 20161231.md to 20160521.md.
move http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7578651 from 20161231.md to 20160521.md.
move http://web.mit.edu/zoya/www/visual_attention_reading.pdf from 20161231.md to 20160521.md.
move https://github.com/kjw0612/awesome-deep-vision#visual-attention-and-saliency from 20161231.md to 20160521.md.
NotAndOr added a commit to NotAndOr/bookmarks that referenced this issue May 3, 2017
move keras-team/keras#1777 from 20161231.md to 20160521.md.
move https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb from 20161231.md to 20160521.md.
move https://github.com/kundajelab/deeplift/blob/tensorflow/examples/public/tal_gata/TAL-GATA%20simulation.ipynb from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/38340791/guided-back-propagation-in-tensorflow from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1412.6806 from 20161231.md to 20160521.md.
move http://www.cs.toronto.edu/~guerzhoy/321/lec/W07/HowConvNetsSee.pdf from 20161231.md to 20160521.md.
move tensorflow/tensorflow#6422 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/39793505/in-tensorflow-is-it-possible-to-use-different-learning-rate-for-different-part/39793644#39793644 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1608.00530 from 20161231.md to 20160521.md.
move https://github.com/artvandelay/Deep_Inside_Convolutional_Networks from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1611.05418 from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1704.07911 from 20161231.md to 20160521.md.
move https://blogs.nvidia.com/blog/2017/04/27/how-nvidias-neural-net-makes-decisions from 20161231.md to 20160521.md.
move http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7578651 from 20161231.md to 20160521.md.
move http://web.mit.edu/zoya/www/visual_attention_reading.pdf from 20161231.md to 20160521.md.
move https://github.com/kjw0612/awesome-deep-vision#visual-attention-and-saliency from 20161231.md to 20160521.md.
@experiencor
Copy link

@Andrjusha I have ported the implementation of Guided Backprop from TensorFlow to Keras.

https://github.com/experiencor/deep-viz-keras

@stale stale bot added the stale label Sep 26, 2017
@stale
Copy link

stale bot commented Sep 26, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

NotAndOr added a commit to NotAndOr/bookmarks that referenced this issue Feb 19, 2023
move keras-team/keras#1777 from 20161231.md to 20160521.md.
move https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb from 20161231.md to 20160521.md.
move https://github.com/kundajelab/deeplift/blob/tensorflow/examples/public/tal_gata/TAL-GATA%20simulation.ipynb from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/38340791/guided-back-propagation-in-tensorflow from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1412.6806 from 20161231.md to 20160521.md.
move http://www.cs.toronto.edu/~guerzhoy/321/lec/W07/HowConvNetsSee.pdf from 20161231.md to 20160521.md.
move tensorflow/tensorflow#6422 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/39793505/in-tensorflow-is-it-possible-to-use-different-learning-rate-for-different-part/39793644#39793644 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1608.00530 from 20161231.md to 20160521.md.
move https://github.com/artvandelay/Deep_Inside_Convolutional_Networks from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1611.05418 from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1704.07911 from 20161231.md to 20160521.md.
move https://blogs.nvidia.com/blog/2017/04/27/how-nvidias-neural-net-makes-decisions from 20161231.md to 20160521.md.
move http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7578651 from 20161231.md to 20160521.md.
move http://web.mit.edu/zoya/www/visual_attention_reading.pdf from 20161231.md to 20160521.md.
move https://github.com/kjw0612/awesome-deep-vision#visual-attention-and-saliency from 20161231.md to 20160521.md.
NotAndOr added a commit to NotAndOr/bookmarks that referenced this issue Feb 19, 2023
move keras-team/keras#1777 from 20161231.md to 20160521.md.
move https://github.com/Lasagne/Recipes/blob/master/examples/Saliency%20Maps%20and%20Guided%20Backpropagation.ipynb from 20161231.md to 20160521.md.
move https://github.com/kundajelab/deeplift/blob/tensorflow/examples/public/tal_gata/TAL-GATA%20simulation.ipynb from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/38340791/guided-back-propagation-in-tensorflow from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1412.6806 from 20161231.md to 20160521.md.
move http://www.cs.toronto.edu/~guerzhoy/321/lec/W07/HowConvNetsSee.pdf from 20161231.md to 20160521.md.
move tensorflow/tensorflow#6422 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/39793505/in-tensorflow-is-it-possible-to-use-different-learning-rate-for-different-part/39793644#39793644 from 20161231.md to 20160521.md.
move http://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1608.00530 from 20161231.md to 20160521.md.
move https://github.com/artvandelay/Deep_Inside_Convolutional_Networks from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1611.05418 from 20161231.md to 20160521.md.
move https://arxiv.org/abs/1704.07911 from 20161231.md to 20160521.md.
move https://blogs.nvidia.com/blog/2017/04/27/how-nvidias-neural-net-makes-decisions from 20161231.md to 20160521.md.
move http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7578651 from 20161231.md to 20160521.md.
move http://web.mit.edu/zoya/www/visual_attention_reading.pdf from 20161231.md to 20160521.md.
move https://github.com/kjw0612/awesome-deep-vision#visual-attention-and-saliency from 20161231.md to 20160521.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests