diff --git a/ivy/functional/frontends/torch/reduction_ops.py b/ivy/functional/frontends/torch/reduction_ops.py index e69de29bb2d1d..eb700ee099942 100644 --- a/ivy/functional/frontends/torch/reduction_ops.py +++ b/ivy/functional/frontends/torch/reduction_ops.py @@ -0,0 +1,33 @@ +import ivy + + +def argmax(input, dim=None, keepdim=False): + return ivy.argmax(input, axis=dim, keepdims=keepdim) + + +def argmin(input, dim=None, keepdim=False): + return ivy.argmin(input, axis=dim, keepdims=keepdim) + + +def amax(input, dim=None, keepdim=False, *, out=None): + return ivy.max(input, axis=dim, keepdims=keepdim, out=out) + + +def amin(input, dim=None, keepdim=False, *, out=None): + return ivy.min(input, axis=dim, keepdims=keepdim, out=out) + + +def all(input, dim=None, keepdim=False, *, out=None): + input_dtype = ivy.as_ivy_dtype(input.dtype) + ret = ivy.all(input, axis=dim, keepdims=keepdim, out=out) + if ivy.is_uint_dtype(input_dtype): + ret = ivy.astype(ret, input_dtype, out=out) + return ret + + +def any(input, dim=None, keepdim=False, *, out=None): + input_dtype = ivy.as_ivy_dtype(input.dtype) + ret = ivy.any(input, axis=dim, keepdims=keepdim, out=out) + if ivy.is_uint_dtype(input_dtype): + ret = ivy.astype(ret, input_dtype, out=out) + return ret diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py index e69de29bb2d1d..2ac3786d4acf1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py @@ -0,0 +1,241 @@ +# global +import numpy as np +from hypothesis import given, strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_cmd_line_args + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.argmax" + ), + keepdims=st.booleans(), +) +def test_torch_argmax( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=False, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="argmax", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + ) + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.argmin" + ), + keepdims=st.booleans(), +) +def test_torch_argmin( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=False, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="argmin", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + ) + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.amax" + ), + keepdims=st.booleans(), +) +def test_torch_amax( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + with_out, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="amax", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + out=None, + ) + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.amin" + ), + keepdims=st.booleans(), +) +def test_torch_amin( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + with_out, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="amin", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + out=None, + ) + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + allow_inf=False, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.all" + ), + keepdims=st.booleans(), +) +def test_torch_all( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + with_out, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="all", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + out=None, + ) + + +@handle_cmd_line_args +@given( + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + allow_inf=False, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.any" + ), + keepdims=st.booleans(), +) +def test_torch_any( + dtype_input_axis, + as_variable, + num_positional_args, + native_array, + keepdims, + with_out, + fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="any", + input=np.asarray(x, dtype=input_dtype), + dim=axis, + keepdim=keepdims, + out=None, + )