diff --git a/ivy/functional/frontends/paddle/nn/functional/common.py b/ivy/functional/frontends/paddle/nn/functional/common.py index 400779ca05f25..321c7da2a45a4 100644 --- a/ivy/functional/frontends/paddle/nn/functional/common.py +++ b/ivy/functional/frontends/paddle/nn/functional/common.py @@ -109,3 +109,8 @@ def interpolate( def linear(x, weight, bias=None, name=None): weight = ivy.swapaxes(weight, -1, -2) return ivy.linear(x, weight, bias=bias) + +@to_ivy_arrays_and_back +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +def class_center_sample ( label, num_classes, num_samples, group=None ): + return ivy.class_center_sample( label, num_classes, num_samples, group=group ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py index 4afbd027ddf77..c07190df5d33d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py @@ -416,3 +416,33 @@ def test_linear( weight=weight, bias=bias, ) + +# class_center_sample +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.class_center_sample", + dtype_x_weight_bias=x_and_linear( + dtypes=helpers.get_dtypes("float", full=False), + ), +) +def test_class_center_sample( + *, + dtype_x_weight_bias, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + dtype, x, weight, bias = dtype_x_weight_bias + weight = ivy.swapaxes(weight, -1, -2) + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x, + weight=weight, + bias=bias, + ) \ No newline at end of file