Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make output types jax-array #187

Merged
merged 4 commits into from
Jan 25, 2022
Merged

Make output types jax-array #187

merged 4 commits into from
Jan 25, 2022

Conversation

smajee
Copy link
Contributor

@smajee smajee commented Jan 24, 2022

Fixes output types of svmbir and bm3d to be always jaxarrays. Fixes issue #93

@smajee smajee requested a review from bwohlberg January 24, 2022 21:32
@codecov
Copy link

codecov bot commented Jan 24, 2022

Codecov Report

Merging #187 (504225a) into main (6f7d3a4) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #187      +/-   ##
==========================================
+ Coverage   92.21%   92.24%   +0.02%     
==========================================
  Files          48       48              
  Lines        3353     3353              
==========================================
+ Hits         3092     3093       +1     
+ Misses        261      260       -1     
Flag Coverage Δ
unittests 92.24% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/linop/radon_svmbir.py 88.75% <100.00%> (ø)
scico/numpy/_util.py 97.05% <0.00%> (+2.94%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6f7d3a4...504225a. Read the comment docs.

@bwohlberg bwohlberg added the improvement Improvement of existing code, including addressing of omissions or inconsistencies label Jan 25, 2022
@@ -110,7 +111,7 @@ def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
# undo squeezing, if neccessary
y = y.reshape(x_in_shape)

return y
return jax.device_put(y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type here was a numpy ndarray? I would have expected hcb.call to keep things within the jax world.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, the device_put wasn't necessary here. Fixed it.

@bwohlberg bwohlberg linked an issue Jan 25, 2022 that may be closed by this pull request
@smajee smajee merged commit d2b3e46 into main Jan 25, 2022
@smajee smajee deleted the smajee/fix_output_dtype branch January 25, 2022 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improvement of existing code, including addressing of omissions or inconsistencies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Output types of operators
2 participants