Skip to content

Commit

Permalink
Add trainer subsystem with SGD and Adam optimizers (fff-rs#177)
Browse files Browse the repository at this point in the history
Co-authored-by: Mikhail Balakhno <{ID}+{username}@users.noreply.github.com>
Co-authored-by: Bernhard Schuster <bernhard@ahoi.io>
  • Loading branch information
3 people authored and Mikhail Balakhno committed Feb 25, 2024
1 parent 0e2dea5 commit 6e8b558
Show file tree
Hide file tree
Showing 7 changed files with 569 additions and 0 deletions.
1 change: 1 addition & 0 deletions juice/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ pub mod layers;
pub mod net;
pub mod solver;
pub mod solvers;
pub mod train;
pub mod weight;

mod capnp_util;
Expand Down
6 changes: 6 additions & 0 deletions juice/src/net/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ impl From<SequentialConfig> for LayerConfig {
LayerConfig::Sequential(c)
}
}

impl From<LinearConfig> for LayerConfig {
fn from(c: LinearConfig) -> Self {
LayerConfig::Linear(c)
}
}
5 changes: 5 additions & 0 deletions juice/src/train/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod trainer;
mod optimizer;

pub use optimizer::*;
pub use trainer::*;
145 changes: 145 additions & 0 deletions juice/src/train/optimizer/adam.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//! Adam optimizer.
//!
//! Computes the update Vᵢ from params gradient ∇ᵢ as:
//! Mᵢ = β₁Mᵢ₋₁ + (1-β₁)∇ᵢ,
//! Sᵢ = β₂Sᵢ₋₁ + (1-β₂)∇ᵢ⊙∇ᵢ,
//! M₀ = 0,
//! S₀ = 0,
//! M̂ᵢ = Mᵢ/(1-β₁ᵗ),
//! Ŝᵢ = Sᵢ/(1-β₂ᵗ),
//! Vᵢ = M̂ᵢ⊘(√Ŝᵢ+ε),
//! where:
//! ⊙ - pointwise multiplication,
//! ⊘ - pointwise division,
//! β₁, β₂ - averaging parameters (typically set to 0.9 and 0.999 respectively),
//! ε - small constant to prevent division by zero (typically 1e-8).
//!
//! (Note that the update Vᵢ is then additionally scaled by Trainer using global and param-specific
//! learning rates.)

use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use crate::train::Optimizer;
use crate::util::native_backend;
use crate::weight::FillerType;
use co::prelude::*;

#[derive(Clone, Debug)]
pub struct AdamConfig {
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
}

pub struct Adam {
// First gradient moment (Mᵢ).
first_moments: HashMap<usize, SharedTensor<f32>>,
// Second gradient moment (Sᵢ).
second_moments: HashMap<usize, SharedTensor<f32>>,

// Original β₁ as well as raised to t-th power (β₁ᵗ).
beta1: f32,
beta1_nth: f32,
// Original β₂ as well as raised to t-th power (β₂ᵗ).
beta2: f32,
beta2_nth: f32,

epsilon: f32,
}

impl Default for AdamConfig {
fn default() -> Self {
AdamConfig {
beta1: 0.9,
beta2: 0.999,
epsilon: 1.0e-8,
}
}
}

impl Adam {
pub fn new(config: &AdamConfig) -> Self {
Adam {
first_moments: HashMap::new(),
second_moments: HashMap::new(),
beta1: config.beta1,
beta1_nth: config.beta1,
beta2: config.beta2,
beta2_nth: config.beta2,
epsilon: config.epsilon,
}
}
}

// TODO: Rewrite with backend ops (requires element-wise square and square root support).
impl<B: IBackend> Optimizer<B> for Adam {
fn adjust_weight_change(
&mut self,
backend: &B,
weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>,
) {
let native = native_backend();

for (key, change) in weight_changes {
let mut change_ref = change.borrow_mut();

let first_moment = self.first_moments.entry(*key).or_insert_with(|| {
let mut tensor = SharedTensor::new(change_ref.desc());
FillerType::fill_constant(&mut tensor, 0.0);
tensor
});
let second_moment = self.second_moments.entry(*key).or_insert_with(|| {
let mut tensor = SharedTensor::new(change_ref.desc());
FillerType::fill_constant(&mut tensor, 0.0);
tensor
});

// Make sure the params shape didn't change under us.
assert_eq!(change_ref.desc().size(), first_moment.desc().size());
assert_eq!(change_ref.desc().size(), second_moment.desc().size());

let len = change_ref.desc().size();

let change_slice = change_ref
.read_write(native.device())
.unwrap()
.as_mut_slice::<f32>();
let first_moment_slice = first_moment
.read_write(native.device())
.unwrap()
.as_mut_slice::<f32>();
let second_moment_slice = second_moment
.read_write(native.device())
.unwrap()
.as_mut_slice::<f32>();

// We can rewrite the matrix equations at the top of this file in a element-wise form:
// Mᵢ[j] = β₁Mᵢ₋₁[j] + (1-β₁)∇ᵢ[j]
// Sᵢ[j] = β₂Sᵢ₋₁[j] + (1-β₂)∇ᵢ[j]²
// Vᵢ[j] = Mᵢ[j] / ((1-β₁ᵗ)•√(Sᵢ[j]/(1-β₂ᵗ) + ε)
for j in 0..len {
// ∇ᵢ[j].
let w = change_slice[j];
// Mᵢ[j], M̂ᵢ[j].
let m = self.beta1 * first_moment_slice[j] + (1.0 - self.beta1) * w;
let m_hat = m / (1.0 - self.beta1_nth);
// Sᵢ[j], Ŝᵢ[j].
let s = self.beta2 * second_moment_slice[j] + (1.0 - self.beta2) * w * w;
let s_hat = s / (1.0 - self.beta2_nth);
// Vᵢ[j].
let v = m_hat / (s_hat.sqrt() + self.epsilon);

assert!(!v.is_nan());

change_slice[j] = v;
first_moment_slice[j] = m;
second_moment_slice[j] = s;
}
}

self.beta1_nth *= self.beta1;
self.beta2_nth *= self.beta2;
}
}
60 changes: 60 additions & 0 deletions juice/src/train/optimizer/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
mod adam;
mod sgd_momentum;

