From 21c3b0100cde38886264f692d1201259feb90287 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 21:27:33 +0000 Subject: [PATCH] test(axes): Add tensordot tests --- test/utils/test_axes.py | 43 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 38b19350b..a59fd2891 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -559,12 +559,51 @@ def test_linalg_solve_incompatible_left(): np.linalg.solve(arrA, arrb) +def test_ts_to_einsum_int_axes(): + a_str, b_str = axes._tensordot_to_einsum(3, 3, 2).split(",") + # expecting 'abc,bcf + assert a_str[1] == b_str[0] + assert a_str[2] == b_str[1] + assert a_str[0] not in b_str + assert b_str[2] not in a_str + + +def test_ts_to_einsum_list_axes(): + a_str, b_str = axes._tensordot_to_einsum(3, 3, [[1], [2]]).split(",") + # expecting 'abcd,efbh + assert a_str[0] not in b_str + assert a_str[1] == b_str[2] + assert a_str[2] not in b_str + assert a_str[3] not in b_str + assert b_str[0] not in a_str + assert b_str[1] not in a_str + assert b_str[3] not in a_str + + def test_tensordot_int_axes(): - ... + axes_a = {"ax_a": 0, "ax_b": [1, 2]} + axes_b = {"ax_b": [0, 1], "ax_c": 2} + arr = np.arange(8).reshape((2, 2, 2)) + arr_a = AxesArray(arr, axes_a) + arr_b = AxesArray(arr, axes_b) + result = np.tensordot(arr_a, arr_b, 2) + super_result = np.tensordot(arr, arr, 2) + expected_axes = {"ax_a": 0, "ax_c": 1} + assert result.axes == expected_axes + assert_array_equal(result, super_result) def test_tensordot_list_axes(): - ... + axes_a = {"ax_a": 0, "ax_b": [1, 2]} + axes_b = {"ax_c": [0, 1], "ax_b": 2} + arr = np.arange(8).reshape((2, 2, 2)) + arr_a = AxesArray(arr, axes_a) + arr_b = AxesArray(arr, axes_b) + result = np.tensordot(arr_a, arr_b, [[1], [2]]) + super_result = np.tensordot(arr, arr, 2) + expected_axes = {"ax_a": 0, "ax_b": 1, "ax_c": [2, 3]} + assert result.axes == expected_axes + assert_array_equal(result, super_result) def test_einsum_implicit():