Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test for svmbir interface hang on GPU device #344

Closed
bwohlberg opened this issue Sep 21, 2022 · 2 comments
Closed

Test for svmbir interface hang on GPU device #344

bwohlberg opened this issue Sep 21, 2022 · 2 comments
Assignees
Labels
bug Something isn't working tests Pertaining to SCICO tests

Comments

@bwohlberg
Copy link
Collaborator

The test functions test_prox and test_prox_weights in scico/test/linop/test_radon_svmbir.py hang when run on a GPU. A temporary workaround using @pytest.mark.skipif will be implemented, but the cause of the problem needs to be determined and addressed.

@bwohlberg bwohlberg added bug Something isn't working tests Pertaining to SCICO tests labels Sep 21, 2022
bwohlberg pushed a commit that referenced this issue Sep 21, 2022
bwohlberg added a commit that referenced this issue Sep 21, 2022
* Ensure pip versions of jaxlib/jax installed

* Workaround for #344
@bwohlberg
Copy link
Collaborator Author

bwohlberg commented Apr 27, 2023

The tests in question call scico.test.functional.prox_test, which calls scico.test.functional.prox_solve,
which calls scico.solver.minimize

def prox_solve(v, v0, f, alpha):
"""Evaluate the alpha-scaled proximal operator of f at v, using v0 as an
initial point for the optimization."""
fnc = lambda x: prox_func(x, v, f, alpha)
fmn = minimize(
fnc,
v0,
method="Nelder-Mead",
options={"maxiter": 1000, "xatol": 1e-9, "fatol": 1e-9},
)
return fmn.x.reshape(v.shape), fmn.fun

When run on a GPU, the code appears to hang at the jax.experimental.host_callback call in scico.solver.minimize

scico/scico/solver.py

Lines 243 to 248 in 090e607

# HCB call with side effects to get the OptimizeResult on the same device it was called
res.x = hcb.call(
fun,
arg=x0,
result_shape=x0, # From Jax-docs: This can be an object that has .shape and .dtype attributes
)

@bwohlberg bwohlberg self-assigned this May 8, 2023
@bwohlberg
Copy link
Collaborator Author

Closing since the problem seems to be within the jax host callback implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working tests Pertaining to SCICO tests
Projects
None yet
Development

No branches or pull requests

1 participant