Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why can we use parameter "input_shape = (2,)" which undefined in __init__? #57

Open
neilteng opened this issue Jun 12, 2020 · 1 comment

Comments

@neilteng
Copy link
Contributor

neilteng commented Jun 12, 2020

In 4-3, it is not a bug but just a question that I dont understand. Why we can do this model.add(Linear(units = 1,input_shape = (2,))) without this parameter in init method "input_shape = (2,)"

class Linear(layers.Layer):
    def __init__(self, units=32, **kwargs):
#         super(Linear, self).__init__(**kwargs)
        super().__init__(**kwargs)
        self.units = units
    
    # The trainable parameters are defined in build method
    # Since we do not need the input_shape except the build function,
    # we do not need to store then in the __init__ function
    def build(self, input_shape): 
        self.w = self.add_weight("w",shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True) # Parameter named "w" is compulsory or an error will be thrown out
        self.b = self.add_weight("b",shape=(self.units,),
                                 initializer='random_normal',
                                 trainable=True)
        super().build(input_shape) # Identical to self.built = True

    # The logic of forward propagation is defined in call method, and is called by __call__ method
    @tf.function
    def call(self, inputs): 
        return tf.matmul(inputs, self.w) + self.b
    
    # Use customized get-config method to save the model as h5 format, specifically for the model composed through Functional API with customized Layer
    def get_config(self):  
        config = super().get_config()
        config.update({'units': self.units})
        return config

tf.keras.backend.clear_session()

model = models.Sequential()
# Note: the input_shape here will be modified by the model, so we don't have to fill None in the dimension representing the number of samples.
model.add(Linear(units = 1,input_shape = (2,)))  
print("model.input_shape: ",model.input_shape)
print("model.output_shape: ",model.output_shape)
model.summary()
@lsl1229840757
Copy link

"input_shape" is the attribute of it's superclass, which is defined in init method of layers.Layer.

when we call the init method of Linear, "input_shape = (2,)" is regarded as an item of "**kwargs" which is keyword args

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants