Skip to content

Commit

Permalink
fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
NotHaozi committed Jun 16, 2023
1 parent 6d6b853 commit b61b1aa
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/dygraph_to_static/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -413,6 +414,7 @@ def verify_predict(self):
),
)

@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -426,6 +428,7 @@ def test_resnet(self):
)
self.verify_predict()

@ast_only_test
def test_resnet_composite_backward(self):
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
Expand All @@ -440,6 +443,7 @@ def test_resnet_composite_backward(self):
),
)

@ast_only_test
def test_resnet_composite_forward_backward(self):
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
Expand All @@ -454,6 +458,7 @@ def test_resnet_composite_forward_backward(self):
),
)

@ast_only_test
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
4 changes: 4 additions & 0 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -412,6 +413,7 @@ def verify_predict(self):
),
)

@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -425,6 +427,7 @@ def test_resnet(self):
)
self.verify_predict()

@ast_only_test
def test_resnet_composite(self):
core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("batch_norm")
Expand All @@ -440,6 +443,7 @@ def test_resnet_composite(self):
),
)

@ast_only_test
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
2 changes: 2 additions & 0 deletions test/dygraph_to_static/test_rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle
from paddle.jit.dy2static.program_translator import StaticFunction
Expand Down Expand Up @@ -88,6 +89,7 @@ class TestRollBackNet(unittest.TestCase):
def setUp(self):
paddle.set_device("cpu")

@ast_only_test
def test_net(self):
net = paddle.jit.to_static(Net())
x = paddle.randn([3, 4])
Expand Down
3 changes: 3 additions & 0 deletions test/dygraph_to_static/test_save_inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle
from paddle import fluid
Expand Down Expand Up @@ -53,6 +54,7 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_save_inference_model(self):
fc_size = 20
x_data = np.random.random((fc_size, fc_size)).astype('float32')
Expand Down Expand Up @@ -145,6 +147,7 @@ def load_and_run_inference(


class TestPartialProgramRaiseError(unittest.TestCase):
@ast_only_test
def test_param_type(self):
paddle.jit.enable_to_static(True)
x_data = np.random.random((20, 20)).astype('float32')
Expand Down
3 changes: 3 additions & 0 deletions test/dygraph_to_static/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from test_fetch_feed import Linear

import paddle
Expand Down Expand Up @@ -115,6 +116,7 @@ def test_save_load_same_result(self):
dygraph_loss.numpy(), static_loss.numpy(), rtol=1e-05
)

@ast_only_test
def test_save_load_prim(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
Expand Down Expand Up @@ -155,6 +157,7 @@ def test_save_load_prim(self):
self.assertIn("pool2d", load_op_type_list)
np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05)

@ast_only_test
def test_save_load_prim_with_hook(self):
with fluid.dygraph.guard(place):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
Expand Down

0 comments on commit b61b1aa

Please sign in to comment.