Skip to content

Commit

Permalink
ENH: datetime: Unify datetime/timedelta type promotion
Browse files Browse the repository at this point in the history
Now it always goes to the more precise unit.
  • Loading branch information
Mark Wiebe committed Jun 8, 2011
1 parent f986fd4 commit 0800fe3
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 122 deletions.
13 changes: 0 additions & 13 deletions numpy/core/src/multiarray/_datetime.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,6 @@ datetime_metadata_divides(
PyArray_Descr *divisor,
int strict_with_nonlinear_units);

/*
* Computes the GCD of the two date-time metadata values. Raises
* an exception if there is no reasonable GCD, such as with
* years and days.
*
* Returns a capsule with the GCD metadata.
*/
NPY_NO_EXPORT PyObject *
compute_datetime_metadata_greatest_common_divisor(
PyArray_Descr *type1,
PyArray_Descr *type2,
int strict_with_nonlinear_units);

/*
* Computes the conversion factor to convert data with 'src_meta' metadata
* into data with 'dst_meta' metadata, not taking into account the events.
Expand Down
88 changes: 41 additions & 47 deletions numpy/core/src/multiarray/datetime.c
Original file line number Diff line number Diff line change
Expand Up @@ -1632,11 +1632,12 @@ datetime_metadata_divides(
}


NPY_NO_EXPORT PyObject *
static PyObject *
compute_datetime_metadata_greatest_common_divisor(
PyArray_Descr *type1,
PyArray_Descr *type2,
int strict_with_nonlinear_units)
int strict_with_nonlinear_units1,
int strict_with_nonlinear_units2)
{
PyArray_DatetimeMetaData *meta1, *meta2, *dt_data;
NPY_DATETIMEUNIT base;
Expand Down Expand Up @@ -1688,7 +1689,7 @@ compute_datetime_metadata_greatest_common_divisor(
base = NPY_FR_M;
num1 *= 12;
}
else if (strict_with_nonlinear_units) {
else if (strict_with_nonlinear_units1) {
goto incompatible_units;
}
else {
Expand All @@ -1701,19 +1702,34 @@ compute_datetime_metadata_greatest_common_divisor(
base = NPY_FR_M;
num2 *= 12;
}
else if (strict_with_nonlinear_units) {
else if (strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
base = meta1->base;
/* Don't multiply num2 since there is no even factor */
}
}
else if (meta1->base == NPY_FR_M ||
meta1->base == NPY_FR_B ||
meta2->base == NPY_FR_M ||
meta2->base == NPY_FR_B) {
if (strict_with_nonlinear_units) {
else if (meta1->base == NPY_FR_M) {
if (strict_with_nonlinear_units1) {
goto incompatible_units;
}
else {
base = meta2->base;
/* Don't multiply num1 since there is no even factor */
}
}
else if (meta2->base == NPY_FR_M) {
if (strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
base = meta1->base;
/* Don't multiply num2 since there is no even factor */
}
}
else if (meta1->base == NPY_FR_B || meta2->base == NPY_FR_B) {
if (strict_with_nonlinear_units1 || strict_with_nonlinear_units2) {
goto incompatible_units;
}
else {
Expand Down Expand Up @@ -1801,28 +1817,38 @@ units_overflow: {
}

/*
* Uses type1's type_num and the gcd of the metadata to create
* the result type.
* Both type1 and type2 must be either NPY_DATETIME or NPY_TIMEDELTA.
* Applies the type promotion rules between the two types, returning
* the promoted type.
*/
static PyArray_Descr *
datetime_gcd_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
NPY_NO_EXPORT PyArray_Descr *
datetime_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
{
int type_num1, type_num2;
PyObject *gcdmeta;
PyArray_Descr *dtype;
int is_datetime;

type_num1 = type1->type_num;
type_num2 = type2->type_num;

is_datetime = (type_num1 == NPY_DATETIME || type_num2 == NPY_DATETIME);

/*
* Get the metadata GCD, being strict about nonlinear units for
* timedelta and relaxed for datetime.
*/
gcdmeta = compute_datetime_metadata_greatest_common_divisor(
type1, type2,
type1->type_num == NPY_TIMEDELTA);
type_num1 == NPY_TIMEDELTA,
type_num2 == NPY_TIMEDELTA);
if (gcdmeta == NULL) {
return NULL;
}

/* Create a DATETIME or TIMEDELTA dtype */
dtype = PyArray_DescrNewFromType(type1->type_num);
dtype = PyArray_DescrNewFromType(is_datetime ? NPY_DATETIME :
NPY_TIMEDELTA);
if (dtype == NULL) {
Py_DECREF(gcdmeta);
return NULL;
Expand All @@ -1847,39 +1873,7 @@ datetime_gcd_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
Py_DECREF(gcdmeta);

return dtype;
}

/*
* Both type1 and type2 must be either NPY_DATETIME or NPY_TIMEDELTA.
* Applies the type promotion rules between the two types, returning
* the promoted type.
*/
NPY_NO_EXPORT PyArray_Descr *
datetime_type_promotion(PyArray_Descr *type1, PyArray_Descr *type2)
{
int type_num1, type_num2;

type_num1 = type1->type_num;
type_num2 = type2->type_num;

if (type_num1 == NPY_DATETIME) {
if (type_num2 == NPY_DATETIME) {
return datetime_gcd_type_promotion(type1, type2);
}
else if (type_num2 == NPY_TIMEDELTA) {
Py_INCREF(type1);
return type1;
}
}
else if (type_num1 == NPY_TIMEDELTA) {
if (type_num2 == NPY_DATETIME) {
Py_INCREF(type2);
return type2;
}
else if (type_num2 == NPY_TIMEDELTA) {
return datetime_gcd_type_promotion(type1, type2);
}
}

PyErr_SetString(PyExc_RuntimeError,
"Called datetime_type_promotion on non-datetype type");
Expand Down
102 changes: 59 additions & 43 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -2243,8 +2243,8 @@ timedelta_dtype_with_copied_meta(PyArray_Descr *dtype)
* int + m8[<A>] => m8[<A>] + m8[<A>]
* M8[<A>] + int => M8[<A>] + m8[<A>]
* int + M8[<A>] => m8[<A>] + M8[<A>]
* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>]
* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>]
* M8[<A>] + m8[<B>] => M8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)]
* m8[<A>] + M8[<B>] => m8[gcd(<A>,<B>)] + M8[gcd(<A>,<B>)]
* TODO: Non-linear time unit cases require highly special-cased loops
* M8[<A>] + m8[Y|M|B]
* m8[Y|M|B] + M8[<A>]
Expand Down Expand Up @@ -2287,16 +2287,20 @@ PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* m8[<A>] + M8[<B>] => m8[<B>] + M8[<B>] */
/* m8[<A>] + M8[<B>] => m8[gcd(<A>,<B>)] + M8[gcd(<A>,<B>)] */
else if (type_num2 == NPY_DATETIME) {
/* Make a new NPY_TIMEDELTA, and copy type2's metadata */
out_dtypes[0] = timedelta_dtype_with_copied_meta(
PyArray_DESCR(operands[1]));
out_dtypes[1] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[1] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[0] = timedelta_dtype_with_copied_meta(out_dtypes[1]);
if (out_dtypes[0] == NULL) {
Py_DECREF(out_dtypes[1]);
out_dtypes[1] = NULL;
return -1;
}
out_dtypes[1] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[1]);
out_dtypes[2] = out_dtypes[1];
Py_INCREF(out_dtypes[2]);
}
Expand All @@ -2317,10 +2321,25 @@ PyUFunc_AdditionTypeResolution(PyUFuncObject *ufunc,
}
}
else if (type_num1 == NPY_DATETIME) {
/* M8[<A>] + m8[<B>] => M8[<A>] + m8[<A>] */
/* M8[<A>] + m8[<B>] => M8[gcd(<A>,<B>)] + m8[gcd(<A>,<B>)] */
if (type_num2 == NPY_TIMEDELTA) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[1] == NULL) {
Py_DECREF(out_dtypes[0]);
out_dtypes[0] = NULL;
return -1;
}
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* M8[<A>] + int => M8[<A>] + m8[<A>] */
if (type_num2 == NPY_TIMEDELTA ||
PyTypeNum_ISINTEGER(type_num2) ||
else if (PyTypeNum_ISINTEGER(type_num2) ||
PyTypeNum_ISBOOL(type_num2)) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(
Expand Down Expand Up @@ -2421,7 +2440,7 @@ type_reso_error: {
* m8[<A>] - int => m8[<A>] - m8[<A>]
* int - m8[<A>] => m8[<A>] - m8[<A>]
* M8[<A>] - int => M8[<A>] - m8[<A>]
* M8[<A>] - m8[<B>] => M8[<A>] - m8[<A>]
* M8[<A>] - m8[<B>] => M8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)]
* TODO: Non-linear time unit cases require highly special-cased loops
* M8[<A>] - m8[Y|M|B]
*/
Expand Down Expand Up @@ -2480,10 +2499,25 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,
}
}
else if (type_num1 == NPY_DATETIME) {
/* M8[<A>] - m8[<B>] => M8[<A>] - m8[<A>] */
/* M8[<A>] - m8[<B>] => M8[gcd(<A>,<B>)] - m8[gcd(<A>,<B>)] */
if (type_num2 == NPY_TIMEDELTA) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
/* Make a new NPY_TIMEDELTA, and copy the datetime's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[1] == NULL) {
Py_DECREF(out_dtypes[0]);
out_dtypes[0] = NULL;
return -1;
}
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
/* M8[<A>] - int => M8[<A>] - m8[<A>] */
if (type_num2 == NPY_TIMEDELTA ||
PyTypeNum_ISINTEGER(type_num2) ||
else if (PyTypeNum_ISINTEGER(type_num2) ||
PyTypeNum_ISBOOL(type_num2)) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[1] = timedelta_dtype_with_copied_meta(
Expand All @@ -2498,39 +2532,21 @@ PyUFunc_SubtractionTypeResolution(PyUFuncObject *ufunc,

type_num2 = NPY_TIMEDELTA;
}
/* M8[<A>] - M8[<A>] (producing m8[<A>])*/
/* M8[<A>] - M8[<B>] => M8[gcd(<A>,<B>)] - M8[gcd(<A>,<B>)] */
else if (type_num2 == NPY_DATETIME) {
PyArray_DatetimeMetaData *meta1, *meta2;

meta1 = get_datetime_metadata_from_dtype(
PyArray_DESCR(operands[0]));
if (meta1 == NULL) {
out_dtypes[0] = PyArray_PromoteTypes(PyArray_DESCR(operands[0]),
PyArray_DESCR(operands[1]));
if (out_dtypes[0] == NULL) {
return -1;
}
meta2 = get_datetime_metadata_from_dtype(
PyArray_DESCR(operands[1]));
if (meta2 == NULL) {
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[2] = timedelta_dtype_with_copied_meta(out_dtypes[0]);
if (out_dtypes[2] == NULL) {
Py_DECREF(out_dtypes[0]);
return -1;
}

/* If the metadata matches up, the subtraction is ok */
if (meta1->num == meta2->num &&
meta1->base == meta2->base &&
meta1->events == meta2->events) {
out_dtypes[0] = PyArray_DESCR(operands[1]);
Py_INCREF(out_dtypes[0]);
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
/* Make a new NPY_TIMEDELTA, and copy type1's metadata */
out_dtypes[2] = timedelta_dtype_with_copied_meta(
PyArray_DESCR(operands[0]));
if (out_dtypes[2] == NULL) {
return -1;
}
}
else {
goto type_reso_error;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
}
else {
goto type_reso_error;
Expand Down
Loading

0 comments on commit 0800fe3

Please sign in to comment.