Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: datetime: Unify datetime/timedelta type promotion #86

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 0 additions & 13 deletions numpy/core/src/multiarray/_datetime.h
Expand Up @@ -93,19 +93,6 @@ datetime_metadata_divides(
PyArray_Descr *divisor,
int strict_with_nonlinear_units);

/*
* Computes the GCD of the two date-time metadata values. Raises
* an exception if there is no reasonable GCD, such as with
* years and days.
*
* Returns a capsule with the GCD metadata.
*/
NPY_NO_EXPORT PyObject *
compute_datetime_metadata_greatest_common_divisor(
PyArray_Descr *type1,
PyArray_Descr *type2,
int strict_with_nonlinear_units);

/*
* Computes the conversion factor to convert data with 'src_meta' metadata
* into data with 'dst_meta' metadata, not taking into account the events.
Expand Down
88 changes: 41 additions & 47 deletions numpy/core/src/multiarray/datetime.c
Expand Up @@ -1632,11 +1632,12 @@ datetime_metadata_divides(
}


NPY_NO_EXPORT PyObject *
static PyObject *
compute_datetime_metadata_greatest_common_divisor(
PyArray_Descr *type1,
PyArray_Descr *type2,
int strict_with_nonlinear_units)
int strict_with_nonlinear_units1,
int strict_with_nonlinear_units2)
{
PyArray_DatetimeMetaData *meta1, *meta2, *dt_data;
NPY_DATETIMEUNIT base;
Expand Down Expand Up @@ -1688,7 +1689,7 @@ compute_datetime_metadata_greatest_common_divisor(
base = NPY_FR_M;
num1 *= 12;
}
else if (strict_with_nonlinear_units) {
else if (strict_with_nonlinear_units1) {
goto incompatible_units;
}
else {
Expand All @@ -1701,19 +1702,34 @@ compute_datetime_metadata_greatest_common_divisor(
base = NPY_FR_M;
num2 *= 12;
}
else if (strict_with_nonlinear_units) {
else if (strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
base = meta1->base;
/* Don't multiply num2 since there is no even factor */
}
}
else if (meta1->base == NPY_FR_M ||
meta1->base == NPY_FR_B ||
meta2->base == NPY_FR_M ||
meta2->base == NPY_FR_B) {
if (strict_with_nonlinear_units) {
else if (meta1->base == NPY_FR_M) {
if (strict_with_nonlinear_units1) {
goto incompatible_units;
}
else {
base = meta2->base;
/* Don't multiply num1 since there is no even factor */
}
}
else if (meta2->base == NPY_FR_M) {
if (strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
base = meta1->base;
/* Don't multiply num2 since there is no even factor */
}
}
else if (meta1->base == NPY_FR_B || meta2->base == NPY_FR_B) {
if (strict_with_nonlinear_units1 || strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
Expand Down Expand Up @@ -1801,28 +1817,38 @@ units_overflow: {
}

/*
* Uses type1's type_num and the gcd of the metadata to create
* the result type.
* Both type1 and type2 must be either NPY_DATETIME or NPY_TIMEDELTA.
* Applies the type promotion rules between the two types, returning
* the promoted type.
*/
static PyArray_Descr *
datetime_gcd_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
NPY_NO_EXPORT PyArray_Descr *
datetime_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
{
int type_num1, type_num2;
PyObject *gcdmeta;
PyArray_Descr *dtype;
int is_datetime;

type_num1 = type1->type_num;
type_num2 = type2->type_num;

is_datetime = (type_num1 == NPY_DATETIME || type_num2 == NPY_DATETIME);

/*
* Get the metadata GCD, being strict about nonlinear units for
* timedelta and relaxed for datetime.
*/
gcdmeta = compute_datetime_metadata_greatest_common_divisor(
type1, type2,
type1->type_num == NPY_TIMEDELTA);
type_num1 == NPY_TIMEDELTA,
type_num2 == NPY_TIMEDELTA);
if (gcdmeta == NULL) {
return NULL;
}

/* Create a DATETIME or TIMEDELTA dtype */
dtype = PyArray_DescrNewFromType(type1->type_num);
dtype = PyArray_DescrNewFromType(is_datetime ? NPY_DATETIME :
NPY_TIMEDELTA);
if (dtype == NULL) {
Py_DECREF(gcdmeta);
return NULL;
Expand All @@ -1847,39 +1873,7 @@ datetime_gcd_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
Py_DECREF(gcdmeta);

return dtype;
}

/*
* Both type1 and type2 must be either NPY_DATETIME or NPY_TIMEDELTA.
* Applies the type promotion rules between the two types, returning
* the promoted type.
*/
NPY_NO_EXPORT PyArray_Descr *
datetime_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
{
int type_num1, type_num2;

type_num1 = type1->type_num;
type_num2 = type2->type_num;

if (type_num1 == NPY_DATETIME) {
if (type_num2 == NPY_DATETIME) {
return datetime_gcd_type_promotion(type1, type2);
}
else if (type_num2 == NPY_TIMEDELTA) {
Py_INCREF(type1);
return type1;
}
}
else if (type_num1 == NPY_TIMEDELTA) {
if (type_num2 == NPY_DATETIME) {
Py_INCREF(type2);
return type2;
}
else if (type_num2 == NPY_TIMEDELTA) {
return datetime_gcd_type_promotion(type1, type2);
}
}

PyErr_SetString(PyExc_RuntimeError,
"Called datetime_type_promotion on non-datetype type");
Expand Down
102 changes: 59 additions & 43 deletions numpy/core/src/umath/ufunc_object.c
Expand Up @@ -2243,8 +2243,8 @@ timedelta_dtype_with_copied_meta(PyArray_Descr *dtype)
* int + m8[<A>] => m8[<A>] + m8[<A>]
* M8[<A>] + int => M8[<A>] + m8[<A>]
* int + M8[<A>] => m8[<A>] + M8[<A>]
* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>]
* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>]
* M8[<A>] + m8[<B>] => M8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)]
* m8[<A>] + M8[<B>] => m8[gcd(<A>,<B>)] + M8[gcd(<A>,<B>)]
* TODO: Non-linear time unit cases require highly special-cased loops
* M8[<A>] + m8[Y|M|B]
* m8[Y|M|B] + M8[<A>]
Expand Down Expand Up @@ -2287,16 +2287,20 @@ PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>] */
/* m8[<A>] + M8[<B>] => m8[gcd(<A>,<B>)] + M8[gcd(<A>,<B>)] */
else if (type_num2 == NPY_DATETIME) {
/* Make a new NPY_TIMEDELTA, and copy type2's metadata */
out_dtypes[0] = timedelta_dtype_with_copied_meta(
PyArray_DESCR(operands[1]));
out_dtypes[1] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[1] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[0] = timedelta_dtype_with_copied_meta(out_dtypes[1]);
if (out_dtypes[0] == NULL) {
Py_DECREF(out_dtypes[1]);
out_dtypes[1] = NULL;
return -1;
}
out_dtypes[1] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[1]);
out_dtypes[2] = out_dtypes[1];
Py_INCREF(out_dtypes[2]);
}
Expand All @@ -2317,10 +2321,25 @@ PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
}
}
else if (type_num1 == NPY_DATETIME) {
/* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>] */
/* M8[<A>] + m8[<B>] => M8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)] */
if (type_num2 == NPY_TIMEDELTA) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[1] == NULL) {
Py_DECREF(out_dtypes[0]);
out_dtypes[0] = NULL;
return -1;
}
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* M8[<A>] + int => M8[<A>] + m8[<A>] */
if (type_num2 == NPY_TIMEDELTA ||
PyTypeNum_ISINTEGER(type_num2) ||
else if (PyTypeNum_ISINTEGER(type_num2) ||
PyTypeNum_ISBOOL(type_num2)) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(
Expand Down Expand Up @@ -2421,7 +2440,7 @@ type_reso_error: {
* m8[<A>] - int => m8[<A>] - m8[<A>]
* int - m8[<A>] => m8[<A>] - m8[<A>]
* M8[<A>] - int => M8[<A>] - m8[<A>]
* M8[<A>] - m8[<B>] => M8[<A>] - m8[<A>]
* M8[<A>] - m8[<B>] => M8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)]
* TODO: Non-linear time unit cases require highly special-cased loops
* M8[<A>] - m8[Y|M|B]
*/
Expand Down Expand Up @@ -2480,10 +2499,25 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,
}
}
else if (type_num1 == NPY_DATETIME) {
/* M8[<A>] - m8[<B>] => M8[<A>] - m8[<A>] */
/* M8[<A>] - m8[<B>] => M8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)] */
if (type_num2 == NPY_TIMEDELTA) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[1] == NULL) {
Py_DECREF(out_dtypes[0]);
out_dtypes[0] = NULL;
return -1;
}
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* M8[<A>] - int => M8[<A>] - m8[<A>] */
if (type_num2 == NPY_TIMEDELTA ||
PyTypeNum_ISINTEGER(type_num2) ||
else if (PyTypeNum_ISINTEGER(type_num2) ||
PyTypeNum_ISBOOL(type_num2)) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(
Expand All @@ -2498,39 +2532,21 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,

type_num2 = NPY_TIMEDELTA;
}
/* M8[<A>] - M8[<A>] (producing m8[<A>])*/
/* M8[<A>] - M8[<B>] => M8[gcd(<A>,<B>)] - M8[gcd(<A>,<B>)] */
else if (type_num2 == NPY_DATETIME) {
PyArray_DatetimeMetaData *meta1, *meta2;

meta1 = get_datetime_metadata_from_dtype(
PyArray_DESCR(operands[0]));
if (meta1 == NULL) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
meta2 = get_datetime_metadata_from_dtype(
PyArray_DESCR(operands[1]));
if (meta2 == NULL) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[2] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[2] == NULL) {
Py_DECREF(out_dtypes[0]);
return -1;
}

/* If the metadata matches up, the subtraction is ok */
if (meta1->num == meta2->num &&
meta1->base == meta2->base &&
meta1->events == meta2->events) {
out_dtypes[0] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[0]);
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[2] = timedelta_dtype_with_copied_meta(
PyArray_DESCR(operands[0]));
if (out_dtypes[2] == NULL) {
return -1;
}
}
else {
goto type_reso_error;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
}
else {
goto type_reso_error;
Expand Down