Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

RFC: Elegy/Treex Ecosystem Next Versions #69

Open
cgarciae opened this issue Mar 29, 2022 · 8 comments
Open

RFC: Elegy/Treex Ecosystem Next Versions #69

cgarciae opened this issue Mar 29, 2022 · 8 comments

Comments

@cgarciae
Copy link
Owner

cgarciae commented Mar 29, 2022

Here are some ideas for the Treeo, Treex, and Elegy libraries which hopefully add some quality-of-life improvements so they can stand the test of time a bit better.

Immutability

Treeo/Treex has adopted a mutable/stateful design in favor of simplicity. While careful propagation of the mutated state inside jitted functions guarantees an overall immutable behaviour thanks to pytree cloning, there are some downsides:

  • Asymmetry between traced (jited, vmaped, etc) and non-traced functions, stateful operations could mutate the original object in non-traced functions while this wouldn't happen in traced functions.
  • There are no hints for the user that state needs to be propagated.

Proposal

Add an Immutable mixin in Treeo and have Treex use it for its base Treex class, this work already started in cgarciae/treeo#13 and will do the following:

  1. Enforces immutability via __setattr__ by raising a RuntimeError when a field being updated.
  2. Exposes a replace(**kwargs) -> Tree methods that let you replace the values for desired fields but returns a new object.
  3. Exposes a mutable(method="__call__")(*args, **kwargs) -> (output, Tree) method that lets call another method that includes mutable operations in an immutable fashion.

Creating an immutable Tree via the Immutable mixing would look like this:

import treeo as to

class MyTree(to.Tree, to.Immutable):
    ...

Additionally Treeo could also expose an ImmutableTree class so if users are not comfortable with mixins they could do it like this:

class MyTree(to.ImmutableTree):
   ...

Examples

Field updates

Mutably you would update a field like this:

tree.n = 10

Whereas in the immutable version you use replace and get a new tree:

tree = tree.replace(n=10)
Stateful Methods

Now if your Tree class had some stateful method such as:

def acc_sum(self, x):
    self.n += x
    return self.n

Mutably you could simply use it like this:

output = tree.acc_sum(x)

Now if your tree is immutable you would use mutable which let you run this method but the update are capture in a new instance which is returned along with the output of the method:

output, tree = tree.mutable(method="acc_sum")(x)

Alternatively you could also use it as a function transformation via treeo.mutable like this:

output, tree = treeo.mutable(tree.acc_sum)(tree, x)

Random State

Treex's Modules currently treat random state simply as internal state, because its hidden its actually a bit more difficult to reason about and can cause a variety of issues such as:

  • Changing state when you don't want it to do so
  • Freezing state by accident if you forget to propagate updates

Proposal

Remove the Rng kind and create an apply method similar (but simpler) to Flax's apply with the following signature:

def apply(
    self, 
    key: Optional[PRNGKey], 
    *args, 
    method="__call__",
    mutable: bool = True,
    **kwargs
) -> (Output, Treex)

As you see this method accepts an optional key as its first argument and then just the *args and **kwargs for the function. Regular usage would change from:

y = model(x)

to

y, model = model.apply(key, x)

However, if the module is stateless and doesn't require RNG state you can still call the module directly.

Losses and Metrics

Current Losses and Metrics in Treex (which actually come from Elegy) are great! Since losses and metrics are mostly just Pytree with simple state, it would be nice if one could extract them into their own library and with some minor refactoring build a framework independent losses and metrics library that could be used by anyone in the JAX ecosystem. We could eventually create a library called jax_tools (or something) that contains utilities such as a Loss and Metric interface + implementations of common losses and metrics, and maybe other utilities.

As for the Metric API, I was recently looking a the clu from the Flax team and found some nice ideas that could make the implementation of distributed code simpler.

Proposal

Make Metic immutable and update its API to:

class Metric(ABC):
     @abstractmethod
    def update(self: M, **kwargs) -> M:
        ...

    @abstractmethod
    def reset(self: M) -> M:
        ...

    @abstractmethod
    def compute(self) -> tp.Any:
        ...
        
    @abstractmethod
    def aggregate(self: M) -> M:
        ...
        # could even default to:
        # jax.tree_map(lambda x: jnp.sum(x, axis=0), self)

    @abstractmethod
    def merge(self: M, other: M) -> M:
        stacked = jax.tree_map(lambda *xs: jnp.stack(xs), self, other)
        return stacked.aggregate()

    def batch_updates(self: M, **kwargs) -> M:
        return self.reset().update(**kwargs)

Very similar to the Keras API with the exception of the aggregate method which is incredibly useful when syncing devices on a distributed setup.

Elegy Model

Nothing concrete for the moment, but looking thinking Pytorch Lightning-like architecture which would have the following properties:

  • The creation of an ElegyModule class (analogous to the LightningModule) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo's Kind system would not be used now.
  • Model will now be a regular non-pytree Python object that would contain a state: ElegyModule field that it would maintain and update inplace.
@cgarciae
Copy link
Owner Author

