Skip to content

Commit

Permalink
Revert "[Cherry-pick] Fix Paddle-TRT UT fails (PaddlePaddle#61605)"
Browse files Browse the repository at this point in the history
This reverts commit 867ab0d.
  • Loading branch information
hanhaowen-mt committed May 13, 2024
1 parent b077291 commit d8df156
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
28 changes: 13 additions & 15 deletions test/ir/inference/program_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def generate_weight():
self.outputs = outputs
self.input_type = input_type
self.no_cast_list = [] if no_cast_list is None else no_cast_list
self.supported_cast_type = [np.float32, np.float16]

def __repr__(self):
log_str = ''
Expand All @@ -293,9 +292,11 @@ def __repr__(self):
return log_str

def set_input_type(self, _type: np.dtype) -> None:
assert (
_type in self.supported_cast_type or _type is None
), "PaddleTRT only supports FP32 / FP16 IO"
assert _type in [
np.float32,
np.float16,
None,
], "PaddleTRT only supports FP32 / FP16 IO"

ver = paddle.inference.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
Expand All @@ -308,14 +309,15 @@ def set_input_type(self, _type: np.dtype) -> None:
def get_feed_data(self) -> Dict[str, Dict[str, Any]]:
feed_data = {}
for name, tensor_config in self.inputs.items():
data = tensor_config.data
do_casting = (
self.input_type is not None and name not in self.no_cast_list
)
# Cast to target input_type
if (
self.input_type is not None
and name not in self.no_cast_list
and data.dtype in self.supported_cast_type
):
data = data.astype(self.input_type)
data = (
tensor_config.data.astype(self.input_type)
if do_casting
else tensor_config.data
)
# Truncate FP32 tensors to FP16 precision for FP16 test stability
if data.dtype == np.float32 and name not in self.no_cast_list:
data = data.astype(np.float16).astype(np.float32)
Expand All @@ -332,14 +334,10 @@ def _cast(self) -> None:
for name, inp in self.inputs.items():
if name in self.no_cast_list:
continue
if inp.dtype not in self.supported_cast_type:
continue
inp.convert_type_inplace(self.input_type)
for name, weight in self.weights.items():
if name in self.no_cast_list:
continue
if weight.dtype not in self.supported_cast_type:
continue
weight.convert_type_inplace(self.input_type)
return self

Expand Down
5 changes: 3 additions & 2 deletions test/ir/inference/test_trt_convert_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def clear_dynamic_shape():
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
# Static shape does not support 0 or 1 dim's input
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
if not dynamic_shape and (
self.has_bool_dtype or self.dims == 1 or self.dims == 0
):
return 0, 4
return 1, 2

Expand Down
1 change: 0 additions & 1 deletion test/ir/inference/test_trt_convert_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def generate_input(type):
)
},
outputs=["cast_output_data1"],
no_cast_list=["input_data"],
)

yield program_config
Expand Down

0 comments on commit d8df156

Please sign in to comment.