Skip to content

Commit

Permalink
fix empty tensordot axes, raise dp max cost_cap
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Nov 4, 2020
1 parent c75b8e7 commit d94dd38
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,18 @@ def _core_contract(operands, contraction_list, backend='auto', evaluate_constant

tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)

# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))

# Construct the axes tuples in a canonical order
axes = tuple(zip(*sorted(zip(left_pos, right_pos))))
if idx_rm:
# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))

# Construct the axes tuples in a canonical order
axes = tuple(zip(*sorted(zip(left_pos, right_pos))))
else:
# Ensure axes is always pair of tuples
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
output = set(symbol2int[c] for c in output)
size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
size_dict = [size_dict[j] for j in range(len(size_dict))]
naive_cost = functools.reduce(operator.mul, size_dict)
naive_cost = len(inputs) * functools.reduce(operator.mul, size_dict)

inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)

Expand Down

0 comments on commit d94dd38

Please sign in to comment.