Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Trainer class #91

Closed
levithomason opened this issue Dec 21, 2015 · 1 comment
Closed

Create Trainer class #91

levithomason opened this issue Dec 21, 2015 · 1 comment

Comments

@levithomason
Copy link
Owner

The network train method is getting large. It also breaks the pattern of the Neuron train and Layer train methods. Neuron.train() update the weights, which makes sense. Layer.train() invokes the Neuron.train() methods which makes sense. Network.train() should invoke Layer.train() methods.

Because the overly complex Network.train method is there, the method that invokes the Layer.train() methods is awkwardly named correct().

The Network.train() method should be pulled into a class, Trainer. It should return a function that takes in a Network and trains it. It would also house the training options, default callbacks, and any other settings or config related to training (like batch and online training). Then, the Network.correct() method can be renamed more appropriately Network.train(). In the future, it may also support training different types of networks (convolutional, etc).

The Trainer API may end up looking something like this. It would allow us to have various training strategies.

const shortTrain = new Trainer({maxEpochs: 100})
const accurateTrain = new Trainer({errorThreshold: 0.000001})

// stops training if error is not going down
let lastError = Infinity
const improvingTrain = new Trainer({
  onProgress: (error, epoch) => {
    if (error > lastError) return false
    lastError = error
  }
})

// assume we have a net already made

shortTrain(someNetwork)
accurateTrain(someNetwork)
improvingTrain(someNetwork)

These are off the cuff toy examples to demonstrate the pattern

Something like shortTrain can be used to quickly test if a Network can train in a defined amount of time. Accurate train could be used to see if a network could reach a certain level of accuracy. Improving could be used to test how long a network could improve before having regression. All of these trainers could take in a single network config and generate performance stats for a given network in a clean and reusable way.

The trainers could can even be shared as part of challenges. See if your network can "beat the xyz trainer".

@levithomason
Copy link
Owner Author

Batch mode might be supported something like this:

  /**
   * @param {boolean|number} [options.batch] Use batch, online, or mini-batch
   *   learning modes.
   *
   *   Batch `true`: Connection weights are updated once after iterating
   *   through all the training samples in the training data (an epoch).
   *
   *   Online `false`: Connection weights are updated after every training
   *   sample in the training data.
   *
   *   Mini-batch `<number>`: Connection weights are updated every `<number>`
   *   training samples.
   */

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant