Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closure scoping for nested objax.Functions #237

Closed
mathDR opened this issue Mar 16, 2022 · 2 comments
Closed

Closure scoping for nested objax.Functions #237

mathDR opened this issue Mar 16, 2022 · 2 comments
Assignees

Comments

@mathDR
Copy link

mathDR commented Mar 16, 2022

Hi, I am trying to optimize a list of Gaussian Process models, where I create an optimizer for each "model" in the list, then run the optimizer over a loop.

This is not working. So I concocted an example that (if I can figure out how to do it) would illuminate a lot of what is wrong with my code.

I want to extend the objax.Function() example by doing something like:

import objax
import jax.numpy as jn

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

new_funcs = []
for m in models:
    new_funcs.append(objax.Function(f1,m.vars()))

I know this example doesn't work (for a lot of reasons), but I am trying to understand how to make it work. That is: how can I apply models[i].vars() to f1 inside the new_funcs loop, so that when I run new_funcs[0](x,y) and new_funcs[1](x,y), I get different values?

Because of python's closure scoping I think each function in new_models is just the last call to models, right?

@AlexeyKurakin AlexeyKurakin self-assigned this Mar 16, 2022
@AlexeyKurakin
Copy link
Member

Would something like following work:

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

def loss(x, y, m):
    return ((m(x) - y) ** 2).mean()

new_funcs = []
for m in models:
    new_funcs.append(objax.Function(lambda x, y, model=m: loss(x, y, model), m.vars()))

another alternative:

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

new_funcs = []
for m in models:
    def loss(x, y, model=m):
        return ((model(x) - y) ** 2).mean()
    new_funcs.append(objax.Function(loss, m.vars()))

See also example of how to make python loop works with closure:

@mathDR
Copy link
Author

mathDR commented Mar 16, 2022

This is perfect. Thanks!

@mathDR mathDR closed this as completed Mar 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants