Skip to content

Commit

Permalink
add costmod and temperature sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 7, 2024
1 parent 8436553 commit 958e4a8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def optimize_greedy(
When assessing local greedy scores how much to weight the size of the
tensors removed compared to the size of the tensor added::
score = size_ab - costmod * (size_a + size_b)
score = size_ab / costmod - (size_a + size_b) * costmod
This can be a useful hyper-parameter to tune.
temperature : float, optional
Expand Down Expand Up @@ -235,8 +235,8 @@ def optimize_random_greedy_track_flops(
output,
size_dict,
ntrials=1,
costmod=1.0,
temperature=0.01,
costmod=(0.1, 4.0),
temperature=(0.001, 1.0),
seed=None,
simplify=True,
use_ssa=False,
Expand All @@ -255,20 +255,21 @@ def optimize_random_greedy_track_flops(
A dictionary mapping indices to their dimension.
ntrials : int, optional
The number of random greedy trials to perform. The default is 1.
costmod : float, optional
costmod : (float, float), optional
When assessing local greedy scores how much to weight the size of the
tensors removed compared to the size of the tensor added::
score = size_ab - costmod * (size_a + size_b)
score = size_ab / costmod - (size_a + size_b) * costmod
This can be a useful hyper-parameter to tune.
temperature : float, optional
It is sampled uniformly from the given range.
temperature : (float, float), optional
When asessing local greedy scores, how much to randomly perturb the
score. This is implemented as::
score -> sign(score) * log(|score|) - temperature * gumbel()
which implements boltzmann sampling.
which implements boltzmann sampling. It is sampled log-uniformly from
the given range.
seed : int, optional
The seed for the random number generator.
simplify : bool, optional
Expand Down
28 changes: 23 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ impl ContractionProcessor {
} else {
0.0 as f32
};
logsub(sab, log_coeff_a + logadd(sa, sb)) - gumbel
logsub(sab - log_coeff_a, logadd(sa, sb) + log_coeff_a) - gumbel
};

// cache all current nodes sizes as we go
Expand Down Expand Up @@ -945,14 +945,16 @@ fn optimize_random_greedy_track_flops(
output: Vec<char>,
size_dict: Dict<char, f32>,
ntrials: usize,
costmod: Option<f32>,
temperature: Option<f32>,
costmod: Option<(f32, f32)>,
temperature: Option<(f32, f32)>,
seed: Option<u64>,
simplify: Option<bool>,
use_ssa: Option<bool>,
) -> (Vec<Vec<Node>>, Score) {
py.allow_threads(|| {
let temperature = temperature.unwrap_or(0.01);
let (costmodmin, costmodmax) = costmod.unwrap_or((0.1, 4.0));
let (tempmin, tempmax) = temperature.unwrap_or((0.001, 1.0));

let mut rng = match seed {
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
None => rand::rngs::StdRng::from_entropy(),
Expand All @@ -969,10 +971,26 @@ fn optimize_random_greedy_track_flops(
let mut best_path = None;
let mut best_flops = f32::INFINITY;

let logtempmin = f32::ln(tempmin);
let logtempmax = f32::ln(tempmax);

for seed in seeds {
let mut cp = cp0.clone();

// uniform sample for costmod
let costmod = match costmodmax - costmodmin {
0.0 => costmodmin,
diff => costmodmin + rng.gen::<f32>() * diff,
};

// log-uniform sample for temperature
let temperature = match logtempmax - logtempmin {
0.0 => tempmin,
diff => f32::exp(logtempmin + rng.gen::<f32>() * diff),
};

// greedily contract each connected subgraph
cp.optimize_greedy(costmod, Some(temperature), Some(seed));
cp.optimize_greedy(Some(costmod), Some(temperature), Some(seed));
// optimize any remaining disconnected terms
cp.optimize_remaining_by_size();

Expand Down

0 comments on commit 958e4a8

Please sign in to comment.