Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 21 additions & 30 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This FAQ is compiled from various questions asked on github issues, mailing list, chatrooms, and from personal conversations.

---

**Question: How to initialize a Flax model?**

**Context:** Modules seem to have just one function `apply` that takes as
Expand All @@ -12,24 +13,28 @@ additional parameters, so initialization happens by only calling
`init_by_shape`.

**Answer:** You'll typically use a `init_by_shape` call to init models as it
doesn't perform any actual computation - it just traces all the shapes and inits
submodules.
runs `apply` in "init mode", which doesn't perform any actual computation -
it just traces all the shapes and inits submodules.

---

**Question: The Model abstraction is very lightweight, is it necessary?**

**Context:** Why not having a separate function for defining `params` and a pair
of `init`/`apply` methods?

**Answer:** Modules have an `init` and `call` function that can be used if you
need use them. So you can do things like `Dense.call(params, X, ...)`.
**Answer:** Modules have an `init` and `call` function that can be used if needed.
So you can do things like `Dense.call(params, X, ...)`. `Model` just
wraps parameters and the `apply` function together in a way that's JAX-aware,
so you can just pass a model instance into JAX-transformed functions
without thinking about `static_argnums`.

---

**Question: Does Flax name the tensors like Tensorflow?
(e.g. `orthogonal_conv/kernel:0`)?**

**Answer:** Parameters are just tensors so they don't really have a name.
**Answer:** Parameters are just numpy arrays so they don't really have a name.
We do use a path notation in a few places which is based on the nested dict
structure that models use:
```python
Expand All @@ -46,7 +51,7 @@ structure that models use:
**Answer:** You can directly write submodules and just nest them in a higher
module - Flax takes care of submodule initialization for you based on tracing
shapes, etc. for initializing them. For instance, the Conv and Batchnorm layers
inside the resnet model are submodules themselves.
inside the resnet model are submodules themselves.

---

Expand Down Expand Up @@ -95,7 +100,8 @@ The models.ResNet.partial call is a little different but similar in spirit - it'

**Question: What is `nn.stateful()` for?**

**Answer:** Example: with `nn.stateful() as init_state`.
**Answer:** Example: `with nn.stateful() as init_state`.

Flax uses with scopes to manage state (like batchnorm statistics) and JAX's
functional RNGs for stochastic layers (like dropout). It's a bit weird
compared to pytorch state and RNG but pretty simple to use at the top level and
Expand Down Expand Up @@ -127,16 +133,6 @@ Foo2 and Foo3 you don't need to rewrite the module -- you can just "hack away".

---

**Question: How to save and load model parameters in Flax?***
**Answer:** A simple way of doing this is along the lines of:
```python
for param_name in trained_model.params:
if param_name in new_model.params:
new_model.params[param_name] = trained_model.params[param_name]
```

---

**Question: When should I use Module.shared() and when not?**

**Answer:** Iterating over a submodule in a module function may lead to errors
Expand Down Expand Up @@ -167,16 +163,6 @@ something like this:

---

**Question: How to execute part of a computation only every X iterations?**

**Answer:** You could pass in a counter or a flag boolean to the model that
causes the computation to be run or not. If it's a JAX `static_arg` you can
just write normal python if-else, if you want there to be only a single compute
graph you can use the more cumbersome `lax.cond...` though unless you know you
really need it maybe better to do the former.

---

**Question: How to perform computations for a Flax module only occasionally?**

**Context:** I am trying to create a new Flax module that instantiates a
Expand All @@ -195,11 +181,15 @@ You could write something like this in your train function:
orthogonalized_model = jax.tree_map(orthogonalize_param, optimizer.target)
optimizer = optimizer.replace(target=orthogonalized_model)
```

---

**Question: How to get the full module's parameters?**
**Question: How to get a submodule's parameters?**

**Answer:**

Assuming you're doing this within a module,

```python
nn.Embed(.., name='vocab')
embedding_matrix = self.get_param('vocab')['embedding']
Expand Down Expand Up @@ -242,9 +232,10 @@ precision in cases like batchnorm.

---

**Question: Do FLAX models have static or dynamic shapes?**
**Question: Do Flax models have static or dynamic shapes?**

**Answer:** They have static shapes. A model is created from an initial shape,
and it is not directly possible to change this.
and it is not directly possible to change this. (This is a limitation of XLA
and thus of JAX.)

---