Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use apply with additional parameters? #723

Closed
Lando-L opened this issue Sep 5, 2023 · 1 comment
Closed

How to use apply with additional parameters? #723

Lando-L opened this issue Sep 5, 2023 · 1 comment

Comments

@Lando-L
Copy link

Lando-L commented Sep 5, 2023

I am having trouble understanding the apply method when used with additional parameters, such as with the hk.nets.MLP module.

I tried to implement a simple CNN architecture of several blocks of convolutions followed by a few layers of fully connected layers. Describing the network architecture and initialising the parameters works as expected.

class SimpleCNN(hk.Module):
    def __init__(self):
        self.cnn_blocks = ...
        self.fc_layers = hk.MLP([...])

    def __call__(self, inputs, dropout_rate, rng):
        return self.fc_layers(self.cnn_blocks(inputs), dropout_rate, rng)

def forward(inputs, dropout_rate, rng):
    return SimpleCNN()(inputs, dropout_rate, rng)

net = hk.transform(forward)
params = net.init(sample_rng, sample, None, None)

When it comes to applying the network I am getting confused. As I understand it, the apply methods takes the parameters params and rng, followed by the parameters the untransformed function is expecting. In the case of SimpleCNN that would be inputs, dropout_rate, and rng. Do I have to pass in two rng? If so, then what is the purpose of the first rng?

@Lando-L
Copy link
Author

Lando-L commented Sep 16, 2023

It looks like I missed haiku's random API.

@Lando-L Lando-L closed this as completed Sep 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant