-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
np.vecterize interaction with max infers wrong otype #18239
Description
When np.vecterize is used to wrap function that utlize the python default max function, even if the input dtype is 'float', the output dtype is still 'int'. it surprisingly works fine with the min function.
In the code below, df_test['duration'] returns int array by applying compute_delinquent_duration function. The other two implementations using pandas or map both return the right type as float.
I know with np.vectorize, I can use the function argument otypes=np.float to force output to be float, but I think the default behaviour could cause unexpected behaviour if users are not aware of this issue.
Reproducing code example:
import numpy as np
import pandas as pd
# bad --> interaction of np.vectorize with max return int
def compute_delinquent_duration(arrears_months):
return max(arrears_months - 3.1, 0) # return Int --> bad
# return min(arrears_months - 3.1, 0) # return float --> good
compute_delinquent_duration = np.vectorize(compute_delinquent_duration)
# good --> return float
def compute_delinquent_duration2(row):
return max(row['arrears_months'] - 3.1 ,0)
# good --> return float
def compute_delinquent_duration3(arrears_months):
return max(arrears_months - 3.1 ,0)
df_test = pd.DataFrame(columns = ['arrears_months'], data = [3,3.1,4.3, 4.1, 5]) # return int for compute_delinquent_duration
# df_test = pd.DataFrame(columns = ['arrears_months'], data = [4.2, 3,4, 4, 5]) # return float for compute_delinquent_duration
df_test['duration'] = compute_delinquent_duration(df_test['arrears_months'])
df_test['duration2'] = df_test.apply(compute_delinquent_duration2, axis = 1)
df_test['duration3'] = list(map(compute_delinquent_duration3, df_test['arrears_months']))
print(df_test)
print(df_test.dtypes)Error:
No error message, but wrong return type.
output:
arrears_months duration duration2 duration3
0 3.0 0 0.0 0.0
1 3.1 0 0.0 0.0
2 4.3 1 1.2 1.2
3 4.1 0 1.0 1.0
4 5.0 1 1.9 1.9
NumPy/Python version information:
1.18.2 3.7.6 (default, Dec 30 2019, 19:38:28)