use std::rc::Rc;
use std::cell::RefCell;
use std::collections::HashMap;
use std::default::Default;

use crate::coblas::plugin::Copy;
use co::prelude::*;
use crate::util::Axpby;

use adam::Adam;
use sgd_momentum::SgdWithMomentum;

// Expose configs publicly.
pub use adam::AdamConfig;
pub use sgd_momentum::SgdWithMomentumConfig;

// A gradient descent optimization algorithm.
pub trait Optimizer<B: IBackend> {
// Called on each minibatch training cycle. Takes all weight gradients computed during
// backpropagation (indexed by an opaque key which is guaranteed to be stable for the
// duration of the program).
// Modifies the changes in-place; modified changes will then be applied to the weights:
// W = W - α•change,
// where α is the learning rate (combined from global and param-specific rates).
fn adjust_weight_change(&mut self, backend: &B, weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>);
}

#[derive(Clone, Debug)]
pub enum OptimizerConfig {
SgdWithMomentum(SgdWithMomentumConfig),
Adam(AdamConfig),
}

impl Default for OptimizerConfig {
fn default() -> Self {
OptimizerConfig::SgdWithMomentum(Default::default())
}
}

pub fn optimizer_from_config<B: IBackend + Axpby<f32> + Copy<f32>>(config: &OptimizerConfig) -> Box<dyn Optimizer<B>> {
match config {
OptimizerConfig::SgdWithMomentum(sgd_config) => Box::new(SgdWithMomentum::new(sgd_config)),
OptimizerConfig::Adam(adam_config) => Box::new(Adam::new(adam_config)),
}
}

impl From<SgdWithMomentumConfig> for OptimizerConfig {
fn from(c: SgdWithMomentumConfig) -> Self {
OptimizerConfig::SgdWithMomentum(c)
}
}

impl From<AdamConfig> for OptimizerConfig {
fn from(c: AdamConfig) -> Self {
OptimizerConfig::Adam(c)
}
}
84 changes: 84 additions & 0 deletions juice/src/train/optimizer/sgd_momentum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//! SGD with momentum.
//!
//! Computes the update Vᵢ from params gradient ∇ᵢ as:
//! Vᵢ = (1-β)Vᵢ₋₁ + β∇ᵢ,
//! V₀ = 0,
//! where:
//! β is the momentum parameter (typically set to 0.1).
//!
//! (Note that the update Vᵢ is then additionally scaled by Trainer using global and param-specific
//! learning rates.)

use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use crate::coblas::plugin::Copy;
use crate::train::Optimizer;
use crate::util::{native_scalar, Axpby};
use crate::weight::FillerType;
use co::prelude::*;

#[derive(Clone, Debug)]
pub struct SgdWithMomentumConfig {
pub momentum: f32,
}

pub struct SgdWithMomentum {
history: HashMap<usize, SharedTensor<f32>>,
// Precomputed tensor constants.
zero: SharedTensor<f32>,
momentum: SharedTensor<f32>,
one_minus_momentum: SharedTensor<f32>,
}

impl Default for SgdWithMomentumConfig {
fn default() -> Self {
SgdWithMomentumConfig { momentum: 0.1 }
}
}

impl SgdWithMomentum {
pub fn new(config: &SgdWithMomentumConfig) -> Self {
SgdWithMomentum {
history: HashMap::new(),
zero: native_scalar(0.0),
momentum: native_scalar(config.momentum),
one_minus_momentum: native_scalar(1.0 - config.momentum),
}
}
}

impl<B: IBackend + Axpby<f32> + Copy<f32>> Optimizer<B> for SgdWithMomentum {
fn adjust_weight_change(
&mut self,
backend: &B,
weight_changes: &HashMap<usize, Rc<RefCell<SharedTensor<f32>>>>,
) {
for (key, change) in weight_changes {
let mut change_ref = change.borrow_mut();

let history = self.history.entry(*key).or_insert_with(|| {
let mut tensor = SharedTensor::new(change_ref.desc());
FillerType::fill_constant(&mut tensor, 0.0);
tensor
});

// Make sure the params shape didn't change under us.
assert_eq!(history.desc().size(), change_ref.desc().size());

// Compute Vᵢ=(1-β)Vᵢ₋₁ + β∇.
backend
.axpby(
&self.momentum,
&change_ref,
&self.one_minus_momentum,
history,
)
.unwrap();

// Copy Vᵢ to the weight change which should hold the return value.
backend.copy(history, &mut change_ref).unwrap();
}
}
}

0 comments on commit 6e8b558

Please sign in to comment.