Skip to content

Commit

Permalink
Fixed broken test
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed Jan 14, 2024
1 parent 4a96bc9 commit f7e8b1f
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/test_tensor_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import numpy as np

from tricycle_v2.binary import badd, bdiv, bmul, bsub
Expand All @@ -13,6 +15,11 @@ def test_can_add_tensors():

assert np.allclose(tensor_1 + tensor_2, badd(tensor_1, tensor_2))

before = deepcopy(tensor_1)
tensor_1 += 1

assert np.allclose(tensor_1, uadd(before, 1))


def test_can_subtract_tensors():
tensor_1 = to_tensor(np.arange(12).reshape(3, 4))
Expand All @@ -22,6 +29,11 @@ def test_can_subtract_tensors():

assert np.allclose(tensor_1 - tensor_2, bsub(tensor_1, tensor_2))

before = deepcopy(tensor_1)
tensor_1 -= 1

assert np.allclose(tensor_1, usub(before, 1))


def test_can_multiply_tensors():
tensor_1 = to_tensor(np.arange(12).reshape(3, 4))
Expand All @@ -31,15 +43,25 @@ def test_can_multiply_tensors():

assert np.allclose(tensor_1 * tensor_2, bmul(tensor_1, tensor_2))

before = deepcopy(tensor_1)
tensor_1 *= 2

assert np.allclose(tensor_1, umul(before, 2))


def test_can_divide_tensors():
tensor_1 = to_tensor(np.arange(1, 13).reshape(3, 4))
tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4))
tensor_1 = to_tensor(np.arange(1, 13).reshape(3, 4).astype(float))
tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4).astype(float))

assert np.allclose(tensor_1 / 2, udiv(tensor_1, 2))

assert np.allclose(tensor_1 / tensor_2, bdiv(tensor_1, tensor_2))

before = deepcopy(tensor_1)
tensor_1 /= 2.0

assert np.allclose(tensor_1, udiv(before, 2))


def test_can_pow_tensors():
tensor_1 = to_tensor(np.arange(12).reshape(3, 4))
Expand Down

0 comments on commit f7e8b1f

Please sign in to comment.