Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace dm-tree with optree #19306

Merged
merged 9 commits into from Mar 15, 2024
Merged

Replace dm-tree with optree #19306

merged 9 commits into from Mar 15, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Mar 14, 2024

Related to #18442
Related to #18614

This PR refactors keras.utils.tree to use optree instead of dm-tree

  • Exports APIs (is_nested, flatten, unflatten_as, map_structure, map_structure_up_to, assert_same_structure, pack_sequence_as, lists_to_tuples) with the path keras.utils.tree.*
  • Add docstrings (mostly borrowed from dm-tree and tf.nest)
  • Adds unit tests
  • Eliminates tf.nest in the codebase (excluding legacy code)

I have verified that exported APIs should meet the requirements of keras_cv and keras_nlp

@codecov-commenter
Copy link

codecov-commenter commented Mar 14, 2024

Codecov Report

Attention: Patch coverage is 88.60104% with 22 lines in your changes are missing coverage. Please review.

Project coverage is 75.75%. Comparing base (c8700f4) to head (06648f8).
Report is 102 commits behind head on master.

Files Patch % Lines
keras/utils/tree.py 89.13% 11 Missing and 4 partials ⚠️
keras/backend/tensorflow/layer.py 20.00% 4 Missing ⚠️
keras/utils/tracking.py 85.71% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19306      +/-   ##
==========================================
- Coverage   80.14%   75.75%   -4.39%     
==========================================
  Files         341      366      +25     
  Lines       36163    40208    +4045     
  Branches     7116     7811     +695     
==========================================
+ Hits        28982    30460    +1478     
- Misses       5578     8062    +2484     
- Partials     1603     1686      +83     
Flag Coverage Δ
keras 75.60% <88.08%> (-4.39%) ⬇️
keras-jax 59.98% <76.68%> (-3.08%) ⬇️
keras-numpy 54.56% <73.05%> (-2.53%) ⬇️
keras-tensorflow 61.47% <85.49%> (-3.18%) ⬇️
keras-torch 60.57% <73.57%> (-3.30%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Member

Thanks for the PR. Do you observe performance difference? How long does it take to run the unit test suite with the PyTorch backend before and after the change, for instance?

@james77777778
Copy link
Contributor Author

james77777778 commented Mar 14, 2024

Thanks for the PR. Do you observe performance difference? How long does it take to run the unit test suite with the PyTorch backend before and after the change, for instance?

I didn't observe a performance difference in the unit tests.
Therefore, I benchmarked the performance using the method in #18569

Env:

  • torch==2.2.1+cu121
backend cpu/cuda jit_compile master branch optree
torch cpu False 14ms 12ms
True 17ms 15ms
cuda False 2~3ms 3ms
True 5ms 5ms

There is no difference observed when using cuda, but there is a slight improvement when using cpu
However, even with jit_compile=True, it still runs slower compared to when jit_compile=False

Logs:

[2024-03-14 13:41:05,140] torch._dynamo.convert_frame: [WARNING]    function: 'resume_in___call__' (/home/hongyu/workspace/keras/keras/layers/layer.py:695)
[2024-03-14 13:41:05,140] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,140] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,140] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:05,233] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:05,233] torch._dynamo.convert_frame: [WARNING]    function: '__call__' (/home/hongyu/workspace/keras/keras/ops/operation.py:31)
[2024-03-14 13:41:05,233] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,233] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,233] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING]    function: 'fill_in' (/home/hongyu/workspace/keras/keras/ops/symbolic_arguments.py:31)
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING]    function: 'call' (/home/hongyu/workspace/keras/keras/models/functional.py:571)
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,778] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:05,779] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:05,779] torch._dynamo.convert_frame: [WARNING]    function: '__call__' (/home/hongyu/workspace/keras/keras/layers/layer.py:692)
[2024-03-14 13:41:05,779] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,779] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,779] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:05,780] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:05,780] torch._dynamo.convert_frame: [WARNING]    function: '_setattr_hook' (/home/hongyu/workspace/keras/keras/backend/torch/layer.py:28)
[2024-03-14 13:41:05,780] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:05,780] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:05,780] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:06,635] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:06,635] torch._dynamo.convert_frame: [WARNING]    function: 'maybe_convert' (/home/hongyu/workspace/keras/keras/layers/layer.py:699)
[2024-03-14 13:41:06,635] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:06,635] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:06,635] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2024-03-14 13:41:07,629] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2024-03-14 13:41:07,629] torch._dynamo.convert_frame: [WARNING]    function: '_get_own_losses' (/home/hongyu/workspace/keras/keras/layers/layer.py:1061)
[2024-03-14 13:41:07,629] torch._dynamo.convert_frame: [WARNING]    last reason: ___check_global_state()
[2024-03-14 13:41:07,629] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2024-03-14 13:41:07,629] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.

It seems that there are still some frames that cannot be converted successfully

keras/utils/tree.py Outdated Show resolved Hide resolved
keras/utils/tracking.py Show resolved Hide resolved
keras/ops/core_test.py Outdated Show resolved Hide resolved
keras/utils/tree.py Outdated Show resolved Hide resolved
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 15, 2024
@fchollet fchollet merged commit e2b43e2 into keras-team:master Mar 15, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 15, 2024
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 18, 2024
@james77777778 james77777778 deleted the optree branch March 20, 2024 03:15
@ngam ngam mentioned this pull request Mar 21, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

None yet

4 participants