diff --git a/README.md b/README.md index 5debd6b..42f7b15 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ def optimize_greedy( When assessing local greedy scores how much to weight the size of the tensors removed compared to the size of the tensor added:: - score = size_ab - costmod * (size_a + size_b) + score = size_ab / costmod - (size_a + size_b) * costmod This can be a useful hyper-parameter to tune. temperature : float, optional @@ -235,8 +235,8 @@ def optimize_random_greedy_track_flops( output, size_dict, ntrials=1, - costmod=1.0, - temperature=0.01, + costmod=(0.1, 4.0), + temperature=(0.001, 1.0), seed=None, simplify=True, use_ssa=False, @@ -255,20 +255,21 @@ def optimize_random_greedy_track_flops( A dictionary mapping indices to their dimension. ntrials : int, optional The number of random greedy trials to perform. The default is 1. - costmod : float, optional + costmod : (float, float), optional When assessing local greedy scores how much to weight the size of the tensors removed compared to the size of the tensor added:: - score = size_ab - costmod * (size_a + size_b) + score = size_ab / costmod - (size_a + size_b) * costmod - This can be a useful hyper-parameter to tune. - temperature : float, optional + It is sampled uniformly from the given range. + temperature : (float, float), optional When asessing local greedy scores, how much to randomly perturb the score. This is implemented as:: score -> sign(score) * log(|score|) - temperature * gumbel() - which implements boltzmann sampling. + which implements boltzmann sampling. It is sampled log-uniformly from + the given range. seed : int, optional The seed for the random number generator. simplify : bool, optional diff --git a/src/lib.rs b/src/lib.rs index 29edbe0..2fdc7f4 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -435,7 +435,7 @@ impl ContractionProcessor { } else { 0.0 as f32 }; - logsub(sab, log_coeff_a + logadd(sa, sb)) - gumbel + logsub(sab - log_coeff_a, logadd(sa, sb) + log_coeff_a) - gumbel }; // cache all current nodes sizes as we go @@ -945,14 +945,16 @@ fn optimize_random_greedy_track_flops( output: Vec, size_dict: Dict, ntrials: usize, - costmod: Option, - temperature: Option, + costmod: Option<(f32, f32)>, + temperature: Option<(f32, f32)>, seed: Option, simplify: Option, use_ssa: Option, ) -> (Vec>, Score) { py.allow_threads(|| { - let temperature = temperature.unwrap_or(0.01); + let (costmodmin, costmodmax) = costmod.unwrap_or((0.1, 4.0)); + let (tempmin, tempmax) = temperature.unwrap_or((0.001, 1.0)); + let mut rng = match seed { Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), None => rand::rngs::StdRng::from_entropy(), @@ -969,10 +971,26 @@ fn optimize_random_greedy_track_flops( let mut best_path = None; let mut best_flops = f32::INFINITY; + let logtempmin = f32::ln(tempmin); + let logtempmax = f32::ln(tempmax); + for seed in seeds { let mut cp = cp0.clone(); + + // uniform sample for costmod + let costmod = match costmodmax - costmodmin { + 0.0 => costmodmin, + diff => costmodmin + rng.gen::() * diff, + }; + + // log-uniform sample for temperature + let temperature = match logtempmax - logtempmin { + 0.0 => tempmin, + diff => f32::exp(logtempmin + rng.gen::() * diff), + }; + // greedily contract each connected subgraph - cp.optimize_greedy(costmod, Some(temperature), Some(seed)); + cp.optimize_greedy(Some(costmod), Some(temperature), Some(seed)); // optimize any remaining disconnected terms cp.optimize_remaining_by_size();