Skip to content

Commit

Permalink
fix random greedy for python 3.5 & make deterministic for 3.6+
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Feb 28, 2020
1 parent 4323cfc commit 7f81269
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions opt_einsum/path_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@
import heapq
import math
import numbers
import random
import time
from collections import deque

# random.choices was introduced in python 3.6
try:
from random import choices as random_choices
from random import seed as random_seed
except ImportError:
import numpy as np

def random_choices(population, weights):
norm = sum(weights)
return np.random.choice(population, p=[w / norm for w in weights], size=1)

random_seed = np.random.seed

from . import helpers, paths

__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
Expand Down Expand Up @@ -260,7 +272,7 @@ def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=
energies = [math.exp(-(c - cmin) / temperature) for c in costs]

# randomly choose a contraction based on energies
chosen, = random.choices(range(n), weights=energies)
chosen, = random_choices(range(n), weights=energies)
cost, k1, k2, k12 = choices.pop(chosen)

# put the other choise back in the heap
Expand Down Expand Up @@ -297,8 +309,8 @@ def _trial_greedy_ssa_path_and_cost(r, inputs, output, size_dict, choose_fn, cos
if r == 0:
# always start with the standard greedy approach
choose_fn = None
else:
random.seed(r)

random_seed(r)

ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
Expand Down

0 comments on commit 7f81269

Please sign in to comment.