Skip to content

Add automatic giant array chunking in msgpack checkpoints.#947

Merged
copybara-service[bot] merged 1 commit intogoogle:masterfrom
levskaya:msgpackfix
Jan 28, 2021
Merged

Add automatic giant array chunking in msgpack checkpoints.#947
copybara-service[bot] merged 1 commit intogoogle:masterfrom
levskaya:msgpackfix

Conversation

@levskaya
Copy link
Collaborator

msgpack can only support total leaf encoded buffers sizes of max length 2^32-1 Some giant embedding arrays exceed this, so add an automatic reversible array chunking pass to msgpack serialization.

This PR should -not- break compatibility with existing checkpoints.

@codecov-io
Copy link

codecov-io commented Jan 27, 2021

Codecov Report

Merging #947 (4ecb68a) into master (836946a) will increase coverage by 0.08%.
The diff coverage is 88.88%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #947      +/-   ##
==========================================
+ Coverage   80.93%   81.01%   +0.08%     
==========================================
  Files          55       55              
  Lines        4421     4473      +52     
==========================================
+ Hits         3578     3624      +46     
- Misses        843      849       +6     
Impacted Files Coverage Δ
flax/serialization.py 86.33% <88.88%> (+1.01%) ⬆️
flax/linen/module.py 95.01% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 836946a...4ecb68a. Read the comment docs.

Copy link
Contributor

@jheek jheek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good just some minor comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always a bit worried about the serialisation logic operating on jax arrays. Should we just move things to numpy at the very start of serialisation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, added a pass to do this first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this loop duplicates the tuple_to_dict and range logic. I think it's nicer to have:

chunks = [flatarr[i: i + chunksize] for i in range(0, flatarr.size, chunksize)]
data['chunks'] = _tuple_to_dict(chunks)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, agreed - replaced it.

msgpack can only support total leaf encoded buffers sizes
of max length 2^32-1  Some giant embedding arrays exceed
this, so add an automatic reversible array chunking pass
to msgpack serialization.

This PR does -not- break compatibility with existing
checkpoints.
@copybara-service copybara-service bot merged commit 09a7c91 into google:master Jan 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants