Skip to content

Commit

Permalink
【PIR API adaptor No.206、207】 Migrate paddle.sign/sinh into pir (Paddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee committed Oct 30, 2023
1 parent b500d06 commit f653b39
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Expand Up @@ -4575,7 +4575,7 @@ def sign(x, name=None):
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[ 1., 0., -1., 1.])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.sign(x)
else:
check_variable_and_dtype(
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/ops.py
Expand Up @@ -1010,7 +1010,7 @@ def sinh(x, name=None):
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.41075233, -0.20133601, 0.10016675, 0.30452031])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.sinh(x)
else:
check_variable_and_dtype(
Expand Down
7 changes: 6 additions & 1 deletion test/legacy_test/test_activation_op.py
Expand Up @@ -915,10 +915,13 @@ def setUp(self):

self.convert_input_output()

def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestSinh_Complex64(TestSinh):
Expand All @@ -945,6 +948,7 @@ def test_dygraph(self):
z_expected = np.sinh(np_x)
np.testing.assert_allclose(z, z_expected, rtol=1e-05)

@test_with_pir_api
def test_api(self):
with static_guard():
test_data_shape = [11, 17]
Expand Down Expand Up @@ -985,6 +989,7 @@ def test_backward(self):


class TestSinhOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with static_guard():
with program_guard(Program()):
Expand Down
8 changes: 4 additions & 4 deletions test/legacy_test/test_sign_op.py
Expand Up @@ -34,10 +34,10 @@ def setUp(self):
self.outputs = {'Out': np.sign(self.inputs['X'])}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestSignFP16Op(TestSignOp):
Expand Down Expand Up @@ -70,10 +70,10 @@ def setUp(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out', check_pir=True)


class TestSignAPI(unittest.TestCase):
Expand Down

0 comments on commit f653b39

Please sign in to comment.