-
Notifications
You must be signed in to change notification settings - Fork 32
/
path_greedy.py
60 lines (50 loc) · 1.55 KB
/
path_greedy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import functools
from ..core import ContractionTree, jitter_dict
from ..hyperoptimizers.hyper import register_hyper_function
from .path_basic import get_optimize_greedy
ssa_greedy_optimize = functools.partial(get_optimize_greedy(), use_ssa=True)
# ------------------------------ GREEDY HYPER ------------------------------- #
def trial_greedy(
inputs,
output,
size_dict,
random_strength=0.0,
temperature=0.0,
costmod=1.0,
):
if random_strength != 0.0:
# don't supply randomized sizes to actual contraction tree
greedy_size_dict = jitter_dict(size_dict, random_strength)
else:
greedy_size_dict = size_dict
ssa_path = ssa_greedy_optimize(
inputs,
output,
greedy_size_dict,
temperature=temperature,
costmod=costmod,
)
return ContractionTree.from_path(
inputs, output, size_dict, ssa_path=ssa_path
)
register_hyper_function(
name="greedy",
ssa_func=trial_greedy,
space={
"random_strength": {"type": "FLOAT_EXP", "min": 0.001, "max": 1.0},
"temperature": {"type": "FLOAT_EXP", "min": 0.001, "max": 1.0},
"costmod": {"type": "FLOAT", "min": 0.1, "max": 4.0},
},
)
# greedy but less exploratative -> better for a small number of runs
register_hyper_function(
name="random-greedy",
ssa_func=trial_greedy,
space={
"temperature": {"type": "FLOAT_EXP", "min": 0.001, "max": 0.1},
"costmod": {"type": "FLOAT", "min": 0.5, "max": 3.0},
},
constants={
"random_strength": 0.0,
},
)