Skip to content

Commit

Permalink
removed scoped thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
jackm321 committed May 8, 2015
1 parent b6e9312 commit 26aff11
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 206 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

name = "nn"
version = "0.1.5"
version = "0.1.6"
authors = ["Jack Montgomery <jackm321@gmail.com>"]
repository = "https://github.com/jackm321/RustNN"
documentation = "https://jackm321.github.io/RustNN/doc/nn/"
Expand All @@ -17,6 +17,5 @@ keywords = ["nn", "neural-network", "classifier", "backpropagation",

[dependencies]
rand = "0.3.7"
threadpool = "0.1.3"
rustc-serialize = "0.3.12"
time = "0.1.24"
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ An easy to use neural network library written in Rust.

## Description
RustNN is a [feedforward neural network ](http://en.wikipedia.org/wiki/Feedforward_neural_network)
library that uses parallelization to quickly learn over large datasets. The library
library. The library
generates fully connected multi-layer artificial neural networks that
are trained via [backpropagation](http://en.wikipedia.org/wiki/Backpropagation).
Networks can be trained using a incremental training mode or they
can be trained (optionally in parallel) using a batch training mode.
Networks are trained using an incremental training mode.

## XOR example

Expand All @@ -28,7 +27,7 @@ given examples. See the documentation for the `NN` and `Trainer` structs
for more details.

```rust
use nn::{NN, HaltCondition, LearningMode};
use nn::{NN, HaltCondition};

// create examples of the XOR function
// the network is trained on tuples of vectors where the first vector
Expand All @@ -51,16 +50,15 @@ let mut net = NN::new(&[2, 3, 1]);
// see the documentation for the Trainer struct for more info on what each method does
net.train(&examples)
.halt_condition( HaltCondition::Epochs(10000) )
.learning_mode( LearningMode::Incremental )
.log_interval( Some(100) )
.momentum(0.1)
.rate(0.3)
.momentum( 0.1 )
.rate( 0.3 )
.go();

// evaluate the network to see if it learned the XOR function
for &(ref inputs, ref outputs) in examples.iter() {
let results = net.run(inputs);
let (result, key) = (Float::round(results[0]), outputs[0]);
let (result, key) = (results[0].round(), outputs[0]);
assert!(result == key);
}
```
173 changes: 11 additions & 162 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
//!
//! # Description
//! nn is a [feedforward neural network ](http://en.wikipedia.org/wiki/Feedforward_neural_network)
//! library that uses parallelization to quickly learn over large datasets. The library
//! library. The library
//! generates fully connected multi-layer artificial neural networks that
//! are trained via [backpropagation](http://en.wikipedia.org/wiki/Backpropagation).
//! Networks can be trained using an incremental training mode or they
//! can be trained (optionally in parallel) using a batch training mode.
//! Networks are trained using an incremental training mode.
//!
//! # XOR example
//!
Expand All @@ -20,8 +19,7 @@
//! for more details.
//!
//! ```rust
//! # use std::num::Float;
//! use nn::{NN, HaltCondition, LearningMode};
//! use nn::{NN, HaltCondition};
//!
//! // create examples of the XOR function
//! // the network is trained on tuples of vectors where the first vector
Expand All @@ -44,33 +42,27 @@
//! // see the documentation for the Trainer struct for more info on what each method does
//! net.train(&examples)
//! .halt_condition( HaltCondition::Epochs(10000) )
//! .learning_mode( LearningMode::Incremental )
//! .log_interval( Some(100) )
//! .momentum(0.1)
//! .rate(0.3)
//! .momentum( 0.1 )
//! .rate( 0.3 )
//! .go();
//!
//! // evaluate the network to see if it learned the XOR function
//! for &(ref inputs, ref outputs) in examples.iter() {
//! let results = net.run(inputs);
//! let (result, key) = (Float::round(results[0]), outputs[0]);
//! let (result, key) = (results[0].round(), outputs[0]);
//! assert!(result == key);
//! }
//! ```

extern crate rand;
extern crate threadpool;
extern crate rustc_serialize;
extern crate time;

use HaltCondition::{ Epochs, MSE, Timer };
use LearningMode::{ Incremental, Batch };
use LearningMode::{ Incremental };
use std::iter::{Zip, Enumerate};
use std::slice;
use threadpool::{ScopedPool};
use std::sync::mpsc::channel;
use std::sync::{Arc, RwLock};
use std::cmp;
use rustc_serialize::json;
use time::{ Duration, PreciseTime };
use rand::Rng;
Expand All @@ -94,11 +86,7 @@ pub enum HaltCondition {
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum LearningMode {
/// train the network Incrementally (updates weights after each example)
Incremental,
/// train the network in batch (updates weights only at then end of each epoch)
/// batch training can be parallelized and thus the Batch constructor takes a `u32`
/// that specifies the number of threads to use when training the network
Batch(u32)
Incremental
}

/// Used to specify options that dictate how a network will be trained
Expand Down Expand Up @@ -175,19 +163,7 @@ impl<'a,'b> Trainer<'a,'b> {
}
/// Specifies what [mode](http://en.wikipedia.org/wiki/Backpropagation#Modes_of_learning) to train the network in.
/// `Incremental` means update the weights in the network after every example.
/// `Batch(t)` means run the network on all examples given and accumulate weight
/// updates along the way but don't actually change the weights in the
/// network until all of the examples have been run. Batch training can be
/// parallelized, so the `t` in the `Batch(t)` constructor specifies how
/// many threads to use while training the network.
pub fn learning_mode(&mut self, learning_mode: LearningMode) -> &mut Trainer<'a,'b> {
match learning_mode {
Batch(threads) if threads < 1 => {
panic!("the number of threads in Batch training mode must be at least 1")
}
_ => ()
}

self.learning_mode = learning_mode;
self
}
Expand All @@ -201,8 +177,7 @@ impl<'a,'b> Trainer<'a,'b> {
self.rate,
self.momentum,
self.log_interval,
self.halt_condition,
self.learning_mode,
self.halt_condition
)
}

Expand Down Expand Up @@ -301,7 +276,7 @@ impl NN {
}

fn train_details(&mut self, examples: &[(Vec<f64>, Vec<f64>)], rate: f64, momentum: f64, log_interval: Option<u32>,
halt_condition: HaltCondition, learning_mode: LearningMode) -> f64 {
halt_condition: HaltCondition) -> f64 {

// check that input and output sizes are correct
let input_layer_size = self.num_inputs;
Expand All @@ -315,11 +290,7 @@ impl NN {
}
}

match learning_mode {
Incremental => self.train_incremental(examples, rate, momentum, log_interval, halt_condition),
Batch(threads) => self.train_batch(examples, rate, momentum, log_interval, halt_condition, threads)
}

self.train_incremental(examples, rate, momentum, log_interval, halt_condition)
}

fn train_incremental(&mut self, examples: &[(Vec<f64>, Vec<f64>)], rate: f64, momentum: f64, log_interval: Option<u32>,
Expand Down Expand Up @@ -371,113 +342,6 @@ impl NN {
training_error_rate
}

fn train_batch(&mut self, examples: &[(Vec<f64>, Vec<f64>)], rate: f64, momentum: f64, log_interval: Option<u32>,
halt_condition: HaltCondition, mut threads: u32) -> f64 {

threads = cmp::min(threads, examples.len() as u32);

let mut prev_deltas = self.make_weights_tracker(0.0f64);
let mut epochs = 0;
let mut training_error_rate = 0.0f64;

let mut split_examples = Vec::new();
{
let small_num = examples.len() / threads as usize;
let large_num = small_num + 1;
let larges = examples.len() % threads as usize;
let mut prev_start = 0;
for i in 0..threads {
let start = prev_start;
let end = start + if i < larges as u32 { large_num } else { small_num };
prev_start = end;
let slc = &examples[start..end];
split_examples.push(slc);
}
}

let pool = ScopedPool::new(threads);
let self_lock = Arc::new(RwLock::new(self));
let (tx, rx) = channel();
let start_time = PreciseTime::now();

loop {

if epochs > 0 {
// log error rate if necessary
match log_interval {
Some(interval) if epochs % interval == 0 => {
println!("error rate: {}", training_error_rate);
},
_ => (),
}

// check if we've met the halt condition yet
match halt_condition {
Epochs(epochs_halt) => {
if epochs == epochs_halt { break }
},
MSE(target_error) => {
if training_error_rate <= target_error { break }
},
Timer(duration) => {
let now = PreciseTime::now();
if start_time.to(now) >= duration { break }
}
}
}

training_error_rate = 0f64;

// init batch data
let mut batch_weight_updates =
self_lock.read().unwrap().make_weights_tracker(0.0f64);

// run each example using the thread pool
for examples in split_examples.iter() {
let self_lock = self_lock.clone();
let tx = tx.clone();

let mut local_weight_updates = self_lock.read().unwrap().make_weights_tracker(0.0f64);
let mut local_error_rate = 0.0f64;

pool.execute(move || {
let read_self = self_lock.read().unwrap();

for &(ref inputs, ref targets) in examples.iter() {
let results = read_self.do_run(&inputs);
let new_weight_updates =
read_self.calculate_weight_updates(&results, &targets);

let new_error_rate = calculate_error(&results, &targets);

update_batch_data(&mut local_weight_updates, &new_weight_updates);
local_error_rate += new_error_rate;
}

tx.send((local_weight_updates, local_error_rate)).unwrap();
});
}

// collect the results from the thread pool
for _ in 0..threads {
let (weight_updates, error_rate) = rx.recv().unwrap();
training_error_rate += error_rate;
update_batch_data(&mut batch_weight_updates, &weight_updates);
}

// update weights in the network
self_lock.write().unwrap()
.update_weights(&batch_weight_updates,
&mut prev_deltas,
rate, momentum);

epochs += 1;
}

training_error_rate
}


fn do_run(&self, inputs: &[f64]) -> Vec<Vec<f64>> {
let mut results = Vec::new();
results.push(inputs.to_vec());
Expand Down Expand Up @@ -599,21 +463,6 @@ fn sigmoid(y: f64) -> f64 {
1f64 / (1f64 + (-y).exp())
}

// adds new network weight updates into the updates already collected
fn update_batch_data(batch_data: &mut Vec<Vec<Vec<f64>>> , network_weight_updates: &Vec<Vec<Vec<f64>>>) {
for layer_index in 0..batch_data.len() {
let mut batch_layer = &mut batch_data[layer_index];
let layer_weight_updates = &network_weight_updates[layer_index];
for node_index in 0..batch_layer.len() {
let mut batch_node = &mut batch_layer[node_index];
let node_weight_updates = &layer_weight_updates[node_index];
for weight_index in 0..batch_node.len() {
batch_node[weight_index] += node_weight_updates[weight_index];
}
}
}
}


// takes two arrays and enumerates the iterator produced by zipping each of
// their iterators together
Expand Down
34 changes: 0 additions & 34 deletions tests/xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ extern crate nn;
extern crate time;

use nn::{NN, HaltCondition, LearningMode};
use time::Duration;

#[test]
fn xor_4layers() {
Expand Down Expand Up @@ -36,36 +35,3 @@ fn xor_4layers() {
assert!(result == key);
}
}

#[test]
fn xor_timed() {
// create examples of the xor function
let examples = [
(vec![0f64, 0f64], vec![0f64]),
(vec![0f64, 1f64], vec![1f64]),
(vec![1f64, 0f64], vec![1f64]),
(vec![1f64, 1f64], vec![0f64]),
];

// create a new neural network
let mut net1 = NN::new(&[2,3,1]);

// train the network
net1.train(&examples)
.halt_condition( HaltCondition::Timer(Duration::seconds(5)) )
.log_interval(Some(1000))
.learning_mode( LearningMode::Batch(2) )
.momentum(0.1)
.go();

// make sure json encoding/decoding works as expected
let json = net1.to_json();
let net2 = NN::from_json(&json);

// test the trained network
for &(ref inputs, ref outputs) in examples.iter() {
let results = net2.run(inputs);
let (result, key) = (results[0].round(), outputs[0]);
assert!(result == key);
}
}

0 comments on commit 26aff11

Please sign in to comment.