Add automatic giant array chunking in msgpack checkpoints.#947
Add automatic giant array chunking in msgpack checkpoints.#947copybara-service[bot] merged 1 commit intogoogle:masterfrom
Conversation
Codecov Report
@@ Coverage Diff @@
## master #947 +/- ##
==========================================
+ Coverage 80.93% 81.01% +0.08%
==========================================
Files 55 55
Lines 4421 4473 +52
==========================================
+ Hits 3578 3624 +46
- Misses 843 849 +6
Continue to review full report at Codecov.
|
jheek
left a comment
There was a problem hiding this comment.
Looks good just some minor comments
flax/serialization.py
Outdated
There was a problem hiding this comment.
I'm always a bit worried about the serialisation logic operating on jax arrays. Should we just move things to numpy at the very start of serialisation?
There was a problem hiding this comment.
sure, added a pass to do this first.
flax/serialization.py
Outdated
There was a problem hiding this comment.
nit: this loop duplicates the tuple_to_dict and range logic. I think it's nicer to have:
chunks = [flatarr[i: i + chunksize] for i in range(0, flatarr.size, chunksize)]
data['chunks'] = _tuple_to_dict(chunks)
There was a problem hiding this comment.
yeah, agreed - replaced it.
msgpack can only support total leaf encoded buffers sizes of max length 2^32-1 Some giant embedding arrays exceed this, so add an automatic reversible array chunking pass to msgpack serialization. This PR does -not- break compatibility with existing checkpoints.
msgpack can only support total leaf encoded buffers sizes of max length 2^32-1 Some giant embedding arrays exceed this, so add an automatic reversible array chunking pass to msgpack serialization.
This PR should -not- break compatibility with existing checkpoints.