Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a training API #5

Closed
seanmor5 opened this issue Mar 4, 2021 · 4 comments · Fixed by #17
Closed

Introduce a training API #5

seanmor5 opened this issue Mar 4, 2021 · 4 comments · Fixed by #17

Comments

@seanmor5
Copy link
Contributor

seanmor5 commented Mar 4, 2021

Given the core components implemented in #1, we can implement an efficient, simple, but flexible training API. I am proposing an API similar to trax.supervised.training under the Axon.Training namespace that represents a general supervised training pipeline for models.

Training Behaviour

We can consider the training loop to take the following inputs:

  • model_state - parameters, discussed in a future issue for state management and model initialization
  • optimizer - encapsulates both optimizer state, and the update step, discussed in a future issue
  • train_objective (note I'm not using Task to avoid confusion with Elixir tasks) - an objective (loss) function parameterized by the input model such that grad(model_state, objective) differentiates the model parameters w.r.t input model
  • eval_objective - metrics for evaluating model performance on validation sets, loss, accuracy, mse, mae, etc. and some associated state for monitoring training proress
  • dataset - inputs and labels
  • options - miscellaneous

and to perform the following algorithm (this is half pseudocode, half Elixir):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
    end
    evaluate(model_state, eval_objective)
  end
end

It's common to use metrics as an easy way to monitor training, so we can introduce a metrics object which encapsulates metric state and metric evaluation functions:

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)
    end
    evaluate(model_state, eval_objective)
  end
end

We can further extend this API with before_x and after_x callbacks (writing checkpoints, plotting graphs, etc.):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    before_epoch(model_state)

    for {input, target} <- dataset do
      before_batch(model_state)

      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)

      after_batch(model_state)
    end
    evaluate(model_state, eval_objective)

    after_epoch(model_state)
  end
end

For more flexibility, we can extract each train step into a method, this facilitates easier writing of custom training loops:

def train_on_batch(batch, model_state, train_objective, optimizer) do
  before_batch(model_state)

  gradients = grad(model_state, train_objective(model_state, batch))
  update(model_state, gradients, optimizer)
  metrics(train_objective(model_state, batch))

  after_batch(model_state)
end

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]
  steps = options[:steps] || :unlimited # until batch is empty

  for i <- 0..epochs do
    before_epoch(model_state)

    for batch <- dataset, until: steps do
      train_on_batch(batch, model_state, train_objective, optimizer, train_objective)
    end
    evaluate(model_state, eval_objective)

    after_epoch(model_state)
  end
end

Given this framework, the training API would have at a minimum the following callbacks:

defmodule Axon.Training do  
  # Runs before each epoch
  @callback before_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each epoch
  @callback after_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs before each batch
  @callback before_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each batch
  @callback after_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs a single train step, this can also be `defn` for working with infeed/outfeed
  @callback train_on_batch(batch, model_state, train_objective, optimizer) :: model_state

  # Runs a training loop to convergence
  @callback train(model_state, optimizer, train_objective, eval_objective, dataset, options) :: {:ok, model_state} | {:error, reason}
end

I left a lot of key pieces out because I believe this motivates discussion about how to best separate concerns between modules to provide both maximum flexibility, as well as ease-of-use.

Objectives

The spirit of autograd is that writing a machine learning model is as simple as defining a differentiable objective function. I believe that is a principle we should stick to, so I've separated the idea of objective into what could be a separate module, behaviour, function, etc. Objectives need to encapsulate both evaluation objectives and training objectives. They need to be capable of supporting parameterization by a model. I think they should also contain information about associated metrics and evaluation criteria that's tracked during training. Objectives could possibly be defined as a behavior with two methods: predict and loss where loss depends on predict and predict represents a model definition. I'm not sure I really like that idea, but objectives definitely deserve a well-thought out discussion in a separate issue.

Optimizers and Updates

From a design standpoint, updates and optimizers should be included separately. However, from a performance standpoint, I think you might want to fuse gradient calculation with updates, but I believe this could be possible by silently wrapping both update and the grad(objective) in another defn somewhere because defn calls are inlined and compiled. Optimizers as separate modules is a pretty common pattern, so I would go for a behaviour here with common implementations built on the primitive updates.ex.

State

There is a lot of state to keep track of in the above example: model state, optimizer state, metric state, evaluation state, etc. I think it makes sense to wrap state into a common API, so stateful parameters can be flexibly handled. Another advantage of implementing this is we can limit assumptions about actual state management solutions in practice. So users can choose to implement their own if they so choose.

Dataset

The above just lists a dataset as containing batches. I would basically try to represent this as a stream that can be consumed. I don't think dataset implementations fall in this library, but I think Axon should enforce some standard for what datasets look like.

Conclusion

I believe this lays out a plan for integrating higher-level APIs moving forward. Obviously this is incredibly general because it inherently requires implementation details from the unincluded aspects listed above. However, I believe starting with a training API to make sense of how to split up the rest of the work makes sense.

@arpieb
Copy link
Contributor

arpieb commented Mar 10, 2021

FWIW I like this breakout - it combines the best of PyTorch's flexibility and provides logical callbacks during the training process. Do you see the eval_objective providing a mechanism for early stopping, or should that be another callback?

In line with what many other ML libs do, I agree that Axon should define a Protocol for datasets that any external data management library could implement support for. I also think including a basic in-memory implementation for testing, quick ramp-up, and providing a reference example would add great value.

IMO PyTorch has one of the most reasonable interfaces for working with custom datasets; I'm actually in the process of migrating one project from TF to PyTorch due to the complexity of implementing custom dataset generators.

@seanmor5
Copy link
Contributor Author

@arpieb Yes! I actually think it would be best to adjust after_epoch to take the evaluation state as well so early stopping can be implemented as a callback. That would also be useful for implementing model checkpointing based on validation metrics.

I've really only ever worked in TensorFlow/TFDS, but I have heard there is more of a preference for PyTorch for some of these things. I haven't really considered what kind of dataset integration Axon should have, mainly because I figured we would want to create a more general abstraction that works for the whole ecosystem, but it's definitely worth some thought to have something convenient.

@arpieb
Copy link
Contributor

arpieb commented Mar 10, 2021

I think it's going to be hard to come up with a general abstraction for all use cases of Nx tensor collections. For example, Axon SL models will rely on sample-label pairs, which is not a requirement for most other data ops. A dataframe equivalent to pandas (exdas? nexdas? ;)) would need to be columnar to handle the mixture of numeric (Nx tensor) and non-numeric columns.

Ultimately all Python-based ML solutions seem to fall back to requiring a collection (len, get_item) or iterable (iter) to satisfy dataset needs, which in turn usually return one or more numpy arrays. Even DL4j forces any kind of transform pipeline or generator to ultimately spit up something implementing the INDArray interface (their attempt at replicating numpy.ndarray) before it can be handed off to a model.

Many UL models will not require labels, so 2+ rank tensors would fill the need just fine for them.

@cigrainger
Copy link
Member

This looks a lot like PyTorch Lightning! And that's a compliment. As much as I have distaste for much of the Python deep learning ecosystem, I think PTL has emerged from some early teething problems to be one of the most sane training APIs out there. Axon could do much worse than to emulate their approach to callbacks.

I agree wholeheartedly with @arpieb's take here (and in elixir-nx/nx#301) that defining some protocol for the inevitable dataframe libraries is essential here. Dataloading and its complexities have always been a bugbear for Pytorch and Tensorflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants