From 01a298d1bc428f2e663c58d8509e80ec5c73e82c Mon Sep 17 00:00:00 2001 From: Avital Oliver Date: Mon, 23 Mar 2020 14:07:10 +0100 Subject: [PATCH 1/2] Improve FAQ I fixed some formatting issues and clarified some answers. I think the FAQ could still use some more love though... --- docs/faq.md | 49 ++++++++++++++++++++----------------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index 253315da3..266cdc962 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -3,6 +3,7 @@ This FAQ is compiled from various questions asked on github issues, mailing list, chatrooms, and from personal conversations. --- + **Question: How to initialize a Flax model?** **Context:** Modules seem to have just one function `apply` that takes as @@ -12,8 +13,9 @@ additional parameters, so initialization happens by only calling `init_by_shape`. **Answer:** You'll typically use a `init_by_shape` call to init models as it -doesn't perform any actual computation - it just traces all the shapes and inits -submodules. +runs `apply` in "init mode", which doesn't perform any actual computation - +it just traces all the shapes and inits submodules. + --- **Question: The Model abstraction is very lightweight, is it necessary?** @@ -22,14 +24,17 @@ submodules. of `init`/`apply` methods? **Answer:** Modules have an `init` and `call` function that can be used if you -need use them. So you can do things like `Dense.call(params, X, ...)`. +need use them. So you can do things like `Dense.call(params, X, ...)`. `Model` just +wraps parameters and the `apply` function together in a way that's JAX-aware, +so you can just pass a model instance into JAX-transformed functions +without thinking about `static_argnums`. --- **Question: Does Flax name the tensors like Tensorflow? (e.g. `orthogonal_conv/kernel:0`)?** -**Answer:** Parameters are just tensors so they don't really have a name. +**Answer:** Parameters are just numpy arrays so they don't really have a name. We do use a path notation in a few places which is based on the nested dict structure that models use: ```python @@ -46,7 +51,7 @@ structure that models use: **Answer:** You can directly write submodules and just nest them in a higher module - Flax takes care of submodule initialization for you based on tracing shapes, etc. for initializing them. For instance, the Conv and Batchnorm layers -inside the resnet model are submodules themselves. +inside the resnet model are submodules themselves. --- @@ -95,7 +100,8 @@ The models.ResNet.partial call is a little different but similar in spirit - it' **Question: What is `nn.stateful()` for?** -**Answer:** Example: with `nn.stateful() as init_state`. +**Answer:** Example: `with nn.stateful() as init_state`. + Flax uses with scopes to manage state (like batchnorm statistics) and JAX's functional RNGs for stochastic layers (like dropout). It's a bit weird compared to pytorch state and RNG but pretty simple to use at the top level and @@ -127,16 +133,6 @@ Foo2 and Foo3 you don't need to rewrite the module -- you can just "hack away". --- -**Question: How to save and load model parameters in Flax?*** -**Answer:** A simple way of doing this is along the lines of: -```python -for param_name in trained_model.params: - if param_name in new_model.params: - new_model.params[param_name] = trained_model.params[param_name] -``` - ---- - **Question: When should I use Module.shared() and when not?** **Answer:** Iterating over a submodule in a module function may lead to errors @@ -167,16 +163,6 @@ something like this: --- -**Question: How to execute part of a computation only every X iterations?** - -**Answer:** You could pass in a counter or a flag boolean to the model that -causes the computation to be run or not. If it's a JAX `static_arg` you can -just write normal python if-else, if you want there to be only a single compute -graph you can use the more cumbersome `lax.cond...` though unless you know you -really need it maybe better to do the former. - ---- - **Question: How to perform computations for a Flax module only occasionally?** **Context:** I am trying to create a new Flax module that instantiates a @@ -195,11 +181,15 @@ You could write something like this in your train function: orthogonalized_model = jax.tree_map(orthogonalize_param, optimizer.target) optimizer = optimizer.replace(target=orthogonalized_model) ``` + --- -**Question: How to get the full module's parameters?** +**Question: How to get a submodule's parameters?** **Answer:** + +Assuming you're doing this within a module, + ```python nn.Embed(.., name='vocab') embedding_matrix = self.get_param('vocab')['embedding'] @@ -242,9 +232,10 @@ precision in cases like batchnorm. --- -**Question: Do FLAX models have static or dynamic shapes?** +**Question: Do Flax models have static or dynamic shapes?** **Answer:** They have static shapes. A model is created from an initial shape, -and it is not directly possible to change this. +and it is not directly possible to change this. (This is a limitation of XLA +and thus of JAX) --- From ef851aaa823772093eae8e6a3d5d6b0e7458ca78 Mon Sep 17 00:00:00 2001 From: Avital Oliver Date: Mon, 23 Mar 2020 14:48:34 +0100 Subject: [PATCH 2/2] Responses to @marcvanzee's feedback --- docs/faq.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index 266cdc962..8b6c32891 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -23,8 +23,8 @@ it just traces all the shapes and inits submodules. **Context:** Why not having a separate function for defining `params` and a pair of `init`/`apply` methods? -**Answer:** Modules have an `init` and `call` function that can be used if you -need use them. So you can do things like `Dense.call(params, X, ...)`. `Model` just +**Answer:** Modules have an `init` and `call` function that can be used if needed. +So you can do things like `Dense.call(params, X, ...)`. `Model` just wraps parameters and the `apply` function together in a way that's JAX-aware, so you can just pass a model instance into JAX-transformed functions without thinking about `static_argnums`. @@ -236,6 +236,6 @@ precision in cases like batchnorm. **Answer:** They have static shapes. A model is created from an initial shape, and it is not directly possible to change this. (This is a limitation of XLA -and thus of JAX) +and thus of JAX.) ---