diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index f9614486a27ca..2a9d6ca3cfc49 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -49,3 +49,11 @@ def eq(input, other, *, out=None): def argsort(input, dim=-1, descending=False): return ivy.argsort(input, axis=dim, descending=descending) + + +def greater_equal(input, other, *, out=None): + ret = ivy.greater_equal(input, other, out=out) + return ret + + +ge = greater_equal diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py index f7fe05bf62c86..fb34d44e0af04 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py @@ -175,3 +175,44 @@ def test_torch_argsort( dim=axis, descending=descending, ) + + +# greater_equal +@handle_cmd_line_args +@given( + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=tuple( + set(ivy_np.valid_numeric_dtypes).intersection( + set(ivy_torch.valid_numeric_dtypes) + ), + ), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.greater_equal" + ), +) +def test_torch_greater_equal( + dtype_and_inputs, + as_variable, + with_out, + num_positional_args, + native_array, + fw, +): + input_dtype, inputs = dtype_and_inputs + helpers.test_frontend_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="greater_equal", + input=np.asarray(inputs[0], dtype=input_dtype[0]), + other=np.asarray(inputs[1], dtype=input_dtype[1]), + out=None, + )