@lkhphuc
Copy link
Contributor

lkhphuc commented Apr 2, 2022

I like the proposal and excited with the upcoming changes.

Immutability and RNGs

I think Immutability and therefore apply(*, rng) are the right move forward.
Jax and its ecosystem is is highly compositional by nature of functional programming, so inevitably I will need to use a tool provided from another Jax-based library.
Been training model using elegy for a while, I have created various bugs around the interface of implicit stateful, which I think I would not make with the immutable approach.

Model will now be a regular non-pytree Python object that would contain a state: ElegyModule field that it would maintain and update inplace.

Elegy

Should it be called a Trainer then, like PytorchLightning?

Loss and metrics

I don't use the current loss and metric module as I mostly use low-level API. A separate jax-based library ideally would provide a functional API as well, so that could be used in the low-level API as well as in barebone or other frameworks.

A bit related is a first-class support for probabilistic programming in elegy as well, i.e Distrax. I have increasingly replace my loss functions with like p_x.log_prob(x) etc.
Currently it already mostly work thanks to the Jax-based functional design. However there is still some sharp edges, like cannot use with .summary() method, etc.

@cgarciae
Copy link
Owner Author

cgarciae commented Apr 5, 2022

@lkhphuc After a bit of work, #70 is passing. This reworks all of Treex abstraction to adopt an immutable API as proposed here. Here is an example using apply:

treex/examples/cnn.py

Lines 39 to 53 in e73ac40

def loss_fn(
params: tp.Optional[Model],
key: tp.Optional[jnp.ndarray],
model: Model,
losses_and_metrics: tx.LossesAndMetrics,
x: jnp.ndarray,
y: jnp.ndarray,
) -> tp.Tuple[jnp.ndarray, tp.Tuple[Model, tx.LossesAndMetrics]]:
if params is not None:
model = model.merge(params)
preds, model = model.apply(key, x)
loss, losses_and_metrics = losses_and_metrics.loss_and_update(target=y, preds=preds)
return loss, (model, losses_and_metrics)

This applies for the Optimizer as well:

treex/examples/cnn.py

Lines 57 to 75 in e73ac40

def train_step(
key: jnp.ndarray,
model: Model,
optimizer: tx.Optimizer,
losses_and_metrics: tx.LossesAndMetrics,
x: jnp.ndarray,
y: jnp.ndarray,
) -> tp.Tuple[Model, tx.Optimizer, tx.LossesAndMetrics]:
print("JITTTTING")
params = model.parameters()
grads, (model, losses_and_metrics) = jax.grad(loss_fn, has_aux=True)(
params, key, model, losses_and_metrics, x, y
)
params, optimizer = optimizer.update(grads, params)
model = model.merge(params)
return model, optimizer, losses_and_metrics

toplevel_immutable

It wasn't in this proposal but, when train_step, test_step, and friends are used as methods of a Module/Tree, then it become a little cumbersome and error prone to use .replace everywhere to update self. Instead I experimented with this toplevel_mutable decorator which in this next example would create a copy of and temporarily make self mutable while keeping all subtrees immutable:

@jax.jit
@tx.toplevel_mutable
def train_step(
self: M,
x: jnp.ndarray,
y: jnp.ndarray,
) -> M:
print("JITTTTING")
params = self.module.parameters()
loss_key, self.key = jax.random.split(self.key)
grads, self = jax.grad(self.loss_fn, has_aux=True)(params, loss_key, x, y)
params, self.optimizer = self.optimizer.update(grads, params)
self.module = self.module.merge(params)
return self

From the outside this pattern looks really nice:

model = model.train_step(x, y)

Similar to mutable, methods decorated with toplevel_mutable don't modify the original object so the API is kept immutable from an outside perspective.

@nalzok
Copy link

nalzok commented May 13, 2022

Hello, I just found this awesome library so my opinion is probably not very important, but here are my two cents:

Now if your tree is immutable you would use mutable which let you run this method but the update are capture in a new instance which is returned along with the output of the method:

output, tree = tree.mutable(method="acc_sum")(x)

Alternatively you could also use it as a function transformation via treeo.mutable like this:

output, tree = treeo.mutable(tree.acc_sum)(tree, x)

I think it makes more sense to call the method mutate instead of mutable. Generally, methods should be named after verbs. Please consider changing the name before you make a release!

Should it be called a Trainer then, like PytorchLightning?

Yeah, I was confused for a moment by the name elegy.Model since machine learning "models" typically aren't bundled with loss, metrics, and optimizer. elegy.Trainer sounds like a great name.

The creation of an ElegyModule class (analogous to the LightningModule) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo's Kind system would not be used now.

Could you elaborate on why Treeo's kind system will no longer work? As you said, ElegyModule is just a PyTree, which I assume treeo can work seamlessly with.

@cgarciae
Copy link
Owner Author

Hey @nalzok thanks for taking the time to write this, opinions of any kind are welcomed! This comment will also serve as an update of how implementation evolved:

I think it makes more sense to call the method mutate instead of mutable

Given the proposal also had an apply method, ultimately it was simpler to have a mutable: bool argument in apply which by default is True so previous example look identical with apply.

Yeah, I was confused for a moment by the name elegy.Model since machine learning "models" typically aren't bundled with loss, metrics, and optimizer. elegy.Trainer sounds like a great name.

I too like the name Trainer, however I am hesitant to make the change since it will break code that just uses the high-level API. Maybe we could rename it to Trainer and have Model as an alias for backward compatibility.

Could you elaborate on why Treeo's kind system will no longer work? As you said, ElegyModule is just a PyTree, which I assume treeo can work seamlessly with.

The thing is that Treeo Kind's are additional metadata that is added to the pytree leaves in order create more powerful filters, this mirrored Flax's collections. While they simplified parts of the implementation a lot, users have to learn this additional framework. The solution is to have regular pytree and have the user override a couple of additional methods (this can be automated for specific frameworks.

This is currently being implemented in poets-ai/elegy#232, here is an update the resulting APIs:

API Methods Description
Core API train_step, test_step, pred_step, init_step User has full control, max flexibility, no logging or distributed strategies for free.
Managed API managed_train_step, managed_test_step, managed_pred_step, managed_init_step Similar to Pytorch Lightning thus sufficiently flexible, gets logging and distributed strategies, has to define methods that specify how to get/set parameters and batch statistics.
High Level API init, apply User just specifies how to perform initialization and forward pass, gets get losses, metrics, distributed strategies, and logging for free, has to define methods that specify how to get/set parameters and batch statistics. Note: This API is mostly used to simply the creation of framework-specify implementations (flax, haiku, etc), not clear if it should be exposed to users.

@nalzok
Copy link

nalzok commented May 16, 2022

Given the proposal also had an apply method, ultimately it was simpler to have a mutable: bool argument in apply which by default is True so previous example look identical with apply.

Cool. The name apply(mutable=...) is also consistent with Flax's Module.apply, just without the variables parameter. This will make things easier for those who have some experience with Flax.

Maybe we could rename it to Trainer and have Model as an alias for backward compatibility.

Yes, please. We can also emit a DeprecationWarning when the name Model is used, so that we can remove that name in a future major release.

This is currently being implemented in poets-ai/elegy#232, here is an update the resulting APIs:

I see. Currently I don't quite understand how these work due to the lack of API documentation, but hopefully we can have some detailed documentation after things stabilize a little bit.

Regarding the documentation, do you think we should deprecate re-exporting the API for Treeo and Treex, or at least discourage users from using the re-exported APIs? Just like Keras doesn't re-export the API of Tenserflow, and users still need to import tensorflow when using Keras. More concretely, I am suggesting changing the example

import jax
import optax
import elegy as eg


class MLP(eg.Module):
    @eg.compact
    def __call__(self, x):
        x = eg.Linear(300)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(10)(x)
        return x


model = eg.Model(
    module=MLP(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-5),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

to something like

import jax
import jax_metrics as jm
import optax
import treeo as to
import treex as tx
import elegy as eg


class MLP(tx.Module):
    @to.compact
    def __call__(self, x):
        x = tx.Linear(300)(x)
        x = jax.nn.relu(x)
        x = tx.Linear(10)(x)
        return x


model = eg.Model(
    module=MLP(),
    loss=[
        jm.losses.Crossentropy(),
        jm.regularizers.L2(l=1e-5),
    ],
    metrics=jm.metrics.Accuracy(),
    optimizer=optax.rmsprop(1e-3),
)

This way, users will have an easier time finding the document for a lower-level API in the corresponding package, and we don't need to duplicate the same documentation in several places.

More importantly, it can help clarify the "level" of an API, e.g. eg.Model is higher-level compared to tx.Module and jm.metrics.Accuracy since it's the overarching trainer. While re-exporting the API from the dependencies can make things more convenient because users don't need to remember which function/package comes from which package, I'm afraid it might cause conceptual confusion in the long run because users will adopt a flattened mental model (elegy, treex, treeo, optax, jax_metrics) instead of understanding the hierarchical structure of (elegy(treex(treeo), optax, jax_metrics)).

@cgarciae
Copy link
Owner Author

Regarding the documentation, do you think we should deprecate re-exporting the API for Treeo and Treex, or at least discourage users from using the re-exported APIs?

Yes, definitely. If possible I want to make Treex an optional dependency, I want Elegy to embrace the "Framework Agnostic" slogan for real. Concretely elegy.Module will not be treex.Module so people will have to import treex or whatever framework they want to use.

The question for jax_metrics is interesting, should we re-export losses, metrics, and regularizers? I am inclined to say yes.

@nalzok
Copy link

nalzok commented May 17, 2022

Ah yes, I think it's fine to re-export jax_metrics since the functions live in some submodules, i.e. we have jm.losses.Crossentropy() instead of jm.Crossentropy(). (I would say jm.losses.CrossEntropy() is a better name though, otherwise the naming convention isn't really consistent)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants