Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
195 lines (172 sloc) 6.63 KB
use bit_set::BitSet;
use std::collections::BTreeSet;
use std::fmt::Debug;
use std::fmt;
use rand::{Rng, thread_rng};
use rand::distributions::{Range, IndependentSample};
pub const EPSILON: f32 = 1e-16;
#[derive(Copy, Clone, Debug)]
pub enum SampleElement<E> {
Dependent(E),
NonDependent,
}
use self::SampleElement::*;
pub trait Objective: Sized {
type Element: Ord + Into<usize> + Clone + Copy;
type State: Default + Clone;
fn max_gamma(&self) -> Option<f32>;
fn elements(&self) -> Vec<Self::Element>;
fn benefit(&self, s: &BitSet, state: &Self::State) -> Result<f32, CurvError<Self>>;
/// The marginal gain $\delta(u \mid S, \psi)$
///
/// Returns `0` for $u \in S$.
fn delta(&self,
u: Self::Element,
s: &BitSet,
state: &Self::State)
-> Result<f32, CurvError<Self>>;
fn nabla(&self,
u: Self::Element,
v: Self::Element,
s: &BitSet,
state: &Self::State)
-> Result<f32, CurvError<Self>>;
fn depends(&self,
u: Self::Element,
state: &Self::State)
-> Result<BTreeSet<Self::Element>, CurvError<Self>>;
fn insert_mut(&self, u: Self::Element, s: &mut Self::State) -> Result<(), CurvError<Self>>;
fn insert(&self, u: Self::Element, s: &Self::State) -> Result<Self::State, CurvError<Self>> {
let mut state = s.clone();
self.insert_mut(u, &mut state)?;
Ok(state)
}
/// Uniformly sample a sequence of `k` elements from about element `u` consistent with initial
/// state `state`. *The state is not updated, so if the dependent set is state-dependent,
/// future elements may not be valid selections.*
fn sample_sequence(&self,
u: Self::Element,
k: usize,
bias: Option<f32>,
sol: &BitSet,
state: &Self::State)
-> Result<Vec<SampleElement<Self::Element>>, CurvError<Self>> {
let mut sample = Vec::with_capacity(k);
let mut deps = self.depends(u, state)?
.into_iter()
.filter(|&v| !sol.contains(v.into()))
.collect::<Vec<_>>();
let mut rng = thread_rng();
rng.shuffle(&mut deps);
let bias =
bias.unwrap_or_else(|| deps.len() as f32 / (self.elements().len() - sol.len()) as f32);
let uniform = Range::new(0.0, 1.0);
for _ in 0..k {
if uniform.ind_sample(&mut rng) <= bias && !deps.is_empty() {
// take a dependent element
sample.push(Dependent(deps.pop().unwrap()));
} else {
sample.push(NonDependent);
}
}
Ok(sample)
}
/// Compute the total primal curvature of an element `u` after the first `k` elements of
/// `sequence` have been added. If the `sequence` does not have at least `k` elements,
/// `CurvError::SampleTooSmall` is returned.
fn gamma(&self,
u: Self::Element,
k: usize,
mut sol: BitSet,
sequence: Vec<SampleElement<Self::Element>>,
mut state: Self::State)
-> Result<(f32, BitSet, Self::State), CurvError<Self>> {
if sequence.len() < k {
return Err(CurvError::SampleTooSmall(sequence.len(), k));
}
let mut prod = 1f32;
for e in sequence.iter().take(k) {
prod *= match e {
&Dependent(v) => self.nabla(u, v, &sol, &state)?,
// proof of this relation is in the 2017 notebook, pg. 8-9
//
// the gist of it is that the non-dependence of x implies a pair of f_u(S \cup {x})
// = f_u(S)-style relations, which are used to show ∇(u, v | S) = ∇(u, v | S \cup
// {x}).
&NonDependent => 1f32,
};
if let &Dependent(v) = e {
self.insert_mut(v, &mut state)?;
sol.insert(v.into());
}
}
Ok((prod, sol, state))
}
/// Compute the sequence of total primal curvature values for `1..k` elements of `sequence`.
/// Unlike `gamma`, this does not return the final solution or state. `gamma(0|S)` is omitted
/// as it is constant `1`.
fn gamma_seq(&self,
u: Self::Element,
k: usize,
mut sol: BitSet,
sequence: Vec<SampleElement<Self::Element>>,
mut state: Self::State)
-> Result<Vec<f32>, CurvError<Self>> {
if sequence.len() < k {
return Err(CurvError::SampleTooSmall(sequence.len(), k));
}
let mut prod = 1f32;
let mut seq = Vec::with_capacity(k);
for e in sequence.iter().take(k) {
prod *= match e {
&Dependent(v) => self.nabla(u, v, &sol, &state)?,
// proof of this relation is in the 2017 notebook, pg. 8-9
//
// the gist of it is that the non-dependence of x implies a pair of f_u(S \cup {x})
// = f_u(S)-style relations, which are used to show ∇(u, v | S) = ∇(u, v | S \cup
// {x}).
&NonDependent => 1f32,
};
seq.push(prod);
if let &Dependent(v) = e {
self.insert_mut(v, &mut state)?;
sol.insert(v.into());
}
}
Ok(seq)
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord)]
pub enum CurvError<O: Objective> {
NoCandidates(O::Element),
SampleTooSmall(usize, usize),
Other(String),
}
// These impls are needed because Objective doesn't imply Clone or Debug
impl<O: Objective> Clone for CurvError<O> {
fn clone(&self) -> Self {
match self {
&CurvError::NoCandidates(e) => CurvError::NoCandidates(e),
&CurvError::SampleTooSmall(s, k) => CurvError::SampleTooSmall(s, k),
&CurvError::Other(ref s) => CurvError::Other(s.clone()),
}
}
}
impl<O: Objective> Debug for CurvError<O> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let err: CurvError<O> = self.to_owned();
let s: String = err.into();
write!(f, "{}", s)
}
}
impl<O: Objective> From<CurvError<O>> for String {
fn from(err: CurvError<O>) -> Self {
match err {
CurvError::NoCandidates(u) => format!("Node {} has no neighbors", u.into()),
CurvError::SampleTooSmall(s, k) => {
format!("Provided sample too small (given: {}, requested: {})", s, k)
}
CurvError::Other(msg) => msg,
}
}
}