Skip to content

Commit

Permalink
feat: torch backend basic operation tests (#1306)
Browse files Browse the repository at this point in the history
Signed-off-by: agaraman0 <agaraman0@gmail.com>
  • Loading branch information
agaraman0 committed Mar 29, 2023
1 parent 69c7a77 commit d427754
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/units/computation_backends/torch_backend/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -100,3 +101,46 @@ def test_minmax_normalize(array, t_range, x_range, result):
tensor=array, t_range=t_range, x_range=x_range
)
assert torch.allclose(output, result)


def test_reshape():
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
b = TorchCompBackend.reshape(a, (2, 3))
assert torch.equal(b, torch.tensor([[1, 2, 3], [4, 5, 6]]))


def test_copy():
a = torch.tensor([1, 2, 3])
b = TorchCompBackend.copy(a)
assert torch.equal(a, b)


def test_stack():
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
stacked = TorchCompBackend.stack([a, b], dim=0)
assert torch.equal(stacked, torch.tensor([[1, 2, 3], [4, 5, 6]]))


def test_empty_all():
shape = (2, 3)
dtype = torch.float32
device = 'cpu'
a = TorchCompBackend.empty(shape, dtype, device)
assert a.shape == shape and a.dtype == dtype and a.device.type == device


def test_to_numpy():
a = torch.tensor([1, 2, 3])
b = TorchCompBackend.to_numpy(a)
assert np.array_equal(b, np.array(a))


def test_none_value():
assert torch.isnan(TorchCompBackend.none_value())


def test_detach():
a = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
b = TorchCompBackend.detach(a)
assert not b.requires_grad

0 comments on commit d427754

Please sign in to comment.