Skip to content

Commit

Permalink
contraction_list: only keep remaining for last 10 contractions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jul 16, 2020
1 parent 2cbd28e commit c00c8bd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
23 changes: 17 additions & 6 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,22 @@ def __repr__(self):
" Complete contraction: {}\n".format(self.eq), " Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)), " Naive FLOP count: {:.3e}\n".format(
self.naive_cost), " Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
" Theoretical speedup: {:3.3f}\n".format(self.speedup),
" Theoretical speedup: {:.3e}\n".format(self.speedup),
" Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), "-" * 80 + "\n",
"{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80
]

for n, contraction in enumerate(self.contraction_list):
inds, idx_rm, einsum_str, remaining, do_blas = contraction
remaining_str = ",".join(remaining) + "->" + self.output_subscript
path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str)
path_print.append("\n{:>4} {:>14} {:>22} {:>37}".format(*path_run))

if remaining is not None:
remaining_str = ",".join(remaining) + "->" + self.output_subscript
else:
remaining_str = "..."
size_remaining = max(0, 56 - max(22, len(einsum_str)))

path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str, size_remaining)
path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run))

return "".join(path_print)

Expand Down Expand Up @@ -303,7 +309,12 @@ def contract_path(*operands, **kwargs):

einsum_str = ",".join(tmp_inputs) + "->" + idx_result

contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
if len(input_list) <= 10:
remaining = tuple(input_list)
else:
remaining = None

contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas)
contraction_list.append(contraction)

opt_cost = sum(cost_list)
Expand Down Expand Up @@ -529,7 +540,7 @@ def _core_contract(operands, contraction_list, backend='auto', evaluate_constant

# Start contraction loop
for num, contraction in enumerate(contraction_list):
inds, idx_rm, einsum_str, remaining, blas_flag = contraction
inds, idx_rm, einsum_str, _, blas_flag = contraction

# check if we are performing the pre-pass of an expression with constants,
# if so, break out upon finding first non-constant (None) operand
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_printing():
views = helpers.build_views(string)

ein = contract_path(string, *views)
assert len(str(ein[1])) == 726
assert len(str(ein[1])) == 728


@pytest.mark.parametrize("string", tests)
Expand Down

0 comments on commit c00c8bd

Please sign in to comment.