Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularization Not taken into account when composing Models #5318

Closed
2 tasks done
unrealwill opened this issue Feb 7, 2017 · 1 comment
Closed
2 tasks done

Regularization Not taken into account when composing Models #5318

unrealwill opened this issue Feb 7, 2017 · 1 comment

Comments

@unrealwill
Copy link

unrealwill commented Feb 7, 2017

def BuggyReg():
    inp = Input( batch_shape=(None,10))
    out = ActivityRegularization(100000.0,100000.0)( Dense(100)(inp))
    model= Model([inp],[out])

    #This WORKS fine and takes regularization into account
    #model.compile("adam", "mse")
    #return model

    inp = Input(batch_shape=(None, 10))
    out2 = model(inp)
    model2 = Model([inp],[out2])

    #This doesn't work correctly : regularization is not taken into account
    model2.compile("adam","mse")
    return model2`


m= BuggyReg()
keras/engine/topology.py:379: UserWarning: The `regularizers` property of layers/models is deprecated. Regularization losses are now managed via the `losses` layer/model property.
  warnings.warn('The `regularizers` property of layers/models '

In [4]: m.fit(np.random.randn(10000,10), np.ones((10000,100)))
10000/10000 [==============================] - 0s - loss: 0.7956   SHOULD BE in order of 1000000

A quick search of the warning (which just appeared) points to the previous release note :
#4703
The refactorization of regularization is quite fresh, and I couldn't find any more information to how to obtain the correct behaviour.

I need to be able to compose model, and at the same time apply some regularization.
This bug was quite tricky to catch because some regularization layers like batchNormalization works fine, and inner models worked fine, but combined model was insensitive to regularization.

Can you please advise.

Thank you!

  • Check that you are up-to-date with the master branch of Keras. You can update with:
    pip install git+git://github.com/fchollet/keras.git --upgrade --no-deps

  • If running on Theano, check that you are up-to-date with the master branch of Theano. You can update with:
    pip install git+git://github.com/Theano/Theano.git --upgrade --no-deps

@unrealwill
Copy link
Author

unrealwill commented Feb 10, 2017

Solving my own question. There is probably a Bug in Keras "Layer" class to investigate and correct. (Can someone do the correction @fchollet and close the issue ty )
Using my own not buggy ActivityRegularization2 instead of ActivityRegularization work as intended.

class ActivityRegularization2(Layer):
    """Layer that applies an update to the cost function based input activity.
    # Arguments
        l1: L1 regularization factor (positive float).
        l2: L2 regularization factor (positive float).
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
    # Output shape
        Same shape as input.
    """

    def __init__(self, l1=0., l2=0., **kwargs):
        self.supports_masking = True
        self.l1 = l1
        self.l2 = l2

        super(ActivityRegularization2, self).__init__(**kwargs)
        self.activity_regularizer = regularizers.L1L2Regularizer(l1=l1, l2=l2)
        #We remove the now obsolete regularizer
        #self.regularizers = [self.activity_regularizer]

    #We call add_loss inside call
    def call(self,x,mask=None):
        self.add_loss(self.activity_regularizer(x),x)
        return x

    def get_config(self):
        config = {'l1': self.l1,
                  'l2': self.l2}
        base_config = super(ActivityRegularization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant