Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a good way to save/load & compress/decompress model weights? #18

Closed
chris-chris opened this issue Mar 1, 2020 · 10 comments
Closed
Assignees

Comments

@chris-chris
Copy link
Contributor

Hey- This is Chris.
I'm using this open-source for my project.

https://github.com/chris-chris/haiku-scalable-example

Since I'm new to JAX and haiku, I have some questions.

Is there a good way to save/load & compress/decompress & serialize model weights?

  • save/load model (network only or weight only)
  • compress/decompress weights
  • serialize

I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?

Thanks!

@trevorcai
Copy link
Contributor

Hey, we're intentionally un-opinionated here. I will note:

  1. Haiku params (and network state) are transparent dictionaries of JAX jnp.ndarrays.
  2. jnp.ndarray converts to np.ndarray, so when using non-bfloat16 types, anything that works to save NumPy will work here.

There are a few options we've seen work well:

  • Directly pickle the params dict. Upside: it just works, downside: may not be totally efficient, and has usual pickle caveats.
  • Use np.save or np.savez to store the ndarrays in a flat format, and save the tree structure via either pickle or a stable serialized format (protobuf, json, yaml, you name it.)

I'll look into extending either the Transformer or ResNet example with checkpointing, so we have a concrete piece of code that we can point people to as an example.

I'll leave this bug open conditioned on that - hope this helps!

@chris-chris
Copy link
Contributor Author

Thanks for the help! @trevorcai

Your advice helped me a lot!

I'm planning to try serialization via protobuf over gRPC communication.
and for the checkpointing, I'll wait for your examples :)

@trevorcai trevorcai self-assigned this Mar 2, 2020
@chris-chris
Copy link
Contributor Author

chris-chris commented Mar 4, 2020

https://github.com/chris-chris/haiku-scalable-example/pull/1/files

Thanks! @trevorcai

I made encoder & decoder for the haiku model weights and trajectories for gRPC protobuf message.
I noticed that you used frozendict for the data structure of model weights.
And there was a comment on this data type.

Is this data type going to be deprecated?

# TODO(lenamartens) Deprecate type

https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/data_structures.py#L80

@trevorcai
Copy link
Contributor

Hey, nice job! That's correct, we'd like to replace it with the FlatMapping class below it.
The bit is ready to be flipped, we'll look to flip it soon if we can.

@trevorcai
Copy link
Contributor

Quick update - it turns out the bit is not ready to be flipped, there are a couple edge cases that need to be fixed. We don't really have the time to look into this for now, so don't expect it to flip in the near future.

@asmith26
Copy link

asmith26 commented Oct 3, 2020

I've been using https://github.com/cloudpipe/cloudpickle/ which seems to be working well.

@NightMachinery
Copy link

What library is recommended for directly serializing the params dict? What are the caveats? I think adding these to the docs will be a nice addition, or at least links to other good docs on serialization in Python.

@trevorcai
Copy link
Contributor

If you use HAIKU_FLATMAPPING=0, then Haiku checkpointing is as simple as serializing dicts of np.ndarrays; any solution that works for that will work for Haiku.

The transformer example is a simple demonstration of pickle-ing the entire state:
https://github.com/deepmind/dm-haiku/blob/main/examples/transformer/train.py#L168-L218

Two years on, the vast majority of people at DeepMind use np.save to store the np.ndarrays in a flat format, and save the tree structure separately through pickle or a specialized internal format (that I don't know the details of because I use pickle).

@NightMachinery
Copy link

@trevorcai Is there an example of saving the tree structure and then loading the np.ndarrays back into it?

@trevorcai
Copy link
Contributor

def save(ckpt_dir: str, state) -> None:
 with open(os.path.join(ckpt_dir, "arrays.npy"), "wb") as f:
   for x in jax.tree_leaves(state):
     np.save(f, x, allow_pickle=False)

 tree_struct = jax.tree_map(lambda t: 0, state)
 with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
   pickle.dump(tree_struct, f)

def restore(ckpt_dir):
 with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
   tree_struct = pickle.load(f)
 
 leaves, treedef = jax.tree_flatten(tree_struct)
 with open(os.path.join(ckpt_dir, "arrays.npy"), "rb") as f:
   flat_state = [np.load(f) for _ in leaves]

 return jax.tree_unflatten(treedef, flat_state)

typed out in the comment-box so may need some adjustment to actually run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants