-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a training API #5
Comments
FWIW I like this breakout - it combines the best of PyTorch's flexibility and provides logical callbacks during the training process. Do you see the In line with what many other ML libs do, I agree that Axon should define a Protocol for datasets that any external data management library could implement support for. I also think including a basic in-memory implementation for testing, quick ramp-up, and providing a reference example would add great value. IMO PyTorch has one of the most reasonable interfaces for working with custom datasets; I'm actually in the process of migrating one project from TF to PyTorch due to the complexity of implementing custom dataset generators. |
@arpieb Yes! I actually think it would be best to adjust I've really only ever worked in TensorFlow/TFDS, but I have heard there is more of a preference for PyTorch for some of these things. I haven't really considered what kind of dataset integration Axon should have, mainly because I figured we would want to create a more general abstraction that works for the whole ecosystem, but it's definitely worth some thought to have something convenient. |
I think it's going to be hard to come up with a general abstraction for all use cases of Nx tensor collections. For example, Axon SL models will rely on sample-label pairs, which is not a requirement for most other data ops. A dataframe equivalent to pandas (exdas? nexdas? ;)) would need to be columnar to handle the mixture of numeric (Nx tensor) and non-numeric columns. Ultimately all Python-based ML solutions seem to fall back to requiring a collection (len, get_item) or iterable (iter) to satisfy dataset needs, which in turn usually return one or more numpy arrays. Even DL4j forces any kind of transform pipeline or generator to ultimately spit up something implementing the INDArray interface (their attempt at replicating numpy.ndarray) before it can be handed off to a model. Many UL models will not require labels, so 2+ rank tensors would fill the need just fine for them. |
This looks a lot like PyTorch Lightning! And that's a compliment. As much as I have distaste for much of the Python deep learning ecosystem, I think PTL has emerged from some early teething problems to be one of the most sane training APIs out there. Axon could do much worse than to emulate their approach to callbacks. I agree wholeheartedly with @arpieb's take here (and in elixir-nx/nx#301) that defining some protocol for the inevitable dataframe libraries is essential here. Dataloading and its complexities have always been a bugbear for Pytorch and Tensorflow. |
Given the core components implemented in #1, we can implement an efficient, simple, but flexible training API. I am proposing an API similar to trax.supervised.training under the
Axon.Training
namespace that represents a general supervised training pipeline for models.Training Behaviour
We can consider the training loop to take the following inputs:
model_state
- parameters, discussed in a future issue for state management and model initializationoptimizer
- encapsulates both optimizer state, and the update step, discussed in a future issuetrain_objective
(note I'm not usingTask
to avoid confusion with Elixir tasks) - an objective (loss) function parameterized by the input model such thatgrad(model_state, objective)
differentiates the model parameters w.r.t input modeleval_objective
- metrics for evaluating model performance on validation sets, loss, accuracy, mse, mae, etc. and some associated state for monitoring training proressdataset
- inputs and labelsoptions
- miscellaneousand to perform the following algorithm (this is half pseudocode, half Elixir):
It's common to use metrics as an easy way to monitor training, so we can introduce a
metrics
object which encapsulates metric state and metric evaluation functions:We can further extend this API with before_x and after_x callbacks (writing checkpoints, plotting graphs, etc.):
For more flexibility, we can extract each train step into a method, this facilitates easier writing of custom training loops:
Given this framework, the training API would have at a minimum the following callbacks:
I left a lot of key pieces out because I believe this motivates discussion about how to best separate concerns between modules to provide both maximum flexibility, as well as ease-of-use.
Objectives
The spirit of autograd is that writing a machine learning model is as simple as defining a differentiable objective function. I believe that is a principle we should stick to, so I've separated the idea of objective into what could be a separate module, behaviour, function, etc. Objectives need to encapsulate both evaluation objectives and training objectives. They need to be capable of supporting parameterization by a model. I think they should also contain information about associated metrics and evaluation criteria that's tracked during training. Objectives could possibly be defined as a behavior with two methods:
predict
andloss
whereloss
depends on predict andpredict
represents a model definition. I'm not sure I really like that idea, but objectives definitely deserve a well-thought out discussion in a separate issue.Optimizers and Updates
From a design standpoint, updates and optimizers should be included separately. However, from a performance standpoint, I think you might want to fuse gradient calculation with updates, but I believe this could be possible by silently wrapping both
update
and thegrad(objective)
in another defn somewhere becausedefn
calls are inlined and compiled. Optimizers as separate modules is a pretty common pattern, so I would go for a behaviour here with common implementations built on the primitiveupdates.ex
.State
There is a lot of state to keep track of in the above example: model state, optimizer state, metric state, evaluation state, etc. I think it makes sense to wrap state into a common API, so stateful parameters can be flexibly handled. Another advantage of implementing this is we can limit assumptions about actual state management solutions in practice. So users can choose to implement their own if they so choose.
Dataset
The above just lists a dataset as containing batches. I would basically try to represent this as a stream that can be consumed. I don't think dataset implementations fall in this library, but I think Axon should enforce some standard for what datasets look like.
Conclusion
I believe this lays out a plan for integrating higher-level APIs moving forward. Obviously this is incredibly general because it inherently requires implementation details from the unincluded aspects listed above. However, I believe starting with a training API to make sense of how to split up the rest of the work makes sense.
The text was updated successfully, but these errors were encountered: