Skip to content

Commit

Permalink
Fixed cumulative functions (#127)
Browse files Browse the repository at this point in the history
* Fixed cumulative functions

* Fix formatting

* Resolve comments
  • Loading branch information
williamma12 authored and devin-petersohn committed Oct 8, 2018
1 parent 68f2f59 commit ab2ccef
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,8 @@ def cummax(self, axis=None, skipna=True, *args, **kwargs):
The cumulative maximum of the DataFrame.
"""
axis = pandas.DataFrame()._get_axis_number(axis) if axis is not None else 0
if axis:
self._validate_dtypes()
return DataFrame(
data_manager=self._data_manager.cummax(axis=axis, skipna=skipna, **kwargs)
)
Expand All @@ -1027,6 +1029,8 @@ def cummin(self, axis=None, skipna=True, *args, **kwargs):
The cumulative minimum of the DataFrame.
"""
axis = pandas.DataFrame()._get_axis_number(axis) if axis is not None else 0
if axis:
self._validate_dtypes()
return DataFrame(
data_manager=self._data_manager.cummin(axis=axis, skipna=skipna, **kwargs)
)
Expand All @@ -1042,6 +1046,7 @@ def cumprod(self, axis=None, skipna=True, *args, **kwargs):
The cumulative product of the DataFrame.
"""
axis = pandas.DataFrame()._get_axis_number(axis) if axis is not None else 0
self._validate_dtypes(numeric_only=True)
return DataFrame(
data_manager=self._data_manager.cumprod(axis=axis, skipna=skipna, **kwargs)
)
Expand All @@ -1057,6 +1062,7 @@ def cumsum(self, axis=None, skipna=True, *args, **kwargs):
The cumulative sum of the DataFrame.
"""
axis = pandas.DataFrame()._get_axis_number(axis) if axis is not None else 0
self._validate_dtypes(numeric_only=True)
return DataFrame(
data_manager=self._data_manager.cumsum(axis=axis, skipna=skipna, **kwargs)
)
Expand Down Expand Up @@ -4629,3 +4635,14 @@ def _validate_other(self, other, axis):
"given {1}".format(len(self.columns), len(other))
)
return other

def _validate_dtypes(self, numeric_only=False):
"""Helper method to check that all the dtypes are the same"""
dtype = self.dtypes[0]
for t in self.dtypes:
if numeric_only and not is_numeric_dtype(t):
raise TypeError("{0} is not a numeric data type".format(t))
elif not numeric_only and t != dtype:
raise TypeError(
"Cannot compare type '{0}' with type '{1}'".format(t, dtype)
)

0 comments on commit ab2ccef

Please sign in to comment.