Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT-#493: Extensible. #494

Merged
merged 4 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions lux/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import lux


class LuxDataFrame(pd.DataFrame):
class LuxDataFrameMixin:
"""
A subclass of pd.DataFrame that supports all dataframe operations while housing other variables and functions for generating visual recommendations.
"""
Expand Down Expand Up @@ -58,6 +58,8 @@ class LuxDataFrame(pd.DataFrame):
]

def __init__(self, *args, **kw):
super().__init__(*args, **kw)

self._history = History()
self._intent = []
self._inferred_intent = []
Expand All @@ -66,7 +68,6 @@ def __init__(self, *args, **kw):
self._current_vis = []
self._prev = None
self._widget = None
super(LuxDataFrame, self).__init__(*args, **kw)

self.table_name = ""
if lux.config.SQLconnection == "":
Expand All @@ -92,20 +93,6 @@ def __init__(self, *args, **kw):
self._type_override = {}
warnings.formatwarning = lux.warning_format

@property
def _constructor(self):
return LuxDataFrame

@property
def _constructor_sliced(self):
def f(*args, **kwargs):
s = LuxSeries(*args, **kwargs)
for attr in self._metadata: # propagate metadata
s.__dict__[attr] = getattr(self, attr, None)
return s

return f

@property
def history(self):
return self._history
Expand Down Expand Up @@ -174,23 +161,23 @@ def expire_metadata(self) -> None:
## Override Pandas ##
#####################
def __getattr__(self, name):
ret_value = super(LuxDataFrame, self).__getattr__(name)
ret_value = super().__getattr__(name)
self.expire_metadata()
self.expire_recs()
return ret_value

def _set_axis(self, axis, labels):
super(LuxDataFrame, self)._set_axis(axis, labels)
super()._set_axis(axis, labels)
self.expire_metadata()
self.expire_recs()

def _update_inplace(self, *args, **kwargs):
super(LuxDataFrame, self)._update_inplace(*args, **kwargs)
super()._update_inplace(*args, **kwargs)
self.expire_metadata()
self.expire_recs()

def _set_item(self, key, value):
super(LuxDataFrame, self)._set_item(key, value)
super()._set_item(key, value)
self.expire_metadata()
self.expire_recs()

Expand Down Expand Up @@ -847,13 +834,13 @@ def save_as_html(self, filename: str = "export.html", output=False):

# Overridden Pandas Functions
def head(self, n: int = 5):
ret_val = super(LuxDataFrame, self).head(n)
ret_val = super().head(n)
ret_val._prev = self
ret_val._history.append_event("head", n=5)
return ret_val

def tail(self, n: int = 5):
ret_val = super(LuxDataFrame, self).tail(n)
ret_val = super().tail(n)
ret_val._prev = self
ret_val._history.append_event("tail", n=5)
return ret_val
Expand All @@ -864,11 +851,27 @@ def groupby(self, *args, **kwargs):
history_flag = True
if "history" in kwargs:
del kwargs["history"]
groupby_obj = super(LuxDataFrame, self).groupby(*args, **kwargs)
groupby_obj = super().groupby(*args, **kwargs)
for attr in self._metadata:
groupby_obj.__dict__[attr] = getattr(self, attr, None)
if history_flag:
groupby_obj._history = groupby_obj._history.copy()
groupby_obj._history.append_event("groupby", *args, **kwargs)
groupby_obj.pre_aggregated = True
return groupby_obj


class LuxDataFrame(LuxDataFrameMixin, pd.DataFrame):
@property
def _constructor(self):
return LuxDataFrame

@property
def _constructor_sliced(self):
def f(*args, **kwargs):
s = LuxSeries(*args, **kwargs)
for attr in self._metadata: # propagate metadata
s.__dict__[attr] = getattr(self, attr, None)
return s

return f
32 changes: 13 additions & 19 deletions lux/core/groupby.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pandas as pd


class LuxGroupBy(pd.core.groupby.groupby.GroupBy):

class LuxGroupByMixin:
_metadata = [
"_intent",
"_inferred_intent",
Expand All @@ -25,68 +24,63 @@ class LuxGroupBy(pd.core.groupby.groupby.GroupBy):
"_type_override",
]

def __init__(self, *args, **kwargs):
super(LuxGroupBy, self).__init__(*args, **kwargs)

def aggregate(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).aggregate(*args, **kwargs)
ret_val = super().aggregate(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _agg_general(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self)._agg_general(*args, **kwargs)
ret_val = super()._agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _cython_agg_general(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self)._cython_agg_general(*args, **kwargs)
ret_val = super()._cython_agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def get_group(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).get_group(*args, **kwargs)
ret_val = super().get_group(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def filter(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).filter(*args, **kwargs)
ret_val = super().filter(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def apply(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).apply(*args, **kwargs)
ret_val = super().apply(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def size(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).size(*args, **kwargs)
ret_val = super().size(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def __getitem__(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).__getitem__(*args, **kwargs)
ret_val = super().__getitem__(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

agg = aggregate


class LuxDataFrameGroupBy(LuxGroupBy, pd.core.groupby.generic.DataFrameGroupBy):
def __init__(self, *args, **kwargs):
super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs)
class LuxDataFrameGroupBy(LuxGroupByMixin, pd.core.groupby.DataFrameGroupBy):
pass


class LuxSeriesGroupBy(LuxGroupBy, pd.core.groupby.generic.SeriesGroupBy):
def __init__(self, *args, **kwargs):
super(LuxSeriesGroupBy, self).__init__(*args, **kwargs)
class LuxSeriesGroupBy(LuxGroupByMixin, pd.core.groupby.SeriesGroupBy):
pass
55 changes: 29 additions & 26 deletions lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import pandas as pd
import lux
Expand All @@ -23,7 +24,7 @@
from typing import Dict, Union, List, Callable


class LuxSeries(pd.Series):
class LuxSeriesMixin:
"""
A subclass of pd.Series that supports all 1-D Series operations
"""
Expand Down Expand Up @@ -66,34 +67,13 @@ class LuxSeries(pd.Series):
}

def __init__(self, *args, **kw):
super(LuxSeries, self).__init__(*args, **kw)
super().__init__(*args, **kw)
for attr in self._metadata:
if attr in self._default_metadata:
self.__dict__[attr] = self._default_metadata[attr]()
else:
self.__dict__[attr] = None

@property
def _constructor(self):
return LuxSeries

@property
def _constructor_expanddim(self):
from lux.core.frame import LuxDataFrame

def f(*args, **kwargs):
df = LuxDataFrame(*args, **kwargs)
for attr in self._metadata:
# if attr in self._default_metadata:
# default = self._default_metadata[attr]
# else:
# default = None
df.__dict__[attr] = getattr(self, attr, None)
return df

f._get_axis_number = LuxDataFrame._get_axis_number
return f

def to_pandas(self) -> pd.Series:
"""
Convert Lux Series to Pandas Series
Expand Down Expand Up @@ -123,15 +103,15 @@ def unique(self):
if self.unique_values and self.name in self.unique_values.keys():
return np.array(self.unique_values[self.name])
else:
return super(LuxSeries, self).unique()
return super().unique()

def _ipython_display_(self):
from IPython.display import display
from IPython.display import clear_output
import ipywidgets as widgets
from lux.core.frame import LuxDataFrame

series_repr = super(LuxSeries, self).__repr__()
series_repr = super().__repr__()

ldf = LuxDataFrame(self)

Expand Down Expand Up @@ -252,11 +232,34 @@ def groupby(self, *args, **kwargs):
history_flag = True
if "history" in kwargs:
del kwargs["history"]
groupby_obj = super(LuxSeries, self).groupby(*args, **kwargs)
groupby_obj = super().groupby(*args, **kwargs)
for attr in self._metadata:
groupby_obj.__dict__[attr] = getattr(self, attr, None)
if history_flag:
groupby_obj._history = groupby_obj._history.copy()
groupby_obj._history.append_event("groupby", *args, **kwargs)
groupby_obj.pre_aggregated = True
return groupby_obj


class LuxSeries(LuxSeriesMixin, pd.Series):
@property
def _constructor(self):
return LuxSeries

@property
def _constructor_expanddim(self):
from lux.core.frame import LuxDataFrame

def f(*args, **kwargs):
df = LuxDataFrame(*args, **kwargs)
for attr in self._metadata:
# if attr in self._default_metadata:
# default = self._default_metadata[attr]
# else:
# default = None
df.__dict__[attr] = getattr(self, attr, None)
return df

f._get_axis_number = LuxDataFrame._get_axis_number
return f
Loading