diff --git a/ivy/functional/frontends/paddle/nn/functional/common.py b/ivy/functional/frontends/paddle/nn/functional/common.py index 442b971447a97..4d7ee4f32d23e 100644 --- a/ivy/functional/frontends/paddle/nn/functional/common.py +++ b/ivy/functional/frontends/paddle/nn/functional/common.py @@ -23,3 +23,39 @@ def cosine_similarity(x1, x2, *, axis=1, eps=1e-08): cosine = numerator / denominator return cosine + + +def get_mask(shape, device, prob, seed=None): + mask = ivy.where( + ivy.random_uniform(shape=shape, device=device, seed=seed) < prob, + 0.0, + 1.0, + ) + return mask + + +@with_supported_dtypes({"2.4.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def dropout(x, p=0.5, axis=None, training=True, mode="upscale_in_train", name=None): + if axis > 1: + raise ValueError("Axis value can only be 0 or 1 or None.") + elif axis is None or (isinstance(axis, list) and len(axis) == 2): + mask = get_mask(shape=x.shape, device=ivy.dev(x), prob=p, seed=None) + elif axis == 0: + mask = get_mask(shape=(x.shape[0], 1), device=ivy.dev(x), prob=p) + mask = ivy.broadcast_to(mask, x.shape) + elif axis == 1: + mask = get_mask(shape=(1, x.shape[1]), device=ivy.dev(x), prob=p) + mask = ivy.broadcast_to(mask, x.shape) + if mode == "upscale_in_train": + if training: + out = ivy.multiply(x, mask) + ret = ivy.multiply(out, 1.0 / (1.0 - p)) + else: + ret = x + else: + if training: + ret = ivy.multiply(x, mask) + else: + ret = ivy.multiply(x, (1.0 - p)) + return ret diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_paddle_common.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_paddle_common.py index f0e5be38decea..43f5e76f64b3b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_paddle_common.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_paddle_common.py @@ -41,3 +41,50 @@ def test_paddle_cosine_similarity( x2=x[1], axis=axis, ) + + +# dropout +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.dropout", + d_type_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + p=st.floats(min_value=0.0, max_value=1.0), + axis=st.integers(min_value=0, max_value=1), + training=st.booleans(), + mode=st.one_of( + *[st.just(seq) for seq in ["upscale_in_train", "downscale_in_infer"]] + ), +) +def test_paddle_dropout( + *, + d_type_and_x, + p, + on_device, + fn_tree, + frontend, + test_flags, + training, + axis, + mode, +): + dtype, x = d_type_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + p=p, + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + training=training, + axis=axis, + mode=mode, + )