forked from fff-rs/juice
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add trainer subsystem with SGD and Adam optimizers (fff-rs#177)
Co-authored-by: Mikhail Balakhno <{ID}+{username}@users.noreply.github.com> Co-authored-by: Bernhard Schuster <bernhard@ahoi.io>
- Loading branch information
1 parent
77eb655
commit caf87f2
Showing
7 changed files
with
569 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod trainer; | ||
mod optimizer; | ||
|
||
pub use optimizer::*; | ||
pub use trainer::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
//! Adam optimizer. | ||
//! | ||
//! Computes the update Vᵢ from params gradient ∇ᵢ as: | ||
//! Mᵢ = β₁Mᵢ₋₁ + (1-β₁)∇ᵢ, | ||
//! Sᵢ = β₂Sᵢ₋₁ + (1-β₂)∇ᵢ⊙∇ᵢ, | ||
//! M₀ = 0, | ||
//! S₀ = 0, | ||
//! M̂ᵢ = Mᵢ/(1-β₁ᵗ), | ||
//! Ŝᵢ = Sᵢ/(1-β₂ᵗ), | ||
//! Vᵢ = M̂ᵢ⊘(√Ŝᵢ+ε), | ||
//! where: | ||
//! ⊙ - pointwise multiplication, | ||
//! ⊘ - pointwise division, | ||
//! β₁, β₂ - averaging parameters (typically set to 0.9 and 0.999 respectively), | ||
//! ε - small constant to prevent division by zero (typically 1e-8). | ||
//! | ||
//! (Note that the update Vᵢ is then additionally scaled by Trainer using global and param-specific | ||
//! learning rates.) | ||
|
||
use std::cell::RefCell; | ||
use std::collections::HashMap; | ||
use std::rc::Rc; | ||
|
||
use crate::train::Optimizer; | ||
use crate::util::native_backend; | ||
use crate::weight::FillerType; | ||
use co::prelude::*; | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct AdamConfig { | ||
pub beta1: f32, | ||
pub beta2: f32, | ||
pub epsilon: f32, | ||
} | ||
|
||
pub struct Adam { | ||
// First gradient moment (Mᵢ). | ||
first_moments: HashMap<usize, SharedTensor<f32>>, | ||
// Second gradient moment (Sᵢ). | ||
second_moments: HashMap<usize, SharedTensor<f32>>, | ||
|
||
// Original β₁ as well as raised to t-th power (β₁ᵗ). | ||
beta1: f32, | ||
beta1_nth: f32, | ||
// Original β₂ as well as raised to t-th power (β₂ᵗ). | ||
beta2: f32, | ||
beta2_nth: f32, | ||
|
||
epsilon: f32, | ||
} | ||
|
||
impl Default for AdamConfig { | ||
fn default() -> Self { | ||
AdamConfig { | ||
beta1: 0.9, | ||
beta2: 0.999, | ||
epsilon: 1.0e-8, | ||
} | ||
} | ||
} | ||
|
||
impl Adam { | ||
pub fn new(config: &AdamConfig) -> Self { | ||
Adam { | ||
first_moments: HashMap::new(), | ||
second_moments: HashMap::new(), | ||
beta1: config.beta1, | ||
beta1_nth: config.beta1, | ||
beta2: config.beta2, | ||
beta2_nth: config.beta2, | ||
epsilon: config.epsilon, | ||
} | ||
} | ||
} | ||
|
||
// TODO: Rewrite with backend ops (requires element-wise square and square root support). | ||
impl<B: IBackend> Optimizer<B> for Adam { | ||
fn adjust_weight_change( | ||
&mut self, | ||
backend: &B, | ||
weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>, | ||
) { | ||
let native = native_backend(); | ||
|
||
for (key, change) in weight_changes { | ||
let mut change_ref = change.borrow_mut(); | ||
|
||
let first_moment = self.first_moments.entry(*key).or_insert_with(|| { | ||
let mut tensor = SharedTensor::new(change_ref.desc()); | ||
FillerType::fill_constant(&mut tensor, 0.0); | ||
tensor | ||
}); | ||
let second_moment = self.second_moments.entry(*key).or_insert_with(|| { | ||
let mut tensor = SharedTensor::new(change_ref.desc()); | ||
FillerType::fill_constant(&mut tensor, 0.0); | ||
tensor | ||
}); | ||
|
||
// Make sure the params shape didn't change under us. | ||
assert_eq!(change_ref.desc().size(), first_moment.desc().size()); | ||
assert_eq!(change_ref.desc().size(), second_moment.desc().size()); | ||
|
||
let len = change_ref.desc().size(); | ||
|
||
let change_slice = change_ref | ||
.read_write(native.device()) | ||
.unwrap() | ||
.as_mut_slice::<f32>(); | ||
let first_moment_slice = first_moment | ||
.read_write(native.device()) | ||
.unwrap() | ||
.as_mut_slice::<f32>(); | ||
let second_moment_slice = second_moment | ||
.read_write(native.device()) | ||
.unwrap() | ||
.as_mut_slice::<f32>(); | ||
|
||
// We can rewrite the matrix equations at the top of this file in a element-wise form: | ||
// Mᵢ[j] = β₁Mᵢ₋₁[j] + (1-β₁)∇ᵢ[j] | ||
// Sᵢ[j] = β₂Sᵢ₋₁[j] + (1-β₂)∇ᵢ[j]² | ||
// Vᵢ[j] = Mᵢ[j] / ((1-β₁ᵗ)•√(Sᵢ[j]/(1-β₂ᵗ) + ε) | ||
for j in 0..len { | ||
// ∇ᵢ[j]. | ||
let w = change_slice[j]; | ||
// Mᵢ[j], M̂ᵢ[j]. | ||
let m = self.beta1 * first_moment_slice[j] + (1.0 - self.beta1) * w; | ||
let m_hat = m / (1.0 - self.beta1_nth); | ||
// Sᵢ[j], Ŝᵢ[j]. | ||
let s = self.beta2 * second_moment_slice[j] + (1.0 - self.beta2) * w * w; | ||
let s_hat = s / (1.0 - self.beta2_nth); | ||
// Vᵢ[j]. | ||
let v = m_hat / (s_hat.sqrt() + self.epsilon); | ||
|
||
assert!(!v.is_nan()); | ||
|
||
change_slice[j] = v; | ||
first_moment_slice[j] = m; | ||
second_moment_slice[j] = s; | ||
} | ||
} | ||
|
||
self.beta1_nth *= self.beta1; | ||
self.beta2_nth *= self.beta2; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
mod adam; | ||
mod sgd_momentum; | ||
|
||
use std::rc::Rc; | ||
use std::cell::RefCell; | ||
use std::collections::HashMap; | ||
use std::default::Default; | ||
|
||
use crate::coblas::plugin::Copy; | ||
use co::prelude::*; | ||
use crate::util::Axpby; | ||
|
||
use adam::Adam; | ||
use sgd_momentum::SgdWithMomentum; | ||
|
||
// Expose configs publicly. | ||
pub use adam::AdamConfig; | ||
pub use sgd_momentum::SgdWithMomentumConfig; | ||
|
||
// A gradient descent optimization algorithm. | ||
pub trait Optimizer<B: IBackend> { | ||
// Called on each minibatch training cycle. Takes all weight gradients computed during | ||
// backpropagation (indexed by an opaque key which is guaranteed to be stable for the | ||
// duration of the program). | ||
// Modifies the changes in-place; modified changes will then be applied to the weights: | ||
// W = W - α•change, | ||
// where α is the learning rate (combined from global and param-specific rates). | ||
fn adjust_weight_change(&mut self, backend: &B, weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>); | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub enum OptimizerConfig { | ||
SgdWithMomentum(SgdWithMomentumConfig), | ||
Adam(AdamConfig), | ||
} | ||
|
||
impl Default for OptimizerConfig { | ||
fn default() -> Self { | ||
OptimizerConfig::SgdWithMomentum(Default::default()) | ||
} | ||
} | ||
|
||
pub fn optimizer_from_config<B: IBackend + Axpby<f32> + Copy<f32>>(config: &OptimizerConfig) -> Box<dyn Optimizer<B>> { | ||
match config { | ||
OptimizerConfig::SgdWithMomentum(sgd_config) => Box::new(SgdWithMomentum::new(sgd_config)), | ||
OptimizerConfig::Adam(adam_config) => Box::new(Adam::new(adam_config)), | ||
} | ||
} | ||
|
||
impl From<SgdWithMomentumConfig> for OptimizerConfig { | ||
fn from(c: SgdWithMomentumConfig) -> Self { | ||
OptimizerConfig::SgdWithMomentum(c) | ||
} | ||
} | ||
|
||
impl From<AdamConfig> for OptimizerConfig { | ||
fn from(c: AdamConfig) -> Self { | ||
OptimizerConfig::Adam(c) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
//! SGD with momentum. | ||
//! | ||
//! Computes the update Vᵢ from params gradient ∇ᵢ as: | ||
//! Vᵢ = (1-β)Vᵢ₋₁ + β∇ᵢ, | ||
//! V₀ = 0, | ||
//! where: | ||
//! β is the momentum parameter (typically set to 0.1). | ||
//! | ||
//! (Note that the update Vᵢ is then additionally scaled by Trainer using global and param-specific | ||
//! learning rates.) | ||
|
||
use std::cell::RefCell; | ||
use std::collections::HashMap; | ||
use std::rc::Rc; | ||
|
||
use crate::coblas::plugin::Copy; | ||
use crate::train::Optimizer; | ||
use crate::util::{native_scalar, Axpby}; | ||
use crate::weight::FillerType; | ||
use co::prelude::*; | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct SgdWithMomentumConfig { | ||
pub momentum: f32, | ||
} | ||
|
||
pub struct SgdWithMomentum { | ||
history: HashMap<usize, SharedTensor<f32>>, | ||
// Precomputed tensor constants. | ||
zero: SharedTensor<f32>, | ||
momentum: SharedTensor<f32>, | ||
one_minus_momentum: SharedTensor<f32>, | ||
} | ||
|
||
impl Default for SgdWithMomentumConfig { | ||
fn default() -> Self { | ||
SgdWithMomentumConfig { momentum: 0.1 } | ||
} | ||
} | ||
|
||
impl SgdWithMomentum { | ||
pub fn new(config: &SgdWithMomentumConfig) -> Self { | ||
SgdWithMomentum { | ||
history: HashMap::new(), | ||
zero: native_scalar(0.0), | ||
momentum: native_scalar(config.momentum), | ||
one_minus_momentum: native_scalar(1.0 - config.momentum), | ||
} | ||
} | ||
} | ||
|
||
impl<B: IBackend + Axpby<f32> + Copy<f32>> Optimizer<B> for SgdWithMomentum { | ||
fn adjust_weight_change( | ||
&mut self, | ||
backend: &B, | ||
weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>, | ||
) { | ||
for (key, change) in weight_changes { | ||
let mut change_ref = change.borrow_mut(); | ||
|
||
let history = self.history.entry(*key).or_insert_with(|| { | ||
let mut tensor = SharedTensor::new(change_ref.desc()); | ||
FillerType::fill_constant(&mut tensor, 0.0); | ||
tensor | ||
}); | ||
|
||
// Make sure the params shape didn't change under us. | ||
assert_eq!(history.desc().size(), change_ref.desc().size()); | ||
|
||
// Compute Vᵢ=(1-β)Vᵢ₋₁ + β∇. | ||
backend | ||
.axpby( | ||
&self.momentum, | ||
&change_ref, | ||
&self.one_minus_momentum, | ||
history, | ||
) | ||
.unwrap(); | ||
|
||
// Copy Vᵢ to the weight change which should hold the return value. | ||
backend.copy(history, &mut change_ref).unwrap(); | ||
} | ||
} | ||
} |
Oops, something went wrong.