From f648e9e625fd5b60905dd2c118075b1c89020e90 Mon Sep 17 00:00:00 2001 From: GiBaeg Kim Date: Tue, 23 Aug 2022 04:06:52 +0900 Subject: [PATCH] Add binary_cross_entropy in functional.frontends.torch (#2310) * add frontend.torch.loss_functions and BCE * add test_loss_functions and edit bce * revert nn.loss_functions to loss_functions * Edit formating and edit test code * Edit formatting * Update test_loss_functions.py Update fn_name in helpers.test_frontend_function to fn_tree * Delete statistical.py * Update test_loss_functions.py * Revert "Delete statistical.py" * Update loss_fuctions and test code * Update loss_functions formating * Update test exclude_min and max * Update reviewed code and test code Co-authored-by: jiahanxie353 <765130715@qq.com> --- .../frontends/torch/loss_functions.py | 55 +++++++ .../test_torch/test_loss_functions.py | 152 ++++++++++++++---- 2 files changed, 177 insertions(+), 30 deletions(-) diff --git a/ivy/functional/frontends/torch/loss_functions.py b/ivy/functional/frontends/torch/loss_functions.py index 655a0ec43d187..02e3647be41f4 100644 --- a/ivy/functional/frontends/torch/loss_functions.py +++ b/ivy/functional/frontends/torch/loss_functions.py @@ -2,6 +2,41 @@ import ivy +def _get_reduction_func(reduction): + if reduction == 'none': + ret = lambda x : x + elif reduction == 'mean': + ret = ivy.mean + elif reduction == 'sum': + ret = ivy.sum + else: + raise ValueError("{} is not a valid value for reduction".format(reduction)) + return ret + + +def _legacy_get_string(size_average, reduce): + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = 'mean' + elif reduce: + ret = 'sum' + else: + ret = 'none' + return ret + + +def _get_reduction(reduction, + size_average=None, + reduce=None): + if size_average is not None or reduce is not None: + return _get_reduction_func(_legacy_get_string(size_average, reduce)) + else: + return _get_reduction_func(reduction) + + def cross_entropy( input, target, @@ -16,3 +51,23 @@ def cross_entropy( cross_entropy.unsupported_dtypes = ("float16",) + + +def binary_cross_entropy( + input, + target, + weight=None, + size_average=None, + reduce=None, + reduction='mean' +): + reduction = _get_reduction(reduction, size_average, reduce) + result = ivy.binary_cross_entropy(target, input, epsilon=0.0) + + if weight is not None: + result = ivy.multiply(weight, result) + result = reduction(result) + return result + + +binary_cross_entropy.unsupported_dtypes = ('float16', 'float64') diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py index bb9925468369f..e44e85d3c2b1e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py @@ -15,34 +15,34 @@ set(ivy_np.valid_float_dtypes).intersection( set(ivy_torch.valid_float_dtypes) ) - ), - min_value=0, - max_value=1, - allow_inf=False, - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, + ), + min_value=0, + max_value=1, + allow_inf=False, + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, ), dtype_and_target=helpers.dtype_and_values( available_dtypes=tuple( set(ivy_np.valid_float_dtypes).intersection( set(ivy_torch.valid_float_dtypes) ) - ), - min_value=1.0013580322265625e-05, - max_value=1, - allow_inf=False, - exclude_min=True, - exclude_max=True, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), - as_variable=helpers.list_of_length(x=st.booleans(), length=2), + ), + min_value=1.0013580322265625e-05, + max_value=1, + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + as_variable=helpers.list_of_length(x=st.booleans(), length=2), num_positional_args=helpers.num_positional_args( fn_name="ivy.functional.frontends.torch.cross_entropy" - ), - native_array=helpers.list_of_length(x=st.booleans(), length=2), + ), + native_array=helpers.list_of_length(x=st.booleans(), length=2), ) def test_torch_cross_entropy( dtype_and_input, @@ -54,14 +54,106 @@ def test_torch_cross_entropy( inputs_dtype, input = dtype_and_input target_dtype, target = dtype_and_target helpers.test_frontend_function( - input_dtypes=[inputs_dtype, target_dtype], - as_variable_flags=as_variable, - with_out=False, - num_positional_args=num_positional_args, - native_array_flags=native_array, - fw=fw, - frontend="torch", - fn_tree="nn.functional.cross_entropy", - input=np.asarray(input, dtype=inputs_dtype), - target=np.asarray(target, dtype=target_dtype), + input_dtypes=[inputs_dtype, target_dtype], + as_variable_flags=as_variable, + with_out=False, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="nn.functional.cross_entropy", + input=np.asarray(input, dtype=inputs_dtype), + target=np.asarray(target, dtype=target_dtype), + ) + + +# binary_cross_entropy +@given( + dtype_and_true=helpers.dtype_and_values( + available_dtypes=tuple( + set(ivy_np.valid_float_dtypes).intersection( + set(ivy_torch.valid_float_dtypes) + ) + ), + min_value=0.0, + max_value=1.0, + large_value_safety_factor=1.0, + small_value_safety_factor=1.0, + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=tuple( + set(ivy_np.valid_float_dtypes).intersection( + set(ivy_torch.valid_float_dtypes) + ) + ), + min_value=1.0013580322265625e-05, + max_value=1.0, + large_value_safety_factor=1.0, + small_value_safety_factor=1.0, + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + dtype_and_weight=helpers.dtype_and_values( + available_dtypes=tuple( + set(ivy_np.valid_float_dtypes).intersection( + set(ivy_torch.valid_float_dtypes) + ) + ), + min_value=1.0013580322265625e-05, + max_value=1.0, + allow_inf=False, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + size_average=st.booleans(), + reduce=st.booleans(), + reduction=st.sampled_from(["mean", "none", "sum", None]), + as_variable=helpers.list_of_length(x=st.booleans(), length=3), + num_positional_args=helpers.num_positional_args( + fn_name="ivy.functional.frontends.torch.binary_cross_entropy" + ), + native_array=helpers.list_of_length(x=st.booleans(), length=3), +) +def test_binary_cross_entropy( + dtype_and_true, + dtype_and_pred, + dtype_and_weight, + size_average, + reduce, + reduction, + as_variable, + num_positional_args, + native_array, + fw, +): + pred_dtype, pred = dtype_and_pred + true_dtype, true = dtype_and_true + weight_dtype, weight = dtype_and_weight + + helpers.test_frontend_function( + input_dtypes=[pred_dtype, true_dtype, weight_dtype], + as_variable_flags=as_variable, + with_out=False, + num_positional_args=num_positional_args, + native_array_flags=native_array, + fw=fw, + frontend="torch", + fn_tree="nn.functional.binary_cross_entropy", + input=np.asarray(pred, dtype=pred_dtype), + target=np.asarray(true, dtype=true_dtype), + weight=np.asarray(weight, dtype=weight_dtype), + size_average=size_average, + reduce=reduce, + reduction=reduction, )