RFC: Elegy/Treex Ecosystem Next Versions #69
Comments
I like the proposal and excited with the upcoming changes. Immutability and RNGsI think Immutability and therefore
ElegyShould it be called a Loss and metricsI don't use the current loss and metric module as I mostly use low-level API. A separate jax-based library ideally would provide a functional API as well, so that could be used in the low-level API as well as in barebone or other frameworks. A bit related is a first-class support for probabilistic programming in elegy as well, i.e |
@lkhphuc After a bit of work, #70 is passing. This reworks all of Treex abstraction to adopt an immutable API as proposed here. Here is an example using Lines 39 to 53 in e73ac40
This applies for the Lines 57 to 75 in e73ac40
toplevel_immutableIt wasn't in this proposal but, when Lines 88 to 105 in e73ac40
From the outside this pattern looks really nice: Line 190 in e73ac40
Similar to |
Hello, I just found this awesome library so my opinion is probably not very important, but here are my two cents:
I think it makes more sense to call the method
Yeah, I was confused for a moment by the name
Could you elaborate on why Treeo's kind system will no longer work? As you said, |
Hey @nalzok thanks for taking the time to write this, opinions of any kind are welcomed! This comment will also serve as an update of how implementation evolved:
Given the proposal also had an
I too like the name
The thing is that Treeo Kind's are additional metadata that is added to the pytree leaves in order create more powerful filters, this mirrored Flax's collections. While they simplified parts of the implementation a lot, users have to learn this additional framework. The solution is to have regular pytree and have the user override a couple of additional methods (this can be automated for specific frameworks. This is currently being implemented in poets-ai/elegy#232, here is an update the resulting APIs:
|
Cool. The name
Yes, please. We can also emit a
I see. Currently I don't quite understand how these work due to the lack of API documentation, but hopefully we can have some detailed documentation after things stabilize a little bit. Regarding the documentation, do you think we should deprecate re-exporting the API for Treeo and Treex, or at least discourage users from using the re-exported APIs? Just like Keras doesn't re-export the API of Tenserflow, and users still need to import jax
import optax
import elegy as eg
class MLP(eg.Module):
@eg.compact
def __call__(self, x):
x = eg.Linear(300)(x)
x = jax.nn.relu(x)
x = eg.Linear(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
) to something like import jax
import jax_metrics as jm
import optax
import treeo as to
import treex as tx
import elegy as eg
class MLP(tx.Module):
@to.compact
def __call__(self, x):
x = tx.Linear(300)(x)
x = jax.nn.relu(x)
x = tx.Linear(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
jm.losses.Crossentropy(),
jm.regularizers.L2(l=1e-5),
],
metrics=jm.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
) This way, users will have an easier time finding the document for a lower-level API in the corresponding package, and we don't need to duplicate the same documentation in several places. More importantly, it can help clarify the "level" of an API, e.g. |
Yes, definitely. If possible I want to make Treex an optional dependency, I want Elegy to embrace the "Framework Agnostic" slogan for real. Concretely The question for |
Ah yes, I think it's fine to re-export |
Here are some ideas for the Treeo, Treex, and Elegy libraries which hopefully add some quality-of-life improvements so they can stand the test of time a bit better.
Immutability
Treeo/Treex has adopted a mutable/stateful design in favor of simplicity. While careful propagation of the mutated state inside jitted functions guarantees an overall immutable behaviour thanks to pytree cloning, there are some downsides:
Proposal
Add an
Immutable
mixin in Treeo and have Treex use it for its baseTreex
class, this work already started in cgarciae/treeo#13 and will do the following:__setattr__
by raising aRuntimeError
when a field being updated.replace(**kwargs) -> Tree
methods that let you replace the values for desired fields but returns a new object.mutable(method="__call__")(*args, **kwargs) -> (output, Tree)
method that lets call another method that includes mutable operations in an immutable fashion.Creating an immutable Tree via the Immutable mixing would look like this:
Additionally Treeo could also expose an
ImmutableTree
class so if users are not comfortable with mixins they could do it like this:Examples
Field updates
Mutably you would update a field like this:
Whereas in the immutable version you use
replace
and get a newtree
:Stateful Methods
Now if your Tree class had some stateful method such as:
Mutably you could simply use it like this:
Now if your tree is immutable you would use
mutable
which let you run this method but the update are capture in a new instance which is returned along with the output of the method:Alternatively you could also use it as a function transformation via
treeo.mutable
like this:Random State
Treex's
Module
s currently treat random state simply as internal state, because its hidden its actually a bit more difficult to reason about and can cause a variety of issues such as:Proposal
Remove the
Rng
kind and create anapply
method similar (but simpler) to Flax'sapply
with the following signature:As you see this method accepts an optional
key
as its first argument and then just the*args
and**kwargs
for the function. Regular usage would change from:to
However, if the module is stateless and doesn't require RNG state you can still call the module directly.
Losses and Metrics
Current Losses and Metrics in Treex (which actually come from Elegy) are great! Since losses and metrics are mostly just Pytree with simple state, it would be nice if one could extract them into their own library and with some minor refactoring build a framework independent losses and metrics library that could be used by anyone in the JAX ecosystem. We could eventually create a library called
jax_tools
(or something) that contains utilities such as aLoss
andMetric
interface + implementations of common losses and metrics, and maybe other utilities.As for the Metric API, I was recently looking a the clu from the Flax team and found some nice ideas that could make the implementation of distributed code simpler.
Proposal
Make
Metic
immutable and update its API to:Very similar to the Keras API with the exception of the
aggregate
method which is incredibly useful when syncing devices on a distributed setup.Elegy Model
Nothing concrete for the moment, but looking thinking Pytorch Lightning-like architecture which would have the following properties:
ElegyModule
class (analogous to theLightningModule
) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo's Kind system would not be used now.Model
will now be a regular non-pytree Python object that would contain astate: ElegyModule
field that it would maintain and update inplace.The text was updated successfully, but these errors were encountered: