Skip to content

Commit

Permalink
test(axes): Add tensordot tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 15, 2024
1 parent 3dacb89 commit 21c3b01
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,51 @@ def test_linalg_solve_incompatible_left():
np.linalg.solve(arrA, arrb)


def test_ts_to_einsum_int_axes():
a_str, b_str = axes._tensordot_to_einsum(3, 3, 2).split(",")
# expecting 'abc,bcf
assert a_str[1] == b_str[0]
assert a_str[2] == b_str[1]
assert a_str[0] not in b_str
assert b_str[2] not in a_str


def test_ts_to_einsum_list_axes():
a_str, b_str = axes._tensordot_to_einsum(3, 3, [[1], [2]]).split(",")
# expecting 'abcd,efbh
assert a_str[0] not in b_str
assert a_str[1] == b_str[2]
assert a_str[2] not in b_str
assert a_str[3] not in b_str
assert b_str[0] not in a_str
assert b_str[1] not in a_str
assert b_str[3] not in a_str


def test_tensordot_int_axes():
...
axes_a = {"ax_a": 0, "ax_b": [1, 2]}
axes_b = {"ax_b": [0, 1], "ax_c": 2}
arr = np.arange(8).reshape((2, 2, 2))
arr_a = AxesArray(arr, axes_a)
arr_b = AxesArray(arr, axes_b)
result = np.tensordot(arr_a, arr_b, 2)
super_result = np.tensordot(arr, arr, 2)
expected_axes = {"ax_a": 0, "ax_c": 1}
assert result.axes == expected_axes
assert_array_equal(result, super_result)


def test_tensordot_list_axes():
...
axes_a = {"ax_a": 0, "ax_b": [1, 2]}
axes_b = {"ax_c": [0, 1], "ax_b": 2}
arr = np.arange(8).reshape((2, 2, 2))
arr_a = AxesArray(arr, axes_a)
arr_b = AxesArray(arr, axes_b)
result = np.tensordot(arr_a, arr_b, [[1], [2]])
super_result = np.tensordot(arr, arr, 2)
expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]}
assert result.axes == expected_axes
assert_array_equal(result, super_result)


def test_einsum_implicit():
Expand Down

0 comments on commit 21c3b01

Please sign in to comment.