Skip to content

Commit

Permalink
Add some util tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
elinorbgr committed Jul 29, 2015
1 parent c02f663 commit e1df873
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! A set of utility method to combine networks.

use std::iter::repeat;
use std::marker::PhantomData;

use num::Float;
use num::{Float, zero};

use Compute;
use {Method, UnsupervisedTrain, SupervisedTrain, BackpropTrain};
Expand Down Expand Up @@ -174,4 +175,62 @@ impl<F: Float> Compute<F> for FixedOutput<F> {
fn output_size(&self) -> usize {
self.output.len()
}
}

/// A network that simply returns its input
pub struct Identity {
size: usize
}

impl Identity {
/// Creates a new identity network of given size
pub fn new(size: usize) -> Identity {
Identity {
size: size
}
}
}

impl<F: Float> Compute<F> for Identity {
fn compute(&self, input: &[F]) -> Vec<F> {
let mut out = input.to_owned();
out.truncate(self.size);
let outsize = out.len();
if outsize < self.size { out.extend(repeat(zero::<F>()).take(self.size - outsize)); }
out
}

fn input_size(&self) -> usize {
self.size
}

fn output_size(&self) -> usize {
self.size
}
}

#[cfg(test)]
mod tests {
use super::{Identity, Chain, Parallel};

use Compute;

#[test]
fn identity() {
let id = Identity::new(4);
assert_eq!(id.compute(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]), [1.0f32, 2.0, 3.0, 4.0]);
assert_eq!(id.compute(&[1.0f32, 2.0]), [1.0f32, 2.0, 0.0, 0.0]);
}

#[test]
fn chain() {
let ch = Chain::new(Identity::new(4), Identity::new(6));
assert_eq!(ch.compute(&[1.0f32, 2.0, 3.0]), [1.0f32, 2.0, 3.0, 0.0, 0.0, 0.0])
}

#[test]
fn parallel() {
let ch = Parallel::new(Identity::new(4), Identity::new(2));
assert_eq!(ch.compute(&[1.0f32, 2.0, 3.0]), [1.0f32, 2.0, 3.0, 0.0, 1.0, 2.0])
}
}

0 comments on commit e1df873

Please sign in to comment.