-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make output types jax-array #187
Conversation
Codecov Report
@@ Coverage Diff @@
## main #187 +/- ##
==========================================
+ Coverage 92.21% 92.24% +0.02%
==========================================
Files 48 48
Lines 3353 3353
==========================================
+ Hits 3092 3093 +1
+ Misses 261 260 -1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
scico/functional/_denoiser.py
Outdated
@@ -110,7 +111,7 @@ def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: | |||
# undo squeezing, if neccessary | |||
y = y.reshape(x_in_shape) | |||
|
|||
return y | |||
return jax.device_put(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type here was a numpy ndarray? I would have expected hcb.call
to keep things within the jax world.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, the device_put wasn't necessary here. Fixed it.
Fixes output types of svmbir and bm3d to be always jaxarrays. Fixes issue #93