From f986c76cada21f777738c769af4c096b7b22179c Mon Sep 17 00:00:00 2001 From: jinimukh <46768380+jinimukh@users.noreply.github.com> Date: Sun, 1 Nov 2020 23:04:07 -0800 Subject: [PATCH] add black to travis (#127) * add black to travis * reformat all code and adjust test * remove .idea * fix contributing doc * small change in contributing * update * reformat, update command to fix version * remove dev dependencies --- .travis.yml | 1 + CONTRIBUTING.md | 5 +- doc/conf.py | 50 +- lux/__init__.py | 2 +- lux/_config/config.py | 269 ++-- lux/_version.py | 2 +- lux/action/__init__.py | 3 +- lux/action/column_group.py | 65 +- lux/action/correlation.py | 116 +- lux/action/custom.py | 34 +- lux/action/enhance.py | 96 +- lux/action/filter.py | 175 +-- lux/action/generalize.py | 138 ++- lux/action/row_group.py | 57 +- lux/action/similarity.py | 90 +- lux/action/univariate.py | 123 +- lux/core/__init__.py | 11 +- lux/core/frame.py | 1583 +++++++++++++----------- lux/core/series.py | 63 +- lux/executor/Executor.py | 15 +- lux/executor/PandasExecutor.py | 383 +++--- lux/executor/SQLExecutor.py | 189 ++- lux/executor/__init__.py | 3 +- lux/history/__init__.py | 3 +- lux/history/event.py | 31 +- lux/history/history.py | 53 +- lux/interestingness/__init__.py | 3 +- lux/interestingness/interestingness.py | 573 +++++---- lux/processor/Compiler.py | 868 +++++++------ lux/processor/Parser.py | 185 +-- lux/processor/Validator.py | 128 +- lux/processor/__init__.py | 3 +- lux/utils/__init__.py | 3 +- lux/utils/date_utils.py | 220 ++-- lux/utils/message.py | 26 +- lux/utils/utils.py | 116 +- lux/vis/Clause.py | 240 ++-- lux/vis/Vis.py | 578 +++++---- lux/vis/VisList.py | 603 +++++---- lux/vis/__init__.py | 4 +- lux/vislib/__init__.py | 3 +- lux/vislib/altair/AltairChart.py | 161 ++- lux/vislib/altair/AltairRenderer.py | 185 +-- lux/vislib/altair/BarChart.py | 184 +-- lux/vislib/altair/Heatmap.py | 101 +- lux/vislib/altair/Histogram.py | 121 +- lux/vislib/altair/LineChart.py | 95 +- lux/vislib/altair/ScatterChart.py | 96 +- lux/vislib/altair/__init__.py | 3 +- requirements-dev.txt | 1 + requirements.txt | 5 +- setup.py | 50 +- tests/__init__.py | 3 +- tests/context.py | 7 +- tests/test_action.py | 243 ++-- tests/test_compiler.py | 458 ++++--- tests/test_config.py | 230 ++-- tests/test_dates.py | 127 +- tests/test_display.py | 17 +- tests/test_error_warning.py | 38 +- tests/test_executor.py | 202 ++- tests/test_interestingness.py | 232 ++-- tests/test_maintainence.py | 71 +- tests/test_nan.py | 16 +- tests/test_pandas.py | 17 +- tests/test_pandas_coverage.py | 410 ++++-- tests/test_parser.py | 105 +- tests/test_performance.py | 65 +- tests/test_type.py | 168 +-- tests/test_vis.py | 205 ++- 70 files changed, 6253 insertions(+), 4476 deletions(-) diff --git a/.travis.yml b/.travis.yml index fcd12c05..98bde1cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ install: - pip install git+https://github.com/lux-org/lux-widget # command to run tests script: + - black --target-version py37 --check . - python -m pytest tests/*.py - pytest --cov-report term --cov=lux tests/ after_success: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f068c4f..ac05767b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -45,9 +45,12 @@ You can run them locally to make sure that your changes are working and do not b python -m pytest tests/*.py ``` +# Code Formatting +In order to keep our codebase clean and readible, we are using PEP8 guidelines. To help us maintain and check code style, we are using [black](https://github.com/psf/black). Simply run `black .` before commiting. Failure to do so may fail the tests run on Travis. This package should have been installed for you. + # Submitting a Pull Request -You can commit your code and push to your forked repo. Once all of your local changes have been tested and is working, you are ready to submit a PR. For Lux, we use the "Squash and Merge" strategy to merge in PR, which means that even if you make a lot of small commits in your PR, they will all get squashed into a single commit associated with the PR. Please make sure that comments and unnecessary file changes are not committed as part of the PR by looking at the "File Changes" diff view on the pull request page. + You can commit your code and push to your forked repo. Once all of your local changes have been tested and formatted, you are ready to submit a PR. For Lux, we use the "Squash and Merge" strategy to merge in PR, which means that even if you make a lot of small commits in your PR, they will all get squashed into a single commit associated with the PR. Please make sure that comments and unnecessary file changes are not committed as part of the PR by looking at the "File Changes" diff view on the pull request page. Once the pull request is submitted, the maintainer will get notified and review your pull request. They may ask for additional changes or comment on the PR. You can always make updates to your pull request after submitting it. diff --git a/doc/conf.py b/doc/conf.py index 5983aca3..03862439 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,7 +19,8 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import subprocess -subprocess.call(['sh', './docbuild.sh']) + +subprocess.call(["sh", "./docbuild.sh"]) # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, @@ -28,18 +29,19 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'Lux' -copyright = '2020, Doris Jung-Lin Lee' -author = 'Doris Jung-Lin Lee' +project = "Lux" +copyright = "2020, Doris Jung-Lin Lee" +author = "Doris Jung-Lin Lee" # The full version, including alpha/beta/rc tags -release = '0.1.2' +release = "0.1.2" # -- General configuration --------------------------------------------------- @@ -48,32 +50,32 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.coverage', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.githubpages', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.mathjax', - 'sphinx_automodapi.automodapi', - 'sphinx_automodapi.automodsumm' + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.githubpages", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.automodsumm", ] -autodoc_default_flags = ['members', "inherited-members"] +autodoc_default_flags = ["members", "inherited-members"] autodoc_member_order = "groupwise" autosummary_generate = True numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -88,7 +90,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_logo = "_static/logo.png" html_theme_options = {"style_nav_header_background": "#19177c"} @@ -97,4 +99,4 @@ # further. For a list of options available for each theme, see the # documentation. # -master_doc = 'index' +master_doc = "index" diff --git a/lux/__init__.py b/lux/__init__.py index 8688b3d2..92d39840 100644 --- a/lux/__init__.py +++ b/lux/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/lux/_config/config.py b/lux/_config/config.py index 12a7271b..0c1e967f 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -1,7 +1,7 @@ -''' +""" This config file was largely borrowed from Pandas config.py set_action functionality. For more resources, see https://github.com/pandas-dev/pandas/blob/master/pandas/_config -''' +""" from collections import namedtuple from typing import Any, Callable, Dict, Iterable, List, Optional import warnings @@ -14,159 +14,170 @@ update_actions: Dict[str, bool] = {} update_actions["flag"] = False + class OptionError(AttributeError, KeyError): """ Exception for pandas.options, backwards compatible with KeyError checks """ - + + def _get_action(pat: str, silent: bool = False): - return _registered_actions[pat] + return _registered_actions[pat] + class DictWrapper: - def __init__(self, d: Dict[str, Any], prefix: str = ""): - object.__setattr__(self, "d", d) - object.__setattr__(self, "prefix", prefix) - def __init__(self, d: Dict[str, RegisteredOption], prefix: str = ""): - object.__setattr__(self, "d", d) - object.__setattr__(self, "prefix", prefix) - - def __getattr__(self, name: str): - """ - Gets a specific registered action by id - - Parameters - ---------- - name : str - the name of the action - Return - ------- - DictWrapper object for the action - """ - prefix = object.__getattribute__(self, "prefix") - if prefix: - prefix += "." - prefix += name - try: - v = object.__getattribute__(self, "d")[name] - except KeyError as err: - raise OptionError("No such option") from err - if isinstance(v, dict): - return DictWrapper(v, prefix) - else: - return _get_action(prefix) - - def __getactions__(self): - """ - Gathers all currently registered actions in a list of DictWrapper - - Return - ------- - List of DictWrapper objects that are registered - """ - l = [] - for name in self.__dir__(): - l.append(self.__getattr__(name)) - return l - - def __len__(self): - return len(list(self.d.keys())) - - def __dir__(self) -> Iterable[str]: - return list(self.d.keys()) + def __init__(self, d: Dict[str, Any], prefix: str = ""): + object.__setattr__(self, "d", d) + object.__setattr__(self, "prefix", prefix) + + def __init__(self, d: Dict[str, RegisteredOption], prefix: str = ""): + object.__setattr__(self, "d", d) + object.__setattr__(self, "prefix", prefix) + + def __getattr__(self, name: str): + """ + Gets a specific registered action by id + + Parameters + ---------- + name : str + the name of the action + Return + ------- + DictWrapper object for the action + """ + prefix = object.__getattribute__(self, "prefix") + if prefix: + prefix += "." + prefix += name + try: + v = object.__getattribute__(self, "d")[name] + except KeyError as err: + raise OptionError("No such option") from err + if isinstance(v, dict): + return DictWrapper(v, prefix) + else: + return _get_action(prefix) + + def __getactions__(self): + """ + Gathers all currently registered actions in a list of DictWrapper + + Return + ------- + List of DictWrapper objects that are registered + """ + l = [] + for name in self.__dir__(): + l.append(self.__getattr__(name)) + return l + + def __len__(self): + return len(list(self.d.keys())) + + def __dir__(self) -> Iterable[str]: + return list(self.d.keys()) + actions = DictWrapper(_registered_actions) def register_action( - name: str = "", + name: str = "", action: Callable[[Any], Any] = None, display_condition: Optional[Callable[[Any], Any]] = None, *args, ) -> None: - """ - Registers the provided action globally in lux - - Parameters - ---------- - name : str - the name of the action - action : Callable[[Any], Any] - the function used to generate the recommendations - display_condition : Callable[[Any], Any] - the function to check whether or not the function should be applied - args: Any - any additional arguments the function may require - """ - name = name.lower() - if action: - is_callable(action) - - if display_condition: - is_callable(display_condition) - _registered_actions[name] = RegisteredOption( - name=name, action=action, display_condition=display_condition, args=args - ) - update_actions["flag"] = True + """ + Registers the provided action globally in lux + + Parameters + ---------- + name : str + the name of the action + action : Callable[[Any], Any] + the function used to generate the recommendations + display_condition : Callable[[Any], Any] + the function to check whether or not the function should be applied + args: Any + any additional arguments the function may require + """ + name = name.lower() + if action: + is_callable(action) + + if display_condition: + is_callable(display_condition) + _registered_actions[name] = RegisteredOption( + name=name, action=action, display_condition=display_condition, args=args + ) + update_actions["flag"] = True + def remove_action( - name: str = "", + name: str = "", ) -> None: - """ - Removes the provided action globally in lux + """ + Removes the provided action globally in lux + + Parameters + ---------- + name : str + the name of the action to remove + """ + name = name.lower() + if name not in _registered_actions: + raise ValueError(f"Option '{name}' has not been registered") - Parameters - ---------- - name : str - the name of the action to remove - """ - name = name.lower() - if name not in _registered_actions: - raise ValueError(f"Option '{name}' has not been registered") + del _registered_actions[name] + update_actions["flag"] = True - del _registered_actions[name] - update_actions["flag"] = True def is_callable(obj) -> bool: - """ - Parameters - ---------- - obj: Any - the object to be checked - - Returns - ------- - validator : bool - returns True if object is callable - raises ValueError otherwise. - """ - if not callable(obj): - raise ValueError("Value must be a callable") - return True + """ + Parameters + ---------- + obj: Any + the object to be checked + + Returns + ------- + validator : bool + returns True if object is callable + raises ValueError otherwise. + """ + if not callable(obj): + raise ValueError("Value must be a callable") + return True + class Config: + def __init__(self): + self._default_display = "pandas" + + @property + def default_display(self): + return self._default_display + + @default_display.setter + def default_display(self, type: str) -> None: + """ + Set the widget display to show Pandas by default or Lux by default + Parameters + ---------- + type : str + Default display type, can take either the string `lux` or `pandas` (regardless of capitalization) + """ + if type.lower() == "lux": + self._default_display = "lux" + elif type.lower() == "pandas": + self._default_display = "pandas" + else: + warnings.warn( + "Unsupported display type. Default display option should either be `lux` or `pandas`.", + stacklevel=2, + ) - def __init__(self): - self._default_display = "pandas" - - @property - def default_display(self): - return self._default_display - - @default_display.setter - def default_display(self, type:str) -> None: - """ - Set the widget display to show Pandas by default or Lux by default - Parameters - ---------- - type : str - Default display type, can take either the string `lux` or `pandas` (regardless of capitalization) - """ - if (type.lower()=="lux"): - self._default_display = "lux" - elif (type.lower()=="pandas"): - self._default_display = "pandas" - else: - warnings.warn("Unsupported display type. Default display option should either be `lux` or `pandas`.",stacklevel=2) config = Config() diff --git a/lux/_version.py b/lux/_version.py index 173cee8a..924b8842 100644 --- a/lux/_version.py +++ b/lux/_version.py @@ -2,4 +2,4 @@ # coding: utf-8 version_info = (0, 2, 0) -__version__ = ".".join(map(str, version_info)) \ No newline at end of file +__version__ = ".".join(map(str, version_info)) diff --git a/lux/action/__init__.py b/lux/action/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/action/__init__.py +++ b/lux/action/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/action/column_group.py b/lux/action/column_group.py index ba1b9e5a..710cea95 100644 --- a/lux/action/column_group.py +++ b/lux/action/column_group.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,27 +21,44 @@ from lux.vis.VisList import VisList import pandas as pd + def column_group(ldf): - recommendation = {"action":"Column Groups", - "description":"Shows charts of possible visualizations with respect to the column-wise index."} - collection = [] - ldf_flat = ldf - if isinstance(ldf.columns,pd.DatetimeIndex): - ldf_flat.columns = ldf_flat.columns.format() - ldf_flat = ldf_flat.reset_index() #use a single shared ldf_flat so that metadata doesn't need to be computed for every vis - if (ldf.index.nlevels==1): - if ldf.index.name: - index_column_name = ldf.index.name - else: - index_column_name = "index" - if isinstance(ldf.columns,pd.DatetimeIndex): - ldf.columns = ldf.columns.to_native_types() - for attribute in ldf.columns: - if ldf[attribute].dtype!="object" and (attribute!="index"): - vis = Vis([lux.Clause(index_column_name, data_type = "nominal", data_model = "dimension", aggregation=None), lux.Clause(str(attribute), data_type = "quantitative", aggregation=None)]) - collection.append(vis) - vlst = VisList(collection,ldf_flat) - # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated ldf - - recommendation["collection"] = vlst - return recommendation \ No newline at end of file + recommendation = { + "action": "Column Groups", + "description": "Shows charts of possible visualizations with respect to the column-wise index.", + } + collection = [] + ldf_flat = ldf + if isinstance(ldf.columns, pd.DatetimeIndex): + ldf_flat.columns = ldf_flat.columns.format() + ldf_flat = ( + ldf_flat.reset_index() + ) # use a single shared ldf_flat so that metadata doesn't need to be computed for every vis + if ldf.index.nlevels == 1: + if ldf.index.name: + index_column_name = ldf.index.name + else: + index_column_name = "index" + if isinstance(ldf.columns, pd.DatetimeIndex): + ldf.columns = ldf.columns.to_native_types() + for attribute in ldf.columns: + if ldf[attribute].dtype != "object" and (attribute != "index"): + vis = Vis( + [ + lux.Clause( + index_column_name, + data_type="nominal", + data_model="dimension", + aggregation=None, + ), + lux.Clause( + str(attribute), data_type="quantitative", aggregation=None + ), + ] + ) + collection.append(vis) + vlst = VisList(collection, ldf_flat) + # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated ldf + + recommendation["collection"] = vlst + return recommendation diff --git a/lux/action/correlation.py b/lux/action/correlation.py index b5f23fee..5d51ba01 100644 --- a/lux/action/correlation.py +++ b/lux/action/correlation.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,60 +22,76 @@ # change ignore_transpose to false for now. def correlation(ldf: LuxDataFrame, ignore_transpose: bool = True): - ''' - Generates bivariate visualizations that represent all pairwise relationships in the data. + """ + Generates bivariate visualizations that represent all pairwise relationships in the data. + + Parameters + ---------- + ldf : LuxDataFrame + LuxDataFrame with underspecified intent. - Parameters - ---------- - ldf : LuxDataFrame - LuxDataFrame with underspecified intent. + ignore_transpose: bool + Boolean flag to ignore pairs of attributes whose transpose are already computed (i.e., {X,Y} will be ignored if {Y,X} is already computed) - ignore_transpose: bool - Boolean flag to ignore pairs of attributes whose transpose are already computed (i.e., {X,Y} will be ignored if {Y,X} is already computed) + Returns + ------- + recommendations : Dict[str,obj] + object with a collection of visualizations that result from the Correlation action. + """ - Returns - ------- - recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Correlation action. - ''' + import numpy as np - import numpy as np - filter_specs = utils.get_filter_specs(ldf._intent) - intent = [lux.Clause("?", data_model="measure"), lux.Clause("?", data_model="measure")] - intent.extend(filter_specs) - vlist = VisList(intent,ldf) - recommendation = {"action": "Correlation", - "description": "Show relationships between two

quantitative

attributes."} - ignore_rec_flag = False - if (len(ldf)<5): # Doesn't make sense to compute correlation if less than 4 data values - ignore_rec_flag = True - # Then use the data populated in the vis list to compute score - for vis in vlist: - measures = vis.get_attr_by_data_model("measure") - if len(measures) < 2: raise ValueError( - f"Can not compute correlation between {[x.attribute for x in ldf.columns]} since less than 2 measure values present.") - msr1 = measures[0].attribute - msr2 = measures[1].attribute + filter_specs = utils.get_filter_specs(ldf._intent) + intent = [ + lux.Clause("?", data_model="measure"), + lux.Clause("?", data_model="measure"), + ] + intent.extend(filter_specs) + vlist = VisList(intent, ldf) + recommendation = { + "action": "Correlation", + "description": "Show relationships between two

quantitative

attributes.", + } + ignore_rec_flag = False + if ( + len(ldf) < 5 + ): # Doesn't make sense to compute correlation if less than 4 data values + ignore_rec_flag = True + # Then use the data populated in the vis list to compute score + for vis in vlist: + measures = vis.get_attr_by_data_model("measure") + if len(measures) < 2: + raise ValueError( + f"Can not compute correlation between {[x.attribute for x in ldf.columns]} since less than 2 measure values present." + ) + msr1 = measures[0].attribute + msr2 = measures[1].attribute - if (ignore_transpose): - check_transpose = check_transpose_not_computed(vlist, msr1, msr2) - else: - check_transpose = True - if (check_transpose): - vis.score = interestingness(vis, ldf) - else: - vis.score = -1 - if (ignore_rec_flag): - recommendation["collection"] = [] - return recommendation - vlist = vlist.topK(15) - recommendation["collection"] = vlist - return recommendation + if ignore_transpose: + check_transpose = check_transpose_not_computed(vlist, msr1, msr2) + else: + check_transpose = True + if check_transpose: + vis.score = interestingness(vis, ldf) + else: + vis.score = -1 + if ignore_rec_flag: + recommendation["collection"] = [] + return recommendation + vlist = vlist.topK(15) + recommendation["collection"] = vlist + return recommendation def check_transpose_not_computed(vlist: VisList, a: str, b: str): - transpose_exist = list(filter(lambda x: (x._inferred_intent[0].attribute == b) and (x._inferred_intent[1].attribute == a), vlist)) - if (len(transpose_exist) > 0): - return transpose_exist[0].score == -1 - else: - return False + transpose_exist = list( + filter( + lambda x: (x._inferred_intent[0].attribute == b) + and (x._inferred_intent[1].attribute == a), + vlist, + ) + ) + if len(transpose_exist) > 0: + return transpose_exist[0].score == -1 + else: + return False diff --git a/lux/action/custom.py b/lux/action/custom.py index 114990e7..c709d34b 100644 --- a/lux/action/custom.py +++ b/lux/action/custom.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,8 +18,9 @@ from lux.executor.SQLExecutor import SQLExecutor import lux + def custom(ldf): - ''' + """ Generates user-defined vis based on the intent. Parameters @@ -31,22 +32,25 @@ def custom(ldf): ------- recommendations : Dict[str,obj] object with a collection of visualizations that result from the Distribution action. - ''' - recommendation = {"action": "Current Vis", - "description": "Shows the list of visualizations generated based on user specified intent"} + """ + recommendation = { + "action": "Current Vis", + "description": "Shows the list of visualizations generated based on user specified intent", + } recommendation["collection"] = ldf.current_vis vlist = ldf.current_vis PandasExecutor.execute(vlist, ldf) - for vis in vlist: - vis.score = interestingness(vis,ldf) + for vis in vlist: + vis.score = interestingness(vis, ldf) # ldf.clear_intent() vlist.sort(remove_invalid=True) return recommendation + def custom_actions(ldf): - ''' + """ Generates user-defined vis based on globally defined actions. Parameters @@ -58,15 +62,19 @@ def custom_actions(ldf): ------- recommendations : Dict[str,obj] object with a collection of visualizations that were previously registered. - ''' - if (lux.actions.__len__() > 0): + """ + if lux.actions.__len__() > 0: recommendations = [] for action_name in lux.actions.__dir__(): - display_condition = lux.actions.__getattr__(action_name).display_condition - if display_condition is None or (display_condition is not None and display_condition(ldf)): + display_condition = lux.actions.__getattr__(action_name).display_condition + if display_condition is None or ( + display_condition is not None and display_condition(ldf) + ): args = lux.actions.__getattr__(action_name).args if args: - recommendation = lux.actions.__getattr__(action_name).action(ldf, args) + recommendation = lux.actions.__getattr__(action_name).action( + ldf, args + ) else: recommendation = lux.actions.__getattr__(action_name).action(ldf) recommendations.append(recommendation) diff --git a/lux/action/enhance.py b/lux/action/enhance.py index 3ad44d72..ffdc2423 100644 --- a/lux/action/enhance.py +++ b/lux/action/enhance.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,45 +17,57 @@ from lux.processor.Compiler import Compiler from lux.utils import utils + def enhance(ldf): - ''' - Given a set of vis, generates possible visualizations when an additional attribute is added to the current vis. - - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. - - Returns - ------- - recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Enhance action. - ''' - - filters = utils.get_filter_specs(ldf._intent) - # Collect variables that already exist in the intent - attr_specs = list(filter(lambda x: x.value=="" and x.attribute!="Record", ldf._intent)) - fltr_str = [fltr.attribute+fltr.filter_op+str(fltr.value) for fltr in filters] - attr_str = [clause.attribute for clause in attr_specs] - intended_attrs = '

'+', '.join(attr_str+fltr_str)+'

' - if (len(attr_specs)==1): - recommendation = {"action":"Enhance", - "description":f"Augmenting current {intended_attrs} intent with additional attribute."} - elif(len(attr_specs)==2): - recommendation = {"action":"Enhance", - "description":f"Further breaking down current {intended_attrs} intent by additional attribute."} - elif(len(attr_specs)>2): # if there are too many column attributes, return don't generate Enhance recommendations - recommendation = {"action":"Enhance"} - recommendation["collection"] = [] - return recommendation - intent = ldf._intent.copy() - intent = filters + attr_specs - intent.append("?") - vlist = lux.vis.VisList.VisList(intent,ldf) - - # Then use the data populated in the vis list to compute score - for vis in vlist: vis.score = interestingness(vis,ldf) - - vlist = vlist.topK(15) - recommendation["collection"] = vlist - return recommendation \ No newline at end of file + """ + Given a set of vis, generates possible visualizations when an additional attribute is added to the current vis. + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. + + Returns + ------- + recommendations : Dict[str,obj] + object with a collection of visualizations that result from the Enhance action. + """ + + filters = utils.get_filter_specs(ldf._intent) + # Collect variables that already exist in the intent + attr_specs = list( + filter(lambda x: x.value == "" and x.attribute != "Record", ldf._intent) + ) + fltr_str = [fltr.attribute + fltr.filter_op + str(fltr.value) for fltr in filters] + attr_str = [clause.attribute for clause in attr_specs] + intended_attrs = ( + '

' + ", ".join(attr_str + fltr_str) + "

" + ) + if len(attr_specs) == 1: + recommendation = { + "action": "Enhance", + "description": f"Augmenting current {intended_attrs} intent with additional attribute.", + } + elif len(attr_specs) == 2: + recommendation = { + "action": "Enhance", + "description": f"Further breaking down current {intended_attrs} intent by additional attribute.", + } + elif ( + len(attr_specs) > 2 + ): # if there are too many column attributes, return don't generate Enhance recommendations + recommendation = {"action": "Enhance"} + recommendation["collection"] = [] + return recommendation + intent = ldf._intent.copy() + intent = filters + attr_specs + intent.append("?") + vlist = lux.vis.VisList.VisList(intent, ldf) + + # Then use the data populated in the vis list to compute score + for vis in vlist: + vis.score = interestingness(vis, ldf) + + vlist = vlist.topK(15) + recommendation["collection"] = vlist + return recommendation diff --git a/lux/action/filter.py b/lux/action/filter.py index b792260a..f0972722 100644 --- a/lux/action/filter.py +++ b/lux/action/filter.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,83 +19,104 @@ from lux.processor.Compiler import Compiler from lux.utils import utils + def filter(ldf): - ''' - Iterates over all possible values of a categorical variable and generates visualizations where each categorical value filters the data. + """ + Iterates over all possible values of a categorical variable and generates visualizations where each categorical value filters the data. + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. + + Returns + ------- + recommendations : Dict[str,obj] + object with a collection of visualizations that result from the Filter action. + """ + filters = utils.get_filter_specs(ldf._intent) + filter_values = [] + output = [] + # if fltr is specified, create visualizations where data is filtered by all values of the fltr's categorical variable + column_spec = utils.get_attrs_specs(ldf.current_vis[0]._inferred_intent) + column_spec_attr = map(lambda x: x.attribute, column_spec) + if len(filters) == 1: + # get unique values for all categorical values specified and creates corresponding filters + fltr = filters[0] + + if ldf.data_type_lookup[fltr.attribute] == "nominal": + recommendation = { + "action": "Filter", + "description": f"Changing the

{fltr.attribute}

filter to an alternative value.", + } + unique_values = ldf.unique_values[fltr.attribute] + filter_values.append(fltr.value) + # creates vis with new filters + for val in unique_values: + if val not in filter_values: + new_spec = column_spec.copy() + new_filter = lux.Clause(attribute=fltr.attribute, value=val) + new_spec.append(new_filter) + temp_vis = Vis(new_spec) + output.append(temp_vis) + elif ldf.data_type_lookup[fltr.attribute] == "quantitative": + recommendation = { + "action": "Filter", + "description": f"Changing the

{fltr.attribute}

filter to an alternative inequality operation.", + } - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. + def get_complementary_ops(fltr_op): + if fltr_op == ">": + return "<=" + elif fltr_op == "<": + return ">=" + elif fltr_op == ">=": + return "<" + elif fltr_op == "<=": + return ">" + # TODO: need to support case where fltr_op is "=" --> auto-binned ranges - Returns - ------- - recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Filter action. - ''' - filters = utils.get_filter_specs(ldf._intent) - filter_values = [] - output = [] - #if fltr is specified, create visualizations where data is filtered by all values of the fltr's categorical variable - column_spec = utils.get_attrs_specs(ldf.current_vis[0]._inferred_intent) - column_spec_attr = map(lambda x: x.attribute,column_spec) - if len(filters) == 1: - #get unique values for all categorical values specified and creates corresponding filters - fltr = filters[0] - - if (ldf.data_type_lookup[fltr.attribute]=="nominal"): - recommendation = {"action":"Filter", - "description":f"Changing the

{fltr.attribute}

filter to an alternative value."} - unique_values = ldf.unique_values[fltr.attribute] - filter_values.append(fltr.value) - #creates vis with new filters - for val in unique_values: - if val not in filter_values: - new_spec = column_spec.copy() - new_filter = lux.Clause(attribute = fltr.attribute, value = val) - new_spec.append(new_filter) - temp_vis = Vis(new_spec) - output.append(temp_vis) - elif (ldf.data_type_lookup[fltr.attribute]=="quantitative"): - recommendation = {"action":"Filter", - "description":f"Changing the

{fltr.attribute}

filter to an alternative inequality operation."} - def get_complementary_ops(fltr_op): - if (fltr_op=='>'): - return '<=' - elif (fltr_op=='<'): - return '>=' - elif (fltr_op=='>='): - return '<' - elif (fltr_op=='<='): - return '>' - # TODO: need to support case where fltr_op is "=" --> auto-binned ranges - # Create vis with complementary filter operations - new_spec = column_spec.copy() - new_filter = lux.Clause(attribute = fltr.attribute, filter_op=get_complementary_ops(fltr.filter_op),value = fltr.value) - new_spec.append(new_filter) - temp_vis = Vis(new_spec,score=1) - output.append(temp_vis) + # Create vis with complementary filter operations + new_spec = column_spec.copy() + new_filter = lux.Clause( + attribute=fltr.attribute, + filter_op=get_complementary_ops(fltr.filter_op), + value=fltr.value, + ) + new_spec.append(new_filter) + temp_vis = Vis(new_spec, score=1) + output.append(temp_vis) - else: #if no existing filters, create filters using unique values from all categorical variables in the dataset - intended_attrs = ', '.join([clause.attribute for clause in ldf._intent if clause.value=='' and clause.attribute!="Record"]) - recommendation = {"action":"Filter", - "description":f"Applying filters to the

{intended_attrs}

intent."} - categorical_vars = [] - for col in list(ldf.columns): - # if cardinality is not too high, and attribute is not one of the X,Y (specified) column - if ldf.cardinality[col]<30 and col not in column_spec_attr: - categorical_vars.append(col) - for cat in categorical_vars: - unique_values = ldf.unique_values[cat] - for i in range(0, len(unique_values)): - new_spec = column_spec.copy() - new_filter = lux.Clause(attribute=cat, filter_op="=",value=unique_values[i]) - new_spec.append(new_filter) - temp_vis = Vis(new_spec) - output.append(temp_vis) - vlist = lux.vis.VisList.VisList(output,ldf) - for vis in vlist: - vis.score = interestingness(vis,ldf) - vlist = vlist.topK(15) - recommendation["collection"] = vlist - return recommendation \ No newline at end of file + else: # if no existing filters, create filters using unique values from all categorical variables in the dataset + intended_attrs = ", ".join( + [ + clause.attribute + for clause in ldf._intent + if clause.value == "" and clause.attribute != "Record" + ] + ) + recommendation = { + "action": "Filter", + "description": f"Applying filters to the

{intended_attrs}

intent.", + } + categorical_vars = [] + for col in list(ldf.columns): + # if cardinality is not too high, and attribute is not one of the X,Y (specified) column + if ldf.cardinality[col] < 30 and col not in column_spec_attr: + categorical_vars.append(col) + for cat in categorical_vars: + unique_values = ldf.unique_values[cat] + for i in range(0, len(unique_values)): + new_spec = column_spec.copy() + new_filter = lux.Clause( + attribute=cat, filter_op="=", value=unique_values[i] + ) + new_spec.append(new_filter) + temp_vis = Vis(new_spec) + output.append(temp_vis) + vlist = lux.vis.VisList.VisList(output, ldf) + for vis in vlist: + vis.score = interestingness(vis, ldf) + vlist = vlist.topK(15) + recommendation["collection"] = vlist + return recommendation diff --git a/lux/action/generalize.py b/lux/action/generalize.py index 1588907c..c6096cc0 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,72 +18,84 @@ from lux.utils import utils from lux.interestingness.interestingness import interestingness + def generalize(ldf): - ''' - Generates all possible visualizations when one attribute or filter from the current vis is removed. + """ + Generates all possible visualizations when one attribute or filter from the current vis is removed. + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. + Returns + ------- + recommendations : Dict[str,obj] + object with a collection of visualizations that result from the Generalize action. + """ + # takes in a dataObject and generates a list of new dataObjects, each with a single measure from the original object removed + # --> return list of dataObjects with corresponding interestingness scores - Returns - ------- - recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Generalize action. - ''' - # takes in a dataObject and generates a list of new dataObjects, each with a single measure from the original object removed - # --> return list of dataObjects with corresponding interestingness scores + output = [] + excluded_columns = [] + attributes = list( + filter(lambda x: x.value == "" and x.attribute != "Record", ldf._intent) + ) + filters = utils.get_filter_specs(ldf._intent) - output = [] - excluded_columns = [] - attributes = list(filter(lambda x: x.value=="" and x.attribute!="Record", ldf._intent)) - filters = utils.get_filter_specs(ldf._intent) + fltr_str = [fltr.attribute + fltr.filter_op + str(fltr.value) for fltr in filters] + attr_str = [clause.attribute for clause in attributes] + intended_attrs = ( + '

' + ", ".join(attr_str + fltr_str) + "

" + ) - fltr_str = [fltr.attribute+fltr.filter_op+str(fltr.value) for fltr in filters] - attr_str = [clause.attribute for clause in attributes] - intended_attrs = '

'+', '.join(attr_str+fltr_str)+'

' + recommendation = { + "action": "Generalize", + "description": f"Remove an attribute or filter from {intended_attrs}.", + } + # to observe a more general trend + # if we do no have enough column attributes or too many, return no vis. + if len(attributes) < 1 or len(attributes) > 4: + recommendation["collection"] = [] + return recommendation + # for each column specification, create a copy of the ldf's vis and remove the column specification + # then append the vis to the output + if len(attributes) > 1: + for clause in attributes: + columns = clause.attribute + if type(columns) == list: + for column in columns: + if column not in excluded_columns: + temp_vis = Vis(ldf.copy_intent(), score=1) + temp_vis.remove_column_from_spec(column, remove_first=True) + excluded_columns.append(column) + output.append(temp_vis) + elif type(columns) == str: + if columns not in excluded_columns: + temp_vis = Vis(ldf.copy_intent(), score=1) + temp_vis.remove_column_from_spec(columns, remove_first=True) + excluded_columns.append(columns) + output.append(temp_vis) + # for each filter specification, create a copy of the ldf's current vis and remove the filter specification, + # then append the vis to the output + for clause in filters: + # new_spec = ldf._intent.copy() + # new_spec.remove_column_from_spec(new_spec.attribute) + temp_vis = Vis( + ldf.current_vis[0]._inferred_intent.copy(), + source=ldf, + title="Overall", + score=0, + ) + temp_vis.remove_filter_from_spec(clause.value) + output.append(temp_vis) - recommendation = {"action":"Generalize", - "description":f"Remove an attribute or filter from {intended_attrs}."} - # to observe a more general trend - # if we do no have enough column attributes or too many, return no vis. - if(len(attributes)<1 or len(attributes)>4): - recommendation["collection"] = [] - return recommendation - #for each column specification, create a copy of the ldf's vis and remove the column specification - #then append the vis to the output - if (len(attributes)>1): - for clause in attributes: - columns = clause.attribute - if type(columns) == list: - for column in columns: - if column not in excluded_columns: - temp_vis = Vis(ldf.copy_intent(),score=1) - temp_vis.remove_column_from_spec(column, remove_first = True) - excluded_columns.append(column) - output.append(temp_vis) - elif type(columns) == str: - if columns not in excluded_columns: - temp_vis = Vis(ldf.copy_intent(),score=1) - temp_vis.remove_column_from_spec(columns, remove_first = True) - excluded_columns.append(columns) - output.append(temp_vis) - #for each filter specification, create a copy of the ldf's current vis and remove the filter specification, - #then append the vis to the output - for clause in filters: - #new_spec = ldf._intent.copy() - #new_spec.remove_column_from_spec(new_spec.attribute) - temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(),source = ldf,title="Overall",score=0) - temp_vis.remove_filter_from_spec(clause.value) - output.append(temp_vis) - - vlist = lux.vis.VisList.VisList(output,source=ldf) - # Ignore interestingness sorting since Generalize yields very few vis (preserve order of remove attribute, then remove filters) - # for vis in vlist: - # vis.score = interestingness(vis,ldf) + vlist = lux.vis.VisList.VisList(output, source=ldf) + # Ignore interestingness sorting since Generalize yields very few vis (preserve order of remove attribute, then remove filters) + # for vis in vlist: + # vis.score = interestingness(vis,ldf) - vlist.remove_duplicates() - vlist.sort(remove_invalid=True) - recommendation["collection"] = vlist - return recommendation \ No newline at end of file + vlist.remove_duplicates() + vlist.sort(remove_invalid=True) + recommendation["collection"] = vlist + return recommendation diff --git a/lux/action/row_group.py b/lux/action/row_group.py index a7d688ee..3fca5428 100644 --- a/lux/action/row_group.py +++ b/lux/action/row_group.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,29 +21,34 @@ from lux.vis.VisList import VisList import pandas as pd + def row_group(ldf): - recommendation = {"action":"Row Groups", - "description":"Shows charts of possible visualizations with respect to the row-wise index."} - collection = [] - - if (ldf.index.nlevels==1): - if (ldf.columns.name is not None): - dim_name = ldf.columns.name - else: - dim_name = "index" - for row_id in range(len(ldf)): - row = ldf.iloc[row_id,] - rowdf = row.reset_index() - # if (dim_name =="index"): #TODO: need to change this to auto-detect - # rowdf.data_type_lookup["index"]="nominal" - # rowdf.data_model_lookup["index"]="dimension" - # rowdf.cardinality["index"]=len(rowdf) - # if isinstance(ldf.columns,pd.DatetimeIndex): - # rowdf.data_type_lookup[dim_name]="temporal" - vis = Vis([dim_name,lux.Clause(row.name,aggregation=None)],rowdf) - collection.append(vis) - vlst = VisList(collection) - # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data - - recommendation["collection"] = vlst - return recommendation \ No newline at end of file + recommendation = { + "action": "Row Groups", + "description": "Shows charts of possible visualizations with respect to the row-wise index.", + } + collection = [] + + if ldf.index.nlevels == 1: + if ldf.columns.name is not None: + dim_name = ldf.columns.name + else: + dim_name = "index" + for row_id in range(len(ldf)): + row = ldf.iloc[ + row_id, + ] + rowdf = row.reset_index() + # if (dim_name =="index"): #TODO: need to change this to auto-detect + # rowdf.data_type_lookup["index"]="nominal" + # rowdf.data_model_lookup["index"]="dimension" + # rowdf.cardinality["index"]=len(rowdf) + # if isinstance(ldf.columns,pd.DatetimeIndex): + # rowdf.data_type_lookup[dim_name]="temporal" + vis = Vis([dim_name, lux.Clause(row.name, aggregation=None)], rowdf) + collection.append(vis) + vlst = VisList(collection) + # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data + + recommendation["collection"] = vlst + return recommendation diff --git a/lux/action/similarity.py b/lux/action/similarity.py index 8c369f93..c9871cbc 100644 --- a/lux/action/similarity.py +++ b/lux/action/similarity.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,14 +18,15 @@ import numpy as np from lux.vis.VisList import VisList + def similar_pattern(ldf, intent, topK=-1): - ''' + """ Generates visualizations with similar patterns to a query visualization. Parameters ---------- ldf : lux.core.frame - LuxDataFrame with underspecified intent. + LuxDataFrame with underspecified intent. intent: list[lux.Clause] intent for specifying the visual query for the similarity search. @@ -36,31 +37,34 @@ def similar_pattern(ldf, intent, topK=-1): Returns ------- recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Similarity action - ''' + object with a collection of visualizations that result from the Similarity action + """ row_specs = list(filter(lambda x: x.value != "", intent)) - if(len(row_specs) == 1): - search_space_vc = VisList(ldf.current_vis.collection.copy(),ldf) + if len(row_specs) == 1: + search_space_vc = VisList(ldf.current_vis.collection.copy(), ldf) - query_vc = VisList(intent,ldf) + query_vc = VisList(intent, ldf) query_vis = query_vc[0] preprocess(query_vis) - #for loop to create assign euclidean distance - recommendation = {"action":"Similarity", - "description":"Show other charts that are visually similar to the Current vis."} + # for loop to create assign euclidean distance + recommendation = { + "action": "Similarity", + "description": "Show other charts that are visually similar to the Current vis.", + } for vis in search_space_vc: preprocess(vis) vis.score = euclidean_dist(query_vis, vis) search_space_vc.normalize_score(invert_order=True) - if(topK!=-1): + if topK != -1: search_space_vc = search_space_vc.topK(topK) recommendation["collection"] = search_space_vc return recommendation else: print("Query needs to have 1 row value") + def aggregate(vis): - ''' + """ Aggregates data values on the y axis so that the vis is a time series Parameters @@ -70,16 +74,22 @@ def aggregate(vis): Returns ------- None - ''' + """ if vis.get_attr_by_channel("x") and vis.get_attr_by_channel("y"): xAxis = vis.get_attr_by_channel("x")[0].attribute yAxis = vis.get_attr_by_channel("y")[0].attribute - vis.data = vis.data[[xAxis,yAxis]].groupby(xAxis,as_index=False).agg({yAxis:'mean'}).copy() + vis.data = ( + vis.data[[xAxis, yAxis]] + .groupby(xAxis, as_index=False) + .agg({yAxis: "mean"}) + .copy() + ) -def interpolate(vis,length): - ''' + +def interpolate(vis, length): + """ Interpolates the vis data so that the number of data points is fixed to a constant Parameters @@ -92,7 +102,7 @@ def interpolate(vis,length): Returns ------- None - ''' + """ if vis.get_attr_by_channel("x") and vis.get_attr_by_channel("y"): xAxis = vis.get_attr_by_channel("x")[0].attribute @@ -103,32 +113,40 @@ def interpolate(vis,length): xVals = vis.data[xAxis] n = length - interpolated_x_vals = [0.0]*(length) - interpolated_y_vals = [0.0]*(length) + interpolated_x_vals = [0.0] * (length) + interpolated_y_vals = [0.0] * (length) - granularity = (xVals[len(xVals)-1] - xVals[0]) / n + granularity = (xVals[len(xVals) - 1] - xVals[0]) / n count = 0 - for i in range(0,n): + for i in range(0, n): interpolated_x = xVals[0] + i * granularity interpolated_x_vals[i] = interpolated_x while xVals[count] < interpolated_x: - if(count < len(xVals)): + if count < len(xVals): count += 1 if xVals[count] == interpolated_x: interpolated_y_vals[i] = yVals[count] else: - x_diff = xVals[count] - xVals[count-1] - yDiff = yVals[count] - yVals[count-1] - interpolated_y_vals[i] = yVals[count-1] + (interpolated_x - xVals[count-1]) / x_diff * yDiff - vis.data = pd.DataFrame(list(zip(interpolated_x_vals, interpolated_y_vals)),columns = [xAxis, yAxis]) + x_diff = xVals[count] - xVals[count - 1] + yDiff = yVals[count] - yVals[count - 1] + interpolated_y_vals[i] = ( + yVals[count - 1] + + (interpolated_x - xVals[count - 1]) / x_diff * yDiff + ) + vis.data = pd.DataFrame( + list(zip(interpolated_x_vals, interpolated_y_vals)), + columns=[xAxis, yAxis], + ) + # interpolate dataset + def normalize(vis): - ''' + """ Normalizes the vis data so that the range of values is 0 to 1 for the vis Parameters @@ -138,17 +156,18 @@ def normalize(vis): Returns ------- None - ''' + """ if vis.get_attr_by_channel("y"): y_axis = vis.get_attr_by_channel("y")[0].attribute max = vis.data[y_axis].max() min = vis.data[y_axis].min() - if(max == min or (max-min<1)): + if max == min or (max - min < 1): return vis.data[y_axis] = (vis.data[y_axis] - min) / (max - min) + def euclidean_dist(query_vis, vis): - ''' + """ Calculates euclidean distance score for similarity between two visualizations Parameters @@ -162,7 +181,7 @@ def euclidean_dist(query_vis, vis): ------- score : float euclidean distance score - ''' + """ if query_vis.get_attr_by_channel("y") and vis.get_attr_by_channel("y"): @@ -177,8 +196,10 @@ def euclidean_dist(query_vis, vis): else: print("no y axis detected") return 0 + + def preprocess(vis): - ''' + """ Processes vis data to allow similarity comparisons between visualizations Parameters @@ -188,8 +209,7 @@ def preprocess(vis): Returns ------- None - ''' + """ aggregate(vis) interpolate(vis, 100) normalize(vis) - diff --git a/lux/action/univariate.py b/lux/action/univariate.py index 0d58c8bf..4eb0157e 100644 --- a/lux/action/univariate.py +++ b/lux/action/univariate.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,60 +16,77 @@ from lux.vis.VisList import VisList import lux from lux.utils import utils + + def univariate(ldf, *args): - ''' - Generates bar chart distributions of different attributes in the dataframe. + """ + Generates bar chart distributions of different attributes in the dataframe. + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. + data_type_constraint: str + Controls the type of distribution chart that will be rendered. - data_type_constraint: str - Controls the type of distribution chart that will be rendered. + Returns + ------- + recommendations : Dict[str,obj] + object with a collection of visualizations that result from the Distribution action. + """ + import numpy as np - Returns - ------- - recommendations : Dict[str,obj] - object with a collection of visualizations that result from the Distribution action. - ''' - import numpy as np - if len(args) == 0: - data_type_constraint = "quantitative" - else: - data_type_constraint = args[0][0] + if len(args) == 0: + data_type_constraint = "quantitative" + else: + data_type_constraint = args[0][0] - filter_specs = utils.get_filter_specs(ldf._intent) - ignore_rec_flag = False - if (data_type_constraint== "quantitative"): - possible_attributes = [c for c in ldf.columns if ldf.data_type_lookup[c] == "quantitative" - and ldf.cardinality[c] > 5 - and c !="Number of Records"] - intent = [lux.Clause(possible_attributes)] - intent.extend(filter_specs) - recommendation = {"action":"Distribution", - "description":"Show univariate histograms of

quantitative

attributes."} - if (len(ldf)<5): # Doesn't make sense to generate a histogram if there is less than 5 datapoints (pre-aggregated) - ignore_rec_flag = True - elif (data_type_constraint == "nominal"): - intent = [lux.Clause("?",data_type="nominal")] - intent.extend(filter_specs) - recommendation = {"action":"Occurrence", - "description":"Show frequency of occurrence for

categorical

attributes."} - elif (data_type_constraint == "temporal"): - intent = [lux.Clause("?",data_type="temporal")] - intent.extend(filter_specs) - recommendation = {"action":"Temporal", - "description":"Show trends over

time-related

attributes."} - if (len(ldf)<3): # Doesn't make sense to generate a line chart if there is less than 3 datapoints (pre-aggregated) - ignore_rec_flag = True - if (ignore_rec_flag): - recommendation["collection"] = [] - return recommendation - vlist = VisList(intent,ldf) - for vis in vlist: - vis.score = interestingness(vis,ldf) - # vlist = vlist.topK(15) # Basic visualizations should not be capped - vlist.sort() - recommendation["collection"] = vlist - return recommendation \ No newline at end of file + filter_specs = utils.get_filter_specs(ldf._intent) + ignore_rec_flag = False + if data_type_constraint == "quantitative": + possible_attributes = [ + c + for c in ldf.columns + if ldf.data_type_lookup[c] == "quantitative" + and ldf.cardinality[c] > 5 + and c != "Number of Records" + ] + intent = [lux.Clause(possible_attributes)] + intent.extend(filter_specs) + recommendation = { + "action": "Distribution", + "description": "Show univariate histograms of

quantitative

attributes.", + } + if ( + len(ldf) < 5 + ): # Doesn't make sense to generate a histogram if there is less than 5 datapoints (pre-aggregated) + ignore_rec_flag = True + elif data_type_constraint == "nominal": + intent = [lux.Clause("?", data_type="nominal")] + intent.extend(filter_specs) + recommendation = { + "action": "Occurrence", + "description": "Show frequency of occurrence for

categorical

attributes.", + } + elif data_type_constraint == "temporal": + intent = [lux.Clause("?", data_type="temporal")] + intent.extend(filter_specs) + recommendation = { + "action": "Temporal", + "description": "Show trends over

time-related

attributes.", + } + if ( + len(ldf) < 3 + ): # Doesn't make sense to generate a line chart if there is less than 3 datapoints (pre-aggregated) + ignore_rec_flag = True + if ignore_rec_flag: + recommendation["collection"] = [] + return recommendation + vlist = VisList(intent, ldf) + for vis in vlist: + vis.score = interestingness(vis, ldf) + # vlist = vlist.topK(15) # Basic visualizations should not be capped + vlist.sort() + recommendation["collection"] = vlist + return recommendation diff --git a/lux/core/__init__.py b/lux/core/__init__.py index 771a533f..2d72ffb8 100644 --- a/lux/core/__init__.py +++ b/lux/core/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,14 +14,17 @@ import pandas as pd from .frame import LuxDataFrame -global originalDF; + +global originalDF # Keep variable scope of original pandas df originalDF = pd.core.frame.DataFrame + def setOption(overridePandas=True): - if (overridePandas): + if overridePandas: pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = LuxDataFrame else: pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = originalDF -setOption(overridePandas=True) \ No newline at end of file + +setOption(overridePandas=True) diff --git a/lux/core/frame.py b/lux/core/frame.py index bd7cea6a..6e45621b 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,715 +23,874 @@ from typing import Dict, Union, List, Callable import warnings import lux + + class LuxDataFrame(pd.DataFrame): - ''' - A subclass of pd.DataFrame that supports all dataframe operations while housing other variables and functions for generating visual recommendations. - ''' - # MUST register here for new properties!! - _metadata = ['_intent','data_type_lookup','data_type', - 'data_model_lookup','data_model','unique_values','cardinality','_rec_info', '_pandas_only', - '_min_max','plot_config', '_current_vis','_widget', '_recommendation','_prev','_history', '_saved_export'] - - def __init__(self,*args, **kw): - from lux.executor.PandasExecutor import PandasExecutor - self._history = History() - self._intent = [] - self._recommendation = {} - self._saved_export = None - self._current_vis = [] - self._prev = None - super(LuxDataFrame, self).__init__(*args, **kw) - - self.executor_type = "Pandas" - self.executor = PandasExecutor() - self.SQLconnection = "" - self.table_name = "" - - self._sampled = None - self._default_pandas_display = True - self._toggle_pandas_display = True - self._plot_config = None - self._message = Message() - self._pandas_only=False - # Metadata - self.data_type_lookup = None - self.data_type = None - self.data_model_lookup = None - self.data_model = None - self.unique_values = None - self.cardinality = None - self._min_max = None - self.pre_aggregated = None - - @property - def _constructor(self): - return LuxDataFrame - @property - def _constructor_sliced(self): - def f(*args, **kwargs): - s = LuxSeries(*args, **kwargs) - for attr in self._metadata: #propagate metadata - s.__dict__[attr] = getattr(self, attr, None) - return s - return f - - @property - def history(self): - return self._history - def maintain_metadata(self): - if (not hasattr(self,"_metadata_fresh") or not self._metadata_fresh ): # Check that metadata has not yet been computed - if (len(self)>0): #only compute metadata information if the dataframe is non-empty - self.executor.compute_stats(self) - self.executor.compute_dataset_metadata(self) - self._infer_structure() - self._metadata_fresh = True - def expire_recs(self): - self._recs_fresh = False - self.recommendation = {} - self.current_vis = None - self._widget = None - self._rec_info = None - self._sampled = None - def expire_metadata(self): - # Set metadata as null - self._metadata_fresh = False - self.data_type_lookup = None - self.data_type = None - self.data_model_lookup = None - self.data_model = None - self.unique_values = None - self.cardinality = None - self._min_max = None - self.pre_aggregated = None - - ##################### - ## Override Pandas ## - ##################### - def __getattr__(self, name): - ret_value = super(LuxDataFrame, self).__getattr__(name) - self.expire_metadata() - self.expire_recs() - return ret_value - def _set_axis(self, axis, labels): - super(LuxDataFrame, self)._set_axis(axis, labels) - self.expire_metadata() - self.expire_recs() - def _update_inplace(self,*args,**kwargs): - super(LuxDataFrame, self)._update_inplace(*args,**kwargs) - self.expire_metadata() - self.expire_recs() - def _set_item(self, key, value): - super(LuxDataFrame, self)._set_item(key, value) - self.expire_metadata() - self.expire_recs() - def _infer_structure(self): - # If the dataframe is very small and the index column is not a range index, then it is likely that this is an aggregated data - is_multi_index_flag = self.index.nlevels !=1 - not_int_index_flag = self.index.dtype !='int64' - small_df_flag = len(self)<100 - self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag - if ("Number of Records" in self.columns): - self.pre_aggregated = True - very_small_df_flag = len(self)<=10 - if (very_small_df_flag): - self.pre_aggregated = True - def set_executor_type(self, exe): - if (exe =="SQL"): - import pkgutil - if (pkgutil.find_loader("psycopg2") is None): - raise ImportError("psycopg2 is not installed. Run `pip install psycopg2' to install psycopg2 to enable the Postgres connection.") - else: - import psycopg2 - from lux.executor.SQLExecutor import SQLExecutor - self.executor = SQLExecutor - else: - from lux.executor.PandasExecutor import PandasExecutor - self.executor = PandasExecutor() - self.executor_type = exe - @property - def plot_config(self): - return self._plot_config - @plot_config.setter - def plot_config(self,config_func:Callable): - """ - Modify plot aesthetic settings to all visualizations in the dataframe display - Currently only supported for Altair visualizations - Parameters - ---------- - config_func : Callable - A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output - - Example - ---------- - Changing the color of marks and adding a title for all charts displayed for this dataframe - >>> df = pd.read_csv("lux/data/car.csv") - >>> def changeColorAddTitle(chart): - chart = chart.configure_mark(color="red") # change mark color to red - chart.title = "Custom Title" # add title to chart - return chart - >>> df.plot_config = changeColorAddTitle - >>> df - Change the opacity of all scatterplots displayed for this dataframe - >>> df = pd.read_csv("lux/data/olympic.csv") - >>> def changeOpacityScatterOnly(chart): - if chart.mark=='circle': - chart = chart.configure_mark(opacity=0.1) # lower opacity - return chart - >>> df.plot_config = changeOpacityScatterOnly - >>> df - """ - self._plot_config = config_func - self._recs_fresh=False - def clear_plot_config(self): - self._plot_config = None - self._recs_fresh=False - - @property - def intent(self): - return self._intent - @intent.setter - def intent(self, intent_input:Union[List[Union[str, Clause]],Vis]): - is_list_input = isinstance(intent_input,list) - is_vis_input = isinstance(intent_input,Vis) - if not (is_list_input or is_vis_input): - raise TypeError("Input intent must be either a list (of strings or lux.Clause) or a lux.Vis object." - "\nSee more at: https://lux-api.readthedocs.io/en/latest/source/guide/intent.html" - ) - if is_list_input: - self.set_intent(intent_input) - elif is_vis_input: - self.set_intent_as_vis(intent_input) - def clear_intent(self): - self.intent = [] - def set_intent(self, intent:List[Union[str, Clause]]): - """ - Main function to set the intent of the dataframe. - The intent input goes through the parser, so that the string inputs are parsed into a lux.Clause object. - - Parameters - ---------- - intent : List[str,Clause] - intent list, can be a mix of string shorthand or a lux.Clause object - - Notes - ----- - :doc:`../guide/clause` - """ - self.expire_recs() - self._intent = intent - self._parse_validate_compile_intent() - def _parse_validate_compile_intent(self): - from lux.processor.Parser import Parser - from lux.processor.Validator import Validator - self._intent = Parser.parse(self._intent) - Validator.validate_intent(self._intent,self) - self.maintain_metadata() - from lux.processor.Compiler import Compiler - self.current_vis = Compiler.compile_intent(self, self._intent) - - def copy_intent(self): - #creates a true copy of the dataframe's intent - output = [] - for clause in self._intent: - temp_clause = clause.copy_clause() - output.append(temp_clause) - return(output) - - def set_intent_as_vis(self,vis:Vis): - """ - Set intent of the dataframe as the Vis - - Parameters - ---------- - vis : Vis - """ - self.expire_recs() - self._intent = vis._inferred_intent - self._parse_validate_compile_intent() - - def to_pandas(self): - import lux.core - return lux.core.originalDF(self,copy=False) - - @property - def recommendation(self): - return self._recommendation - @recommendation.setter - def recommendation(self,recommendation:Dict): - self._recommendation = recommendation - @property - def current_vis(self): - return self._current_vis - @current_vis.setter - def current_vis(self,current_vis:Dict): - self._current_vis = current_vis - def __repr__(self): - # TODO: _repr_ gets called from _repr_html, need to get rid of this call - return "" - - ####################################################### - ########## SQL Metadata, type, model schema ########### - ####################################################### - - def set_SQL_connection(self, connection, t_name): - self.SQLconnection = connection - self.table_name = t_name - self.compute_SQL_dataset_metadata() - self.set_executor_type("SQL") - - def compute_SQL_dataset_metadata(self): - self.get_SQL_attributes() - for attr in list(self.columns): - self[attr] = None - self.data_type_lookup = {} - self.data_type = {} - #####NOTE: since we aren't expecting users to do much data processing with the SQL database, should we just keep this - ##### in the initialization and do it just once - self.compute_SQL_data_type() - self.compute_SQL_stats() - self.data_model_lookup = {} - self.data_model = {} - self.compute_data_model() - - def compute_SQL_stats(self): - # precompute statistics - self.unique_values = {} - self._min_max = {} - - self.get_SQL_unique_values() - #self.get_SQL_cardinality() - for attribute in self.columns: - if self.data_type_lookup[attribute] == 'quantitative': - self._min_max[attribute] = (self[attribute].min(), self[attribute].max()) - - def get_SQL_attributes(self): - if "." in self.table_name: - table_name = self.table_name[self.table_name.index(".")+1:] - else: - table_name = self.table_name - attr_query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS where TABLE_NAME = '{}'".format(table_name) - attributes = list(pd.read_sql(attr_query, self.SQLconnection)['column_name']) - for attr in attributes: - self[attr] = None - - def get_SQL_cardinality(self): - cardinality = {} - for attr in list(self.columns): - card_query = pd.read_sql("SELECT Count(Distinct({})) FROM {}".format(attr, self.table_name), self.SQLconnection) - cardinality[attr] = list(card_query["count"])[0] - self.cardinality = cardinality - - def get_SQL_unique_values(self): - unique_vals = {} - for attr in list(self.columns): - unique_query = pd.read_sql("SELECT Distinct({}) FROM {}".format(attr, self.table_name), self.SQLconnection) - unique_vals[attr] = list(unique_query[attr]) - self.unique_values = unique_vals - - def compute_SQL_data_type(self): - data_type_lookup = {} - sql_dtypes = {} - self.get_SQL_cardinality() - if "." in self.table_name: - table_name = self.table_name[self.table_name.index(".")+1:] - else: - table_name = self.table_name - #get the data types of the attributes in the SQL table - for attr in list(self.columns): - datatype_query = "SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{}' AND COLUMN_NAME = '{}'".format(table_name, attr) - datatype = list(pd.read_sql(datatype_query, self.SQLconnection)['data_type'])[0] - sql_dtypes[attr] = datatype - - data_type = {"quantitative":[], "nominal":[], "temporal":[]} - for attr in list(self.columns): - if str(attr).lower() in ["month", "year"]: - data_type_lookup[attr] = "temporal" - data_type["temporal"].append(attr) - elif sql_dtypes[attr] in ["character", "character varying", "boolean", "uuid", "text"]: - data_type_lookup[attr] = "nominal" - data_type["nominal"].append(attr) - elif sql_dtypes[attr] in ["integer", "real", "smallint", "smallserial", "serial"]: - if self.cardinality[attr] < 13: - data_type_lookup[attr] = "nominal" - data_type["nominal"].append(attr) - else: - data_type_lookup[attr] = "quantitative" - data_type["quantitative"].append(attr) - elif "time" in sql_dtypes[attr] or "date" in sql_dtypes[attr]: - data_type_lookup[attr] = "temporal" - data_type["temporal"].append(attr) - self.data_type_lookup = data_type_lookup - self.data_type = data_type - def _append_rec(self,rec_infolist,recommendations:Dict): - if (recommendations["collection"] is not None and len(recommendations["collection"])>0): - rec_infolist.append(recommendations) - def maintain_recs(self): - # `rec_df` is the dataframe to generate the recommendations on - # check to see if globally defined actions have been registered/removed - if (lux.update_actions["flag"] == True): - self._recs_fresh = False - show_prev = False # flag indicating whether rec_df is showing previous df or current self - if self._prev is not None: - rec_df = self._prev - rec_df._message = Message() - rec_df.maintain_metadata() # the prev dataframe may not have been printed before - last_event = self.history._events[-1].name - rec_df._message.add(f"Lux is visualizing the previous version of the dataframe before you applied {last_event}.") - show_prev = True - else: - rec_df = self - rec_df._message = Message() - # Add warning message if there exist ID fields - id_fields_str = "" - if (len(rec_df.data_type["id"])>0): - for id_field in rec_df.data_type["id"]: id_fields_str += f"{id_field}, " - id_fields_str = id_fields_str[:-2] - rec_df._message.add(f"{id_fields_str} is not visualized since it resembles an ID field.") - rec_df._prev = None # reset _prev - - if (not hasattr(rec_df,"_recs_fresh") or not rec_df._recs_fresh ): # Check that recs has not yet been computed - rec_infolist = [] - from lux.action.custom import custom - from lux.action.custom import custom_actions - from lux.action.correlation import correlation - from lux.action.univariate import univariate - from lux.action.enhance import enhance - from lux.action.filter import filter - from lux.action.generalize import generalize - from lux.action.row_group import row_group - from lux.action.column_group import column_group - if (rec_df.pre_aggregated): - if (rec_df.columns.name is not None): - rec_df._append_rec(rec_infolist, row_group(rec_df)) - rec_df._append_rec(rec_infolist, column_group(rec_df)) - else: - if rec_df.recommendation == {}: - # display conditions for default actions - no_vis = lambda ldf: (ldf.current_vis is None) or (ldf.current_vis is not None and len(ldf.current_vis) == 0) - one_current_vis = lambda ldf: ldf.current_vis is not None and len(ldf.current_vis) == 1 - multiple_current_vis = lambda ldf: ldf.current_vis is not None and len(ldf.current_vis) > 1 - - # globally register default actions - lux.register_action("correlation", correlation, no_vis) - lux.register_action("distribution", univariate, no_vis, "quantitative") - lux.register_action("occurrence", univariate, no_vis, "nominal") - lux.register_action("temporal", univariate, no_vis, "temporal") - - lux.register_action("enhance", enhance, one_current_vis) - lux.register_action("filter", filter, one_current_vis) - lux.register_action("generalize", generalize, one_current_vis) - - lux.register_action("custom", custom, multiple_current_vis) - - # generate vis from globally registered actions and append to dataframe - custom_action_collection = custom_actions(rec_df) - for rec in custom_action_collection: - rec_df._append_rec(rec_infolist, rec) - lux.update_actions["flag"] = False - - # Store _rec_info into a more user-friendly dictionary form - rec_df.recommendation = {} - for rec_info in rec_infolist: - action_type = rec_info["action"] - vlist = rec_info["collection"] - if (rec_df._plot_config): - if (rec_df.current_vis): - for vis in rec_df.current_vis: vis._plot_config = rec_df.plot_config - for vis in vlist: vis._plot_config = rec_df.plot_config - if (len(vlist)>0): - rec_df.recommendation[action_type] = vlist - rec_df._rec_info = rec_infolist - self._widget = rec_df.render_widget() - elif (show_prev): # re-render widget for the current dataframe if previous rec is not recomputed - self._widget = rec_df.render_widget() - self._recs_fresh = True - - - ####################################################### - ############## LuxWidget Result Display ############### - ####################################################### - @property - def widget(self): - if(self._widget): - return self._widget - @property - def exported(self) -> Union[Dict[str,VisList], VisList]: - """ - Get selected visualizations as exported Vis List - - Notes - ----- - Convert the _selectedVisIdxs dictionary into a programmable VisList - Example _selectedVisIdxs : - {'Correlation': [0, 2], 'Occurrence': [1]} - indicating the 0th and 2nd vis from the `Correlation` tab is selected, and the 1st vis from the `Occurrence` tab is selected. - - Returns - ------- - Union[Dict[str,VisList], VisList] - When there are no exported vis, return empty list -> [] - When all the exported vis is from the same tab, return a VisList of selected visualizations. -> VisList(v1, v2...) - When the exported vis is from the different tabs, return a dictionary with the action name as key and selected visualizations in the VisList. -> {"Enhance": VisList(v1, v2...), "Filter": VisList(v5, v7...), ..} - """ - if not hasattr(self,"_widget"): - warnings.warn( - "\nNo widget attached to the dataframe." - "Please assign dataframe to an output variable.\n" - "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips" - , stacklevel=2) - return [] - exported_vis_lst = self._widget._selectedVisIdxs - exported_vis = [] - if (exported_vis_lst=={}): - if self._saved_export: - return self._saved_export - warnings.warn( - "\nNo visualization selected to export.\n" - "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips" - ,stacklevel=2) - return [] - if len(exported_vis_lst) == 1 and "currentVis" in exported_vis_lst: - return self.current_vis - elif len(exported_vis_lst) > 1: - exported_vis = {} - if ("currentVis" in exported_vis_lst): - exported_vis["Current Vis"] = self.current_vis - for export_action in exported_vis_lst: - if (export_action != "currentVis"): - exported_vis[export_action] = VisList(list(map(self.recommendation[export_action].__getitem__, exported_vis_lst[export_action]))) - return exported_vis - elif len(exported_vis_lst) == 1 and ("currentVis" not in exported_vis_lst): - export_action = list(exported_vis_lst.keys())[0] - exported_vis = VisList(list(map(self.recommendation[export_action].__getitem__, exported_vis_lst[export_action]))) - self._saved_export = exported_vis - return exported_vis - else: - warnings.warn( - "\nNo visualization selected to export.\n" - "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips" - ,stacklevel=2) - return [] - - def remove_deleted_recs(self, change): - for action in self._widget.deletedIndices: - deletedSoFar = 0 - for index in self._widget.deletedIndices[action]: - self.recommendation[action].remove_index(index - deletedSoFar) - deletedSoFar += 1 - - def set_intent_on_click(self, change): - from IPython.display import display, clear_output - from lux.processor.Compiler import Compiler - - intent_action = list(self._widget.selectedIntentIndex.keys())[0] - vis = self.recommendation[intent_action][self._widget.selectedIntentIndex[intent_action][0]] - self.set_intent_as_vis(vis) - - self.maintain_metadata() - self.current_vis = Compiler.compile_intent(self, self._intent) - self.maintain_recs() - - with self.output: - clear_output() - display(self._widget) - - self._widget.observe(self.remove_deleted_recs, names='deletedIndices') - self._widget.observe(self.set_intent_on_click, names='selectedIntentIndex') - - def _repr_html_(self): - from IPython.display import display - from IPython.display import clear_output - import ipywidgets as widgets - - try: - if (self._pandas_only): - display(self.display_pandas()) - self._pandas_only=False - else: - if(self.index.nlevels>=2 or self.columns.nlevels >= 2): - warnings.warn( - "\nLux does not currently support dataframes " - "with hierarchical indexes.\n" - "Please convert the dataframe into a flat " - "table via `pandas.DataFrame.reset_index`.\n", - stacklevel=2, - ) - display(self.display_pandas()) - return - - if (len(self)<=0): - warnings.warn("\nLux can not operate on an empty dataframe.\nPlease check your input again.\n",stacklevel=2) - display(self.display_pandas()) - return - if (len(self.columns)<=1): - warnings.warn("\nLux defaults to Pandas when there is only a single column.",stacklevel=2) - display(self.display_pandas()) - return - self.maintain_metadata() - - if (self._intent!=[] and (not hasattr(self,"_compiled") or not self._compiled)): - from lux.processor.Compiler import Compiler - self.current_vis = Compiler.compile_intent(self, self._intent) - - if (lux.config.default_display == "lux"): - self._toggle_pandas_display = False - else: - self._toggle_pandas_display = True - - # df_to_display.maintain_recs() # compute the recommendations (TODO: This can be rendered in another thread in the background to populate self._widget) - self.maintain_recs() - - #Observers(callback_function, listen_to_this_variable) - self._widget.observe(self.remove_deleted_recs, names='deletedIndices') - self._widget.observe(self.set_intent_on_click, names='selectedIntentIndex') - - if len(self.recommendation) > 0: - # box = widgets.Box(layout=widgets.Layout(display='inline')) - button = widgets.Button(description="Toggle Pandas/Lux",layout=widgets.Layout(width='140px',top='5px')) - self.output = widgets.Output() - # box.children = [button,output] - # output.children = [button] - # display(box) - display(button, self.output) - def on_button_clicked(b): - with self.output: - if (b): - self._toggle_pandas_display = not self._toggle_pandas_display - clear_output() - if (self._toggle_pandas_display): - display(self.display_pandas()) - else: - # b.layout.display = "none" - display(self._widget) - # b.layout.display = "inline-block" - button.on_click(on_button_clicked) - on_button_clicked(None) - else: - warnings.warn("\nLux defaults to Pandas when there are no valid actions defined.",stacklevel=2) - display(self.display_pandas()) - - except(KeyboardInterrupt,SystemExit): - raise - except: - warnings.warn( - "\nUnexpected error in rendering Lux widget and recommendations. " - "Falling back to Pandas display.\n\n" - "Please report this issue on Github: https://github.com/lux-org/lux/issues " - ,stacklevel=2) - display(self.display_pandas()) - def display_pandas(self): - return self.to_pandas() - def render_widget(self, renderer:str ="altair", input_current_vis=""): - """ - Generate a LuxWidget based on the LuxDataFrame - - Structure of widgetJSON: - { - 'current_vis': {}, - 'recommendation': [ - { - 'action': 'Correlation', - 'description': "some description", - 'vspec': [ - {Vega-Lite spec for vis 1}, - {Vega-Lite spec for vis 2}, - ... - ] - }, - ... repeat for other actions - ] - } - Parameters - ---------- - renderer : str, optional - Choice of visualization rendering library, by default "altair" - input_current_vis : lux.LuxDataFrame, optional - User-specified current vis to override default Current Vis, by default - """ - check_import_lux_widget() - import luxwidget - widgetJSON = self.to_JSON(self._rec_info, input_current_vis=input_current_vis) - return luxwidget.LuxWidget( - currentVis=widgetJSON["current_vis"], - recommendations=widgetJSON["recommendation"], - intent=LuxDataFrame.intent_to_string(self._intent), - message = self._message.to_html() - ) - @staticmethod - def intent_to_JSON(intent): - from lux.utils import utils - - filter_specs = utils.get_filter_specs(intent) - attrs_specs = utils.get_attrs_specs(intent) - - intent = {} - intent['attributes'] = [clause.attribute for clause in attrs_specs] - intent['filters'] = [clause.attribute for clause in filter_specs] - return intent - @staticmethod - def intent_to_string(intent): - if (intent): - return ", ".join([clause.to_string() for clause in intent]) - else: - return "" - - def to_JSON(self, rec_infolist, input_current_vis=""): - widget_spec = {} - if (self.current_vis): - self.executor.execute(self.current_vis, self) - widget_spec["current_vis"] = LuxDataFrame.current_vis_to_JSON(self.current_vis, input_current_vis) - else: - widget_spec["current_vis"] = {} - widget_spec["recommendation"] = [] - - # Recommended Collection - recCollection = LuxDataFrame.rec_to_JSON(rec_infolist) - widget_spec["recommendation"].extend(recCollection) - return widget_spec - - @staticmethod - def current_vis_to_JSON(vlist, input_current_vis=""): - current_vis_spec = {} - numVC = len(vlist) #number of visualizations in the vis list - if (numVC==1): - current_vis_spec = vlist[0].render_VSpec() - elif (numVC>1): - pass - return current_vis_spec - - @staticmethod - def rec_to_JSON(recs): - rec_lst = [] - import copy - rec_copy = copy.deepcopy(recs) - for idx,rec in enumerate(rec_copy): - if (len(rec["collection"])>0): - rec["vspec"] = [] - for vis in rec["collection"]: - chart = vis.render_VSpec() - rec["vspec"].append(chart) - rec_lst.append(rec) - # delete DataObjectCollection since not JSON serializable - del rec_lst[idx]["collection"] - return rec_lst - - # Overridden Pandas Functions - def head(self, n: int = 5): - self._prev = self - self._history.append_event("head", n=5) - return super(LuxDataFrame, self).head(n) - - def tail(self, n: int = 5): - self._prev = self - self._history.append_event("tail", n=5) - return super(LuxDataFrame, self).tail(n) - - def info(self, *args, **kwargs): - self._pandas_only=True - self._history.append_event("info",*args, **kwargs) - return super(LuxDataFrame, self).info(*args, **kwargs) - - def describe(self, *args, **kwargs): - self._pandas_only=True - self._history.append_event("describe",*args, **kwargs) - return super(LuxDataFrame, self).describe(*args, **kwargs) + """ + A subclass of pd.DataFrame that supports all dataframe operations while housing other variables and functions for generating visual recommendations. + """ + + # MUST register here for new properties!! + _metadata = [ + "_intent", + "data_type_lookup", + "data_type", + "data_model_lookup", + "data_model", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + ] + + def __init__(self, *args, **kw): + from lux.executor.PandasExecutor import PandasExecutor + + self._history = History() + self._intent = [] + self._recommendation = {} + self._saved_export = None + self._current_vis = [] + self._prev = None + super(LuxDataFrame, self).__init__(*args, **kw) + + self.executor_type = "Pandas" + self.executor = PandasExecutor() + self.SQLconnection = "" + self.table_name = "" + + self._sampled = None + self._default_pandas_display = True + self._toggle_pandas_display = True + self._plot_config = None + self._message = Message() + self._pandas_only = False + # Metadata + self.data_type_lookup = None + self.data_type = None + self.data_model_lookup = None + self.data_model = None + self.unique_values = None + self.cardinality = None + self._min_max = None + self.pre_aggregated = None + + @property + def _constructor(self): + return LuxDataFrame + + @property + def _constructor_sliced(self): + def f(*args, **kwargs): + s = LuxSeries(*args, **kwargs) + for attr in self._metadata: # propagate metadata + s.__dict__[attr] = getattr(self, attr, None) + return s + + return f + + @property + def history(self): + return self._history + + def maintain_metadata(self): + if ( + not hasattr(self, "_metadata_fresh") or not self._metadata_fresh + ): # Check that metadata has not yet been computed + if ( + len(self) > 0 + ): # only compute metadata information if the dataframe is non-empty + self.executor.compute_stats(self) + self.executor.compute_dataset_metadata(self) + self._infer_structure() + self._metadata_fresh = True + + def expire_recs(self): + self._recs_fresh = False + self.recommendation = {} + self.current_vis = None + self._widget = None + self._rec_info = None + self._sampled = None + + def expire_metadata(self): + # Set metadata as null + self._metadata_fresh = False + self.data_type_lookup = None + self.data_type = None + self.data_model_lookup = None + self.data_model = None + self.unique_values = None + self.cardinality = None + self._min_max = None + self.pre_aggregated = None + + ##################### + ## Override Pandas ## + ##################### + def __getattr__(self, name): + ret_value = super(LuxDataFrame, self).__getattr__(name) + self.expire_metadata() + self.expire_recs() + return ret_value + + def _set_axis(self, axis, labels): + super(LuxDataFrame, self)._set_axis(axis, labels) + self.expire_metadata() + self.expire_recs() + + def _update_inplace(self, *args, **kwargs): + super(LuxDataFrame, self)._update_inplace(*args, **kwargs) + self.expire_metadata() + self.expire_recs() + + def _set_item(self, key, value): + super(LuxDataFrame, self)._set_item(key, value) + self.expire_metadata() + self.expire_recs() + + def _infer_structure(self): + # If the dataframe is very small and the index column is not a range index, then it is likely that this is an aggregated data + is_multi_index_flag = self.index.nlevels != 1 + not_int_index_flag = self.index.dtype != "int64" + small_df_flag = len(self) < 100 + self.pre_aggregated = ( + is_multi_index_flag or not_int_index_flag + ) and small_df_flag + if "Number of Records" in self.columns: + self.pre_aggregated = True + very_small_df_flag = len(self) <= 10 + if very_small_df_flag: + self.pre_aggregated = True + + def set_executor_type(self, exe): + if exe == "SQL": + import pkgutil + + if pkgutil.find_loader("psycopg2") is None: + raise ImportError( + "psycopg2 is not installed. Run `pip install psycopg2' to install psycopg2 to enable the Postgres connection." + ) + else: + import psycopg2 + from lux.executor.SQLExecutor import SQLExecutor + + self.executor = SQLExecutor + else: + from lux.executor.PandasExecutor import PandasExecutor + + self.executor = PandasExecutor() + self.executor_type = exe + + @property + def plot_config(self): + return self._plot_config + + @plot_config.setter + def plot_config(self, config_func: Callable): + """ + Modify plot aesthetic settings to all visualizations in the dataframe display + Currently only supported for Altair visualizations + Parameters + ---------- + config_func : Callable + A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output + + Example + ---------- + Changing the color of marks and adding a title for all charts displayed for this dataframe + >>> df = pd.read_csv("lux/data/car.csv") + >>> def changeColorAddTitle(chart): + chart = chart.configure_mark(color="red") # change mark color to red + chart.title = "Custom Title" # add title to chart + return chart + >>> df.plot_config = changeColorAddTitle + >>> df + Change the opacity of all scatterplots displayed for this dataframe + >>> df = pd.read_csv("lux/data/olympic.csv") + >>> def changeOpacityScatterOnly(chart): + if chart.mark=='circle': + chart = chart.configure_mark(opacity=0.1) # lower opacity + return chart + >>> df.plot_config = changeOpacityScatterOnly + >>> df + """ + self._plot_config = config_func + self._recs_fresh = False + + def clear_plot_config(self): + self._plot_config = None + self._recs_fresh = False + + @property + def intent(self): + return self._intent + + @intent.setter + def intent(self, intent_input: Union[List[Union[str, Clause]], Vis]): + is_list_input = isinstance(intent_input, list) + is_vis_input = isinstance(intent_input, Vis) + if not (is_list_input or is_vis_input): + raise TypeError( + "Input intent must be either a list (of strings or lux.Clause) or a lux.Vis object." + "\nSee more at: https://lux-api.readthedocs.io/en/latest/source/guide/intent.html" + ) + if is_list_input: + self.set_intent(intent_input) + elif is_vis_input: + self.set_intent_as_vis(intent_input) + + def clear_intent(self): + self.intent = [] + + def set_intent(self, intent: List[Union[str, Clause]]): + """ + Main function to set the intent of the dataframe. + The intent input goes through the parser, so that the string inputs are parsed into a lux.Clause object. + + Parameters + ---------- + intent : List[str,Clause] + intent list, can be a mix of string shorthand or a lux.Clause object + + Notes + ----- + :doc:`../guide/clause` + """ + self.expire_recs() + self._intent = intent + self._parse_validate_compile_intent() + + def _parse_validate_compile_intent(self): + from lux.processor.Parser import Parser + from lux.processor.Validator import Validator + + self._intent = Parser.parse(self._intent) + Validator.validate_intent(self._intent, self) + self.maintain_metadata() + from lux.processor.Compiler import Compiler + + self.current_vis = Compiler.compile_intent(self, self._intent) + + def copy_intent(self): + # creates a true copy of the dataframe's intent + output = [] + for clause in self._intent: + temp_clause = clause.copy_clause() + output.append(temp_clause) + return output + + def set_intent_as_vis(self, vis: Vis): + """ + Set intent of the dataframe as the Vis + + Parameters + ---------- + vis : Vis + """ + self.expire_recs() + self._intent = vis._inferred_intent + self._parse_validate_compile_intent() + + def to_pandas(self): + import lux.core + + return lux.core.originalDF(self, copy=False) + + @property + def recommendation(self): + return self._recommendation + + @recommendation.setter + def recommendation(self, recommendation: Dict): + self._recommendation = recommendation + + @property + def current_vis(self): + return self._current_vis + + @current_vis.setter + def current_vis(self, current_vis: Dict): + self._current_vis = current_vis + + def __repr__(self): + # TODO: _repr_ gets called from _repr_html, need to get rid of this call + return "" + + ####################################################### + ########## SQL Metadata, type, model schema ########### + ####################################################### + + def set_SQL_connection(self, connection, t_name): + self.SQLconnection = connection + self.table_name = t_name + self.compute_SQL_dataset_metadata() + self.set_executor_type("SQL") + + def compute_SQL_dataset_metadata(self): + self.get_SQL_attributes() + for attr in list(self.columns): + self[attr] = None + self.data_type_lookup = {} + self.data_type = {} + #####NOTE: since we aren't expecting users to do much data processing with the SQL database, should we just keep this + ##### in the initialization and do it just once + self.compute_SQL_data_type() + self.compute_SQL_stats() + self.data_model_lookup = {} + self.data_model = {} + self.compute_data_model() + + def compute_SQL_stats(self): + # precompute statistics + self.unique_values = {} + self._min_max = {} + + self.get_SQL_unique_values() + # self.get_SQL_cardinality() + for attribute in self.columns: + if self.data_type_lookup[attribute] == "quantitative": + self._min_max[attribute] = ( + self[attribute].min(), + self[attribute].max(), + ) + + def get_SQL_attributes(self): + if "." in self.table_name: + table_name = self.table_name[self.table_name.index(".") + 1 :] + else: + table_name = self.table_name + attr_query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS where TABLE_NAME = '{}'".format( + table_name + ) + attributes = list(pd.read_sql(attr_query, self.SQLconnection)["column_name"]) + for attr in attributes: + self[attr] = None + + def get_SQL_cardinality(self): + cardinality = {} + for attr in list(self.columns): + card_query = pd.read_sql( + "SELECT Count(Distinct({})) FROM {}".format(attr, self.table_name), + self.SQLconnection, + ) + cardinality[attr] = list(card_query["count"])[0] + self.cardinality = cardinality + + def get_SQL_unique_values(self): + unique_vals = {} + for attr in list(self.columns): + unique_query = pd.read_sql( + "SELECT Distinct({}) FROM {}".format(attr, self.table_name), + self.SQLconnection, + ) + unique_vals[attr] = list(unique_query[attr]) + self.unique_values = unique_vals + + def compute_SQL_data_type(self): + data_type_lookup = {} + sql_dtypes = {} + self.get_SQL_cardinality() + if "." in self.table_name: + table_name = self.table_name[self.table_name.index(".") + 1 :] + else: + table_name = self.table_name + # get the data types of the attributes in the SQL table + for attr in list(self.columns): + datatype_query = "SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{}' AND COLUMN_NAME = '{}'".format( + table_name, attr + ) + datatype = list( + pd.read_sql(datatype_query, self.SQLconnection)["data_type"] + )[0] + sql_dtypes[attr] = datatype + + data_type = {"quantitative": [], "nominal": [], "temporal": []} + for attr in list(self.columns): + if str(attr).lower() in ["month", "year"]: + data_type_lookup[attr] = "temporal" + data_type["temporal"].append(attr) + elif sql_dtypes[attr] in [ + "character", + "character varying", + "boolean", + "uuid", + "text", + ]: + data_type_lookup[attr] = "nominal" + data_type["nominal"].append(attr) + elif sql_dtypes[attr] in [ + "integer", + "real", + "smallint", + "smallserial", + "serial", + ]: + if self.cardinality[attr] < 13: + data_type_lookup[attr] = "nominal" + data_type["nominal"].append(attr) + else: + data_type_lookup[attr] = "quantitative" + data_type["quantitative"].append(attr) + elif "time" in sql_dtypes[attr] or "date" in sql_dtypes[attr]: + data_type_lookup[attr] = "temporal" + data_type["temporal"].append(attr) + self.data_type_lookup = data_type_lookup + self.data_type = data_type + + def _append_rec(self, rec_infolist, recommendations: Dict): + if ( + recommendations["collection"] is not None + and len(recommendations["collection"]) > 0 + ): + rec_infolist.append(recommendations) + + def maintain_recs(self): + # `rec_df` is the dataframe to generate the recommendations on + # check to see if globally defined actions have been registered/removed + if lux.update_actions["flag"] == True: + self._recs_fresh = False + show_prev = False # flag indicating whether rec_df is showing previous df or current self + if self._prev is not None: + rec_df = self._prev + rec_df._message = Message() + rec_df.maintain_metadata() # the prev dataframe may not have been printed before + last_event = self.history._events[-1].name + rec_df._message.add( + f"Lux is visualizing the previous version of the dataframe before you applied {last_event}." + ) + show_prev = True + else: + rec_df = self + rec_df._message = Message() + # Add warning message if there exist ID fields + id_fields_str = "" + if len(rec_df.data_type["id"]) > 0: + for id_field in rec_df.data_type["id"]: + id_fields_str += f"{id_field}, " + id_fields_str = id_fields_str[:-2] + rec_df._message.add( + f"{id_fields_str} is not visualized since it resembles an ID field." + ) + rec_df._prev = None # reset _prev + + if ( + not hasattr(rec_df, "_recs_fresh") or not rec_df._recs_fresh + ): # Check that recs has not yet been computed + rec_infolist = [] + from lux.action.custom import custom + from lux.action.custom import custom_actions + from lux.action.correlation import correlation + from lux.action.univariate import univariate + from lux.action.enhance import enhance + from lux.action.filter import filter + from lux.action.generalize import generalize + from lux.action.row_group import row_group + from lux.action.column_group import column_group + + if rec_df.pre_aggregated: + if rec_df.columns.name is not None: + rec_df._append_rec(rec_infolist, row_group(rec_df)) + rec_df._append_rec(rec_infolist, column_group(rec_df)) + else: + if rec_df.recommendation == {}: + # display conditions for default actions + no_vis = lambda ldf: (ldf.current_vis is None) or ( + ldf.current_vis is not None and len(ldf.current_vis) == 0 + ) + one_current_vis = ( + lambda ldf: ldf.current_vis is not None + and len(ldf.current_vis) == 1 + ) + multiple_current_vis = ( + lambda ldf: ldf.current_vis is not None + and len(ldf.current_vis) > 1 + ) + + # globally register default actions + lux.register_action("correlation", correlation, no_vis) + lux.register_action( + "distribution", univariate, no_vis, "quantitative" + ) + lux.register_action("occurrence", univariate, no_vis, "nominal") + lux.register_action("temporal", univariate, no_vis, "temporal") + + lux.register_action("enhance", enhance, one_current_vis) + lux.register_action("filter", filter, one_current_vis) + lux.register_action("generalize", generalize, one_current_vis) + + lux.register_action("custom", custom, multiple_current_vis) + + # generate vis from globally registered actions and append to dataframe + custom_action_collection = custom_actions(rec_df) + for rec in custom_action_collection: + rec_df._append_rec(rec_infolist, rec) + lux.update_actions["flag"] = False + + # Store _rec_info into a more user-friendly dictionary form + rec_df.recommendation = {} + for rec_info in rec_infolist: + action_type = rec_info["action"] + vlist = rec_info["collection"] + if rec_df._plot_config: + if rec_df.current_vis: + for vis in rec_df.current_vis: + vis._plot_config = rec_df.plot_config + for vis in vlist: + vis._plot_config = rec_df.plot_config + if len(vlist) > 0: + rec_df.recommendation[action_type] = vlist + rec_df._rec_info = rec_infolist + self._widget = rec_df.render_widget() + elif ( + show_prev + ): # re-render widget for the current dataframe if previous rec is not recomputed + self._widget = rec_df.render_widget() + self._recs_fresh = True + + ####################################################### + ############## LuxWidget Result Display ############### + ####################################################### + @property + def widget(self): + if self._widget: + return self._widget + + @property + def exported(self) -> Union[Dict[str, VisList], VisList]: + """ + Get selected visualizations as exported Vis List + + Notes + ----- + Convert the _selectedVisIdxs dictionary into a programmable VisList + Example _selectedVisIdxs : + {'Correlation': [0, 2], 'Occurrence': [1]} + indicating the 0th and 2nd vis from the `Correlation` tab is selected, and the 1st vis from the `Occurrence` tab is selected. + + Returns + ------- + Union[Dict[str,VisList], VisList] + When there are no exported vis, return empty list -> [] + When all the exported vis is from the same tab, return a VisList of selected visualizations. -> VisList(v1, v2...) + When the exported vis is from the different tabs, return a dictionary with the action name as key and selected visualizations in the VisList. -> {"Enhance": VisList(v1, v2...), "Filter": VisList(v5, v7...), ..} + """ + if not hasattr(self, "_widget"): + warnings.warn( + "\nNo widget attached to the dataframe." + "Please assign dataframe to an output variable.\n" + "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips", + stacklevel=2, + ) + return [] + exported_vis_lst = self._widget._selectedVisIdxs + exported_vis = [] + if exported_vis_lst == {}: + if self._saved_export: + return self._saved_export + warnings.warn( + "\nNo visualization selected to export.\n" + "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips", + stacklevel=2, + ) + return [] + if len(exported_vis_lst) == 1 and "currentVis" in exported_vis_lst: + return self.current_vis + elif len(exported_vis_lst) > 1: + exported_vis = {} + if "currentVis" in exported_vis_lst: + exported_vis["Current Vis"] = self.current_vis + for export_action in exported_vis_lst: + if export_action != "currentVis": + exported_vis[export_action] = VisList( + list( + map( + self.recommendation[export_action].__getitem__, + exported_vis_lst[export_action], + ) + ) + ) + return exported_vis + elif len(exported_vis_lst) == 1 and ("currentVis" not in exported_vis_lst): + export_action = list(exported_vis_lst.keys())[0] + exported_vis = VisList( + list( + map( + self.recommendation[export_action].__getitem__, + exported_vis_lst[export_action], + ) + ) + ) + self._saved_export = exported_vis + return exported_vis + else: + warnings.warn( + "\nNo visualization selected to export.\n" + "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips", + stacklevel=2, + ) + return [] + + def remove_deleted_recs(self, change): + for action in self._widget.deletedIndices: + deletedSoFar = 0 + for index in self._widget.deletedIndices[action]: + self.recommendation[action].remove_index(index - deletedSoFar) + deletedSoFar += 1 + + def set_intent_on_click(self, change): + from IPython.display import display, clear_output + from lux.processor.Compiler import Compiler + + intent_action = list(self._widget.selectedIntentIndex.keys())[0] + vis = self.recommendation[intent_action][ + self._widget.selectedIntentIndex[intent_action][0] + ] + self.set_intent_as_vis(vis) + + self.maintain_metadata() + self.current_vis = Compiler.compile_intent(self, self._intent) + self.maintain_recs() + + with self.output: + clear_output() + display(self._widget) + + self._widget.observe(self.remove_deleted_recs, names="deletedIndices") + self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") + + def _repr_html_(self): + from IPython.display import display + from IPython.display import clear_output + import ipywidgets as widgets + + try: + if self._pandas_only: + display(self.display_pandas()) + self._pandas_only = False + else: + if self.index.nlevels >= 2 or self.columns.nlevels >= 2: + warnings.warn( + "\nLux does not currently support dataframes " + "with hierarchical indexes.\n" + "Please convert the dataframe into a flat " + "table via `pandas.DataFrame.reset_index`.\n", + stacklevel=2, + ) + display(self.display_pandas()) + return + + if len(self) <= 0: + warnings.warn( + "\nLux can not operate on an empty dataframe.\nPlease check your input again.\n", + stacklevel=2, + ) + display(self.display_pandas()) + return + if len(self.columns) <= 1: + warnings.warn( + "\nLux defaults to Pandas when there is only a single column.", + stacklevel=2, + ) + display(self.display_pandas()) + return + self.maintain_metadata() + + if self._intent != [] and ( + not hasattr(self, "_compiled") or not self._compiled + ): + from lux.processor.Compiler import Compiler + + self.current_vis = Compiler.compile_intent(self, self._intent) + + if lux.config.default_display == "lux": + self._toggle_pandas_display = False + else: + self._toggle_pandas_display = True + + # df_to_display.maintain_recs() # compute the recommendations (TODO: This can be rendered in another thread in the background to populate self._widget) + self.maintain_recs() + + # Observers(callback_function, listen_to_this_variable) + self._widget.observe(self.remove_deleted_recs, names="deletedIndices") + self._widget.observe( + self.set_intent_on_click, names="selectedIntentIndex" + ) + + if len(self.recommendation) > 0: + # box = widgets.Box(layout=widgets.Layout(display='inline')) + button = widgets.Button( + description="Toggle Pandas/Lux", + layout=widgets.Layout(width="140px", top="5px"), + ) + self.output = widgets.Output() + # box.children = [button,output] + # output.children = [button] + # display(box) + display(button, self.output) + + def on_button_clicked(b): + with self.output: + if b: + self._toggle_pandas_display = ( + not self._toggle_pandas_display + ) + clear_output() + if self._toggle_pandas_display: + display(self.display_pandas()) + else: + # b.layout.display = "none" + display(self._widget) + # b.layout.display = "inline-block" + + button.on_click(on_button_clicked) + on_button_clicked(None) + else: + warnings.warn( + "\nLux defaults to Pandas when there are no valid actions defined.", + stacklevel=2, + ) + display(self.display_pandas()) + + except (KeyboardInterrupt, SystemExit): + raise + except: + warnings.warn( + "\nUnexpected error in rendering Lux widget and recommendations. " + "Falling back to Pandas display.\n\n" + "Please report this issue on Github: https://github.com/lux-org/lux/issues ", + stacklevel=2, + ) + display(self.display_pandas()) + + def display_pandas(self): + return self.to_pandas() + + def render_widget(self, renderer: str = "altair", input_current_vis=""): + """ + Generate a LuxWidget based on the LuxDataFrame + + Structure of widgetJSON: + { + 'current_vis': {}, + 'recommendation': [ + { + 'action': 'Correlation', + 'description': "some description", + 'vspec': [ + {Vega-Lite spec for vis 1}, + {Vega-Lite spec for vis 2}, + ... + ] + }, + ... repeat for other actions + ] + } + Parameters + ---------- + renderer : str, optional + Choice of visualization rendering library, by default "altair" + input_current_vis : lux.LuxDataFrame, optional + User-specified current vis to override default Current Vis, by default + """ + check_import_lux_widget() + import luxwidget + + widgetJSON = self.to_JSON(self._rec_info, input_current_vis=input_current_vis) + return luxwidget.LuxWidget( + currentVis=widgetJSON["current_vis"], + recommendations=widgetJSON["recommendation"], + intent=LuxDataFrame.intent_to_string(self._intent), + message=self._message.to_html(), + ) + + @staticmethod + def intent_to_JSON(intent): + from lux.utils import utils + + filter_specs = utils.get_filter_specs(intent) + attrs_specs = utils.get_attrs_specs(intent) + + intent = {} + intent["attributes"] = [clause.attribute for clause in attrs_specs] + intent["filters"] = [clause.attribute for clause in filter_specs] + return intent + + @staticmethod + def intent_to_string(intent): + if intent: + return ", ".join([clause.to_string() for clause in intent]) + else: + return "" + + def to_JSON(self, rec_infolist, input_current_vis=""): + widget_spec = {} + if self.current_vis: + self.executor.execute(self.current_vis, self) + widget_spec["current_vis"] = LuxDataFrame.current_vis_to_JSON( + self.current_vis, input_current_vis + ) + else: + widget_spec["current_vis"] = {} + widget_spec["recommendation"] = [] + + # Recommended Collection + recCollection = LuxDataFrame.rec_to_JSON(rec_infolist) + widget_spec["recommendation"].extend(recCollection) + return widget_spec + + @staticmethod + def current_vis_to_JSON(vlist, input_current_vis=""): + current_vis_spec = {} + numVC = len(vlist) # number of visualizations in the vis list + if numVC == 1: + current_vis_spec = vlist[0].render_VSpec() + elif numVC > 1: + pass + return current_vis_spec + + @staticmethod + def rec_to_JSON(recs): + rec_lst = [] + import copy + + rec_copy = copy.deepcopy(recs) + for idx, rec in enumerate(rec_copy): + if len(rec["collection"]) > 0: + rec["vspec"] = [] + for vis in rec["collection"]: + chart = vis.render_VSpec() + rec["vspec"].append(chart) + rec_lst.append(rec) + # delete DataObjectCollection since not JSON serializable + del rec_lst[idx]["collection"] + return rec_lst + + # Overridden Pandas Functions + def head(self, n: int = 5): + self._prev = self + self._history.append_event("head", n=5) + return super(LuxDataFrame, self).head(n) + + def tail(self, n: int = 5): + self._prev = self + self._history.append_event("tail", n=5) + return super(LuxDataFrame, self).tail(n) + + def info(self, *args, **kwargs): + self._pandas_only = True + self._history.append_event("info", *args, **kwargs) + return super(LuxDataFrame, self).info(*args, **kwargs) + + def describe(self, *args, **kwargs): + self._pandas_only = True + self._history.append_event("describe", *args, **kwargs) + return super(LuxDataFrame, self).describe(*args, **kwargs) diff --git a/lux/core/series.py b/lux/core/series.py index ecfd9aa2..d987b4e5 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,24 +13,45 @@ # limitations under the License. import pandas as pd + + class LuxSeries(pd.Series): - _metadata = ['_intent','data_type_lookup','data_type', - 'data_model_lookup','data_model','unique_values','cardinality','_rec_info', '_pandas_only', - '_min_max','plot_config', '_current_vis','_widget', '_recommendation','_prev','_history', '_saved_export'] - def __init__(self,*args, **kw): - super(LuxSeries, self).__init__(*args, **kw) - - @property - def _constructor(self): - return LuxSeries - - @property - def _constructor_expanddim(self): - from lux.core.frame import LuxDataFrame - def f(*args, **kwargs): - df = LuxDataFrame(*args, **kwargs) - for attr in self._metadata: - df.__dict__[attr] = getattr(self, attr, None) - return df - f._get_axis_number = super(LuxSeries, self)._get_axis_number - return f \ No newline at end of file + _metadata = [ + "_intent", + "data_type_lookup", + "data_type", + "data_model_lookup", + "data_model", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + ] + + def __init__(self, *args, **kw): + super(LuxSeries, self).__init__(*args, **kw) + + @property + def _constructor(self): + return LuxSeries + + @property + def _constructor_expanddim(self): + from lux.core.frame import LuxDataFrame + + def f(*args, **kwargs): + df = LuxDataFrame(*args, **kwargs) + for attr in self._metadata: + df.__dict__[attr] = getattr(self, attr, None) + return df + + f._get_axis_number = super(LuxSeries, self)._get_axis_number + return f diff --git a/lux/executor/Executor.py b/lux/executor/Executor.py index f2216131..972f6fb6 100644 --- a/lux/executor/Executor.py +++ b/lux/executor/Executor.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,26 +14,31 @@ from lux.vis.VisList import VisList from lux.utils import utils + + class Executor: """ Abstract class for the execution engine that fetches data for a given vis on a LuxDataFrame - """ + """ + def __init__(self): self.name = "Executor" def __repr__(self): return f"" + @staticmethod - def execute(vis_collection:VisList, ldf): + def execute(vis_collection: VisList, ldf): return NotImplemented @staticmethod def execute_aggregate(vis, ldf): return NotImplemented + @staticmethod def execute_binning(vis, ldf): return NotImplemented - + @staticmethod def execute_filter(vis, ldf): return NotImplemented @@ -61,4 +66,4 @@ def reverseMapping(self, map): for valKey in map: for val in map[valKey]: reverse_map[val] = valKey - return reverse_map \ No newline at end of file + return reverse_map diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index a8d17d36..64fa2e54 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,33 +25,42 @@ class PandasExecutor(Executor): - ''' + """ Given a Vis objects with complete specifications, fetch and process data using Pandas dataframe operations. - ''' + """ + def __init__(self): self.name = "PandasExecutor" def __repr__(self): return f"" + @staticmethod - def execute_sampling(ldf:LuxDataFrame): + def execute_sampling(ldf: LuxDataFrame): # General Sampling for entire dataframe SAMPLE_START = 10000 SAMPLE_CAP = 30000 SAMPLE_FRAC = 0.75 if len(ldf) > SAMPLE_CAP: - if (ldf._sampled is None): # memoize unfiltered sample df - ldf._sampled = ldf.sample(n = SAMPLE_CAP , random_state = 1) - ldf._message.add_unique(f"Large dataframe detected: Lux is only visualizing a random sample capped at {SAMPLE_CAP} rows.", priority=99) + if ldf._sampled is None: # memoize unfiltered sample df + ldf._sampled = ldf.sample(n=SAMPLE_CAP, random_state=1) + ldf._message.add_unique( + f"Large dataframe detected: Lux is only visualizing a random sample capped at {SAMPLE_CAP} rows.", + priority=99, + ) elif len(ldf) > SAMPLE_START: - if (ldf._sampled is None): # memoize unfiltered sample df - ldf._sampled = ldf.sample(frac= SAMPLE_FRAC, random_state = 1) - ldf._message.add_unique(f"Large dataframe detected: Lux is only visualizing a random sample of {len(ldf._sampled)} rows.", priority=99) + if ldf._sampled is None: # memoize unfiltered sample df + ldf._sampled = ldf.sample(frac=SAMPLE_FRAC, random_state=1) + ldf._message.add_unique( + f"Large dataframe detected: Lux is only visualizing a random sample of {len(ldf._sampled)} rows.", + priority=99, + ) else: ldf._sampled = ldf + @staticmethod - def execute(vislist:VisList, ldf:LuxDataFrame): - ''' + def execute(vislist: VisList, ldf: LuxDataFrame): + """ Given a VisList, fetch the data required to render the vis. 1) Apply filters 2) Retrieve relevant attribute @@ -68,36 +77,39 @@ def execute(vislist:VisList, ldf:LuxDataFrame): Returns ------- None - ''' + """ PandasExecutor.execute_sampling(ldf) for vis in vislist: - vis._vis_data = ldf._sampled # The vis data starts off being original or sampled dataframe + vis._vis_data = ( + ldf._sampled + ) # The vis data starts off being original or sampled dataframe filter_executed = PandasExecutor.execute_filter(vis) # Select relevant data based on attribute information attributes = set([]) for clause in vis._inferred_intent: - if (clause.attribute): - if (clause.attribute!="Record"): + if clause.attribute: + if clause.attribute != "Record": attributes.add(clause.attribute) # TODO: Add some type of cap size on Nrows ? vis._vis_data = vis.data[list(attributes)] - if (vis.mark =="bar" or vis.mark =="line"): - PandasExecutor.execute_aggregate(vis,isFiltered = filter_executed) - elif (vis.mark =="histogram"): + if vis.mark == "bar" or vis.mark == "line": + PandasExecutor.execute_aggregate(vis, isFiltered=filter_executed) + elif vis.mark == "histogram": PandasExecutor.execute_binning(vis) - elif (vis.mark =="scatter"): + elif vis.mark == "scatter": HBIN_START = 5000 - if (len(ldf)>HBIN_START): + if len(ldf) > HBIN_START: vis._postbin = True - ldf._message.add_unique(f"Large scatterplots detected: Lux is automatically binning scatterplots to heatmaps.", priority=98) + ldf._message.add_unique( + f"Large scatterplots detected: Lux is automatically binning scatterplots to heatmaps.", + priority=98, + ) # vis._mark = "heatmap" - # PandasExecutor.execute_2D_binning(vis) # Lazy Evaluation (Early pruning based on interestingness) - - + # PandasExecutor.execute_2D_binning(vis) # Lazy Evaluation (Early pruning based on interestingness) @staticmethod - def execute_aggregate(vis: Vis,isFiltered = True): - ''' + def execute_aggregate(vis: Vis, isFiltered=True): + """ Aggregate data points on an axis for bar or line charts Parameters @@ -110,92 +122,128 @@ def execute_aggregate(vis: Vis,isFiltered = True): Returns ------- None - ''' + """ import numpy as np x_attr = vis.get_attr_by_channel("x")[0] y_attr = vis.get_attr_by_channel("y")[0] has_color = False - groupby_attr ="" - measure_attr ="" - if (x_attr.aggregation is None or y_attr.aggregation is None): + groupby_attr = "" + measure_attr = "" + if x_attr.aggregation is None or y_attr.aggregation is None: return - if (y_attr.aggregation!=""): + if y_attr.aggregation != "": groupby_attr = x_attr measure_attr = y_attr agg_func = y_attr.aggregation - if (x_attr.aggregation!=""): + if x_attr.aggregation != "": groupby_attr = y_attr measure_attr = x_attr agg_func = x_attr.aggregation - if (groupby_attr.attribute in vis.data.unique_values.keys()): + if groupby_attr.attribute in vis.data.unique_values.keys(): attr_unique_vals = vis.data.unique_values[groupby_attr.attribute] - #checks if color is specified in the Vis + # checks if color is specified in the Vis if len(vis.get_attr_by_channel("color")) == 1: color_attr = vis.get_attr_by_channel("color")[0] color_attr_vals = vis.data.unique_values[color_attr.attribute] color_cardinality = len(color_attr_vals) - #NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable---------------- + # NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable---------------- has_color = True else: color_cardinality = 1 - - if (measure_attr!=""): - if (measure_attr.attribute=="Record"): + + if measure_attr != "": + if measure_attr.attribute == "Record": vis._vis_data = vis.data.reset_index() - #if color is specified, need to group by groupby_attr and color_attr + # if color is specified, need to group by groupby_attr and color_attr if has_color: - vis._vis_data = vis.data.groupby([groupby_attr.attribute, color_attr.attribute]).count().reset_index() - vis._vis_data = vis.data.rename(columns={"index":"Record"}) - vis._vis_data = vis.data[[groupby_attr.attribute,color_attr.attribute,"Record"]] + vis._vis_data = ( + vis.data.groupby([groupby_attr.attribute, color_attr.attribute]) + .count() + .reset_index() + ) + vis._vis_data = vis.data.rename(columns={"index": "Record"}) + vis._vis_data = vis.data[ + [groupby_attr.attribute, color_attr.attribute, "Record"] + ] else: - vis._vis_data = vis.data.groupby(groupby_attr.attribute).count().reset_index() - vis._vis_data = vis.data.rename(columns={"index":"Record"}) - vis._vis_data = vis.data[[groupby_attr.attribute,"Record"]] + vis._vis_data = ( + vis.data.groupby(groupby_attr.attribute).count().reset_index() + ) + vis._vis_data = vis.data.rename(columns={"index": "Record"}) + vis._vis_data = vis.data[[groupby_attr.attribute, "Record"]] else: - #if color is specified, need to group by groupby_attr and color_attr + # if color is specified, need to group by groupby_attr and color_attr if has_color: - groupby_result = vis.data.groupby([groupby_attr.attribute, color_attr.attribute]) + groupby_result = vis.data.groupby( + [groupby_attr.attribute, color_attr.attribute] + ) else: groupby_result = vis.data.groupby(groupby_attr.attribute) groupby_result = groupby_result.agg(agg_func) intermediate = groupby_result.reset_index() vis._vis_data = intermediate.__finalize__(vis.data) result_vals = list(vis.data[groupby_attr.attribute]) - #create existing group by attribute combinations if color is specified - #this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them + # create existing group by attribute combinations if color is specified + # this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them if has_color: res_color_combi_vals = [] result_color_vals = list(vis.data[color_attr.attribute]) for i in range(0, len(result_vals)): res_color_combi_vals.append([result_vals[i], result_color_vals[i]]) # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints - if (isFiltered or has_color and attr_unique_vals): + if isFiltered or has_color and attr_unique_vals: N_unique_vals = len(attr_unique_vals) - if (len(result_vals) != N_unique_vals*color_cardinality): + if len(result_vals) != N_unique_vals * color_cardinality: columns = vis.data.columns if has_color: - df = pd.DataFrame({columns[0]: attr_unique_vals*color_cardinality, columns[1]: pd.Series(color_attr_vals).repeat(N_unique_vals)}) - vis._vis_data = vis.data.merge(df, on=[columns[0],columns[1]], how='right', suffixes=['', '_right']) + df = pd.DataFrame( + { + columns[0]: attr_unique_vals * color_cardinality, + columns[1]: pd.Series(color_attr_vals).repeat( + N_unique_vals + ), + } + ) + vis._vis_data = vis.data.merge( + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], + ) for col in columns[2:]: - vis.data[col] = vis.data[col].fillna(0) #Triggers __setitem__ - assert len(list(vis.data[groupby_attr.attribute])) == N_unique_vals*len(color_attr_vals), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute, color_attr.attribute}`." - vis._vis_data = vis.data.iloc[:,:3] # Keep only the three relevant columns not the *_right columns resulting from merge + vis.data[col] = vis.data[col].fillna( + 0 + ) # Triggers __setitem__ + assert len( + list(vis.data[groupby_attr.attribute]) + ) == N_unique_vals * len( + color_attr_vals + ), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute, color_attr.attribute}`." + vis._vis_data = vis.data.iloc[ + :, :3 + ] # Keep only the three relevant columns not the *_right columns resulting from merge else: df = pd.DataFrame({columns[0]: attr_unique_vals}) - - vis._vis_data = vis.data.merge(df, on=columns[0], how='right', suffixes=['', '_right']) + + vis._vis_data = vis.data.merge( + df, on=columns[0], how="right", suffixes=["", "_right"] + ) for col in columns[1:]: vis.data[col] = vis.data[col].fillna(0) - assert len(list(vis.data[groupby_attr.attribute])) == N_unique_vals, f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute}`." - vis._vis_data = vis.data.sort_values(by=groupby_attr.attribute, ascending=True) + assert ( + len(list(vis.data[groupby_attr.attribute])) == N_unique_vals + ), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute}`." + vis._vis_data = vis.data.sort_values( + by=groupby_attr.attribute, ascending=True + ) vis._vis_data = vis.data.reset_index() vis._vis_data = vis.data.drop(columns="index") @staticmethod def execute_binning(vis: Vis): - ''' + """ Binning of data points for generating histograms Parameters @@ -208,35 +256,48 @@ def execute_binning(vis: Vis): Returns ------- None - ''' + """ import numpy as np - bin_attribute = list(filter(lambda x: x.bin_size!=0,vis._inferred_intent))[0] + + bin_attribute = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0] if not np.isnan(vis.data[bin_attribute.attribute]).all(): - series = vis.data[bin_attribute.attribute].dropna() # np.histogram breaks if array contain NaN - #TODO:binning runs for name attribte. Name attribute has datatype quantitative which is wrong. - counts,bin_edges = np.histogram(series,bins=bin_attribute.bin_size) - #bin_edges of size N+1, so need to compute bin_center as the bin location - bin_center = np.mean(np.vstack([bin_edges[0:-1],bin_edges[1:]]), axis=0) + series = vis.data[ + bin_attribute.attribute + ].dropna() # np.histogram breaks if array contain NaN + # TODO:binning runs for name attribte. Name attribute has datatype quantitative which is wrong. + counts, bin_edges = np.histogram(series, bins=bin_attribute.bin_size) + # bin_edges of size N+1, so need to compute bin_center as the bin location + bin_center = np.mean(np.vstack([bin_edges[0:-1], bin_edges[1:]]), axis=0) # TODO: Should vis.data be a LuxDataFrame or a Pandas DataFrame? - vis._vis_data = pd.DataFrame(np.array([bin_center,counts]).T,columns=[bin_attribute.attribute, "Number of Records"]) + vis._vis_data = pd.DataFrame( + np.array([bin_center, counts]).T, + columns=[bin_attribute.attribute, "Number of Records"], + ) @staticmethod def execute_filter(vis: Vis): - assert vis.data is not None, "execute_filter assumes input vis.data is populated (if not, populate with LuxDataFrame values)" + assert ( + vis.data is not None + ), "execute_filter assumes input vis.data is populated (if not, populate with LuxDataFrame values)" filters = utils.get_filter_specs(vis._inferred_intent) - - if (filters): + + if filters: # TODO: Need to handle OR logic for filter in filters: - vis._vis_data = PandasExecutor.apply_filter(vis.data, filter.attribute, filter.filter_op, filter.value) + vis._vis_data = PandasExecutor.apply_filter( + vis.data, filter.attribute, filter.filter_op, filter.value + ) return True else: return False + @staticmethod - def apply_filter(df: pd.DataFrame, attribute:str, op: str, val: object) -> pd.DataFrame: + def apply_filter( + df: pd.DataFrame, attribute: str, op: str, val: object + ) -> pd.DataFrame: """ Helper function for applying filter to a dataframe - + Parameters ---------- df : pandas.DataFrame @@ -246,69 +307,84 @@ def apply_filter(df: pd.DataFrame, attribute:str, op: str, val: object) -> pd.Da op : str Filter operation, '=', '<', '>', '<=', '>=', '!=' val : object - Filter value - + Filter value + Returns ------- df: pandas.DataFrame Dataframe resulting from the filter operation - """ - if (op == '='): + """ + if op == "=": return df[df[attribute] == val] - elif (op == '<'): + elif op == "<": return df[df[attribute] < val] - elif (op == '>'): + elif op == ">": return df[df[attribute] > val] - elif (op == '<='): + elif op == "<=": return df[df[attribute] <= val] - elif (op == '>='): + elif op == ">=": return df[df[attribute] >= val] - elif (op == '!='): + elif op == "!=": return df[df[attribute] != val] return df + @staticmethod def execute_2D_binning(vis: Vis): - pd.reset_option('mode.chained_assignment') - with pd.option_context('mode.chained_assignment', None): + pd.reset_option("mode.chained_assignment") + with pd.option_context("mode.chained_assignment", None): x_attr = vis.get_attr_by_channel("x")[0] y_attr = vis.get_attr_by_channel("y")[0] - - vis._vis_data.loc[:,"xBin"] = pd.cut(vis._vis_data[x_attr.attribute], bins=40) - vis._vis_data.loc[:,"yBin"] = pd.cut(vis._vis_data[y_attr.attribute], bins=40) + + vis._vis_data.loc[:, "xBin"] = pd.cut( + vis._vis_data[x_attr.attribute], bins=40 + ) + vis._vis_data.loc[:, "yBin"] = pd.cut( + vis._vis_data[y_attr.attribute], bins=40 + ) color_attr = vis.get_attr_by_channel("color") - if (len(color_attr)>0): + if len(color_attr) > 0: color_attr = color_attr[0] - groups = vis._vis_data.groupby(['xBin','yBin'])[color_attr.attribute] - if (color_attr.data_type == "nominal"): + groups = vis._vis_data.groupby(["xBin", "yBin"])[color_attr.attribute] + if color_attr.data_type == "nominal": # Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0]) - result = groups.agg([("count","count"), - (color_attr.attribute,lambda x: pd.Series.mode(x).iat[0]) - ]).reset_index() - elif (color_attr.data_type == "quantitative"): + result = groups.agg( + [ + ("count", "count"), + (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]), + ] + ).reset_index() + elif color_attr.data_type == "quantitative": # Compute the average of all values in the bin - result = groups.agg([("count","count"), - (color_attr.attribute,"mean") - ]).reset_index() + result = groups.agg( + [("count", "count"), (color_attr.attribute, "mean")] + ).reset_index() result = result.dropna() else: - groups = vis._vis_data.groupby(['xBin','yBin'])[x_attr.attribute] - result = groups.agg("count").reset_index(name=x_attr.attribute) # .agg in this line throws SettingWithCopyWarning - result = result.rename(columns={x_attr.attribute:"count"}) - result = result[result["count"]!=0] + groups = vis._vis_data.groupby(["xBin", "yBin"])[x_attr.attribute] + result = groups.agg("count").reset_index( + name=x_attr.attribute + ) # .agg in this line throws SettingWithCopyWarning + result = result.rename(columns={x_attr.attribute: "count"}) + result = result[result["count"] != 0] # convert type to facilitate weighted correlation interestingess calculation - result.loc[:,"xBinStart"] = result["xBin"].apply(lambda x: x.left).astype('float') - result.loc[:,"xBinEnd"] = result["xBin"].apply(lambda x: x.right) + result.loc[:, "xBinStart"] = ( + result["xBin"].apply(lambda x: x.left).astype("float") + ) + result.loc[:, "xBinEnd"] = result["xBin"].apply(lambda x: x.right) - result.loc[:,"yBinStart"] = result["yBin"].apply(lambda x: x.left).astype('float') - result.loc[:,"yBinEnd"] = result["yBin"].apply(lambda x: x.right) + result.loc[:, "yBinStart"] = ( + result["yBin"].apply(lambda x: x.left).astype("float") + ) + result.loc[:, "yBinEnd"] = result["yBin"].apply(lambda x: x.right) + + vis._vis_data = result.drop(columns=["xBin", "yBin"]) - vis._vis_data = result.drop(columns=["xBin","yBin"]) ####################################################### ############ Metadata: data type, model ############# ####################################################### - def compute_dataset_metadata(self, ldf:LuxDataFrame): + def compute_dataset_metadata(self, ldf: LuxDataFrame): ldf.data_type_lookup = {} ldf.data_type = {} self.compute_data_type(ldf) @@ -316,10 +392,10 @@ def compute_dataset_metadata(self, ldf:LuxDataFrame): ldf.data_model = {} self.compute_data_model(ldf) - def compute_data_type(self, ldf:LuxDataFrame): + def compute_data_type(self, ldf: LuxDataFrame): for attr in list(ldf.columns): - temporal_var_list = ["month", "year","day","date","time"] - if (isinstance(attr,pd._libs.tslibs.timestamps.Timestamp)): + temporal_var_list = ["month", "year", "day", "date", "time"] + if isinstance(attr, pd._libs.tslibs.timestamps.Timestamp): # If timestamp, make the dictionary keys the _repr_ (e.g., TimeStamp('2020-04-05 00.000')--> '2020-04-05') ldf.data_type_lookup[attr] = "temporal" # elif any(var in str(attr).lower() for var in temporal_var_list): @@ -327,71 +403,81 @@ def compute_data_type(self, ldf:LuxDataFrame): ldf.data_type_lookup[attr] = "temporal" elif pd.api.types.is_float_dtype(ldf.dtypes[attr]): ldf.data_type_lookup[attr] = "quantitative" - elif pd.api.types.is_integer_dtype(ldf.dtypes[attr]): + elif pd.api.types.is_integer_dtype(ldf.dtypes[attr]): # See if integer value is quantitative or nominal by checking if the ratio of cardinality/data size is less than 0.4 and if there are less than 10 unique values - if (ldf.pre_aggregated): - if (ldf.cardinality[attr]==len(ldf)): + if ldf.pre_aggregated: + if ldf.cardinality[attr] == len(ldf): ldf.data_type_lookup[attr] = "nominal" - if ldf.cardinality[attr]/len(ldf) < 0.4 and ldf.cardinality[attr]<20: + if ( + ldf.cardinality[attr] / len(ldf) < 0.4 + and ldf.cardinality[attr] < 20 + ): ldf.data_type_lookup[attr] = "nominal" else: ldf.data_type_lookup[attr] = "quantitative" - if check_if_id_like(ldf,attr): + if check_if_id_like(ldf, attr): ldf.data_type_lookup[attr] = "id" # Eliminate this clause because a single NaN value can cause the dtype to be object elif pd.api.types.is_string_dtype(ldf.dtypes[attr]): - if check_if_id_like(ldf,attr): + if check_if_id_like(ldf, attr): ldf.data_type_lookup[attr] = "id" else: ldf.data_type_lookup[attr] = "nominal" - elif is_datetime_series(ldf.dtypes[attr]): #check if attribute is any type of datetime dtype + elif is_datetime_series( + ldf.dtypes[attr] + ): # check if attribute is any type of datetime dtype ldf.data_type_lookup[attr] = "temporal" else: - ldf.data_type_lookup[attr] = "nominal" + ldf.data_type_lookup[attr] = "nominal" # for attr in list(df.dtypes[df.dtypes=="int64"].keys()): # if self.cardinality[attr]>50: - if (ldf.index.dtype !='int64' and ldf.index.name): + if ldf.index.dtype != "int64" and ldf.index.name: ldf.data_type_lookup[ldf.index.name] = "nominal" ldf.data_type = self.mapping(ldf.data_type_lookup) from pandas.api.types import is_datetime64_any_dtype as is_datetime + non_datetime_attrs = [] for attr in ldf.columns: - if ldf.data_type_lookup[attr] == 'temporal' and not is_datetime(ldf[attr]): + if ldf.data_type_lookup[attr] == "temporal" and not is_datetime(ldf[attr]): non_datetime_attrs.append(attr) if len(non_datetime_attrs) == 1: warnings.warn( - f"\nLux detects that the attribute '{non_datetime_attrs[0]}' may be temporal.\n" - "In order to display visualizations for this attribute accurately, temporal attributes should be converted to Pandas Datetime objects.\n\n" - "Please consider converting this attribute using the pd.to_datetime function and providing a 'format' parameter to specify datetime format of the attribute.\n" - "For example, you can convert the 'month' attribute in a dataset to Datetime type via the following command:\n\n\t df['month'] = pd.to_datetime(df['month'], format='%m')\n\n" - "See more at: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html\n" - ,stacklevel=2) + f"\nLux detects that the attribute '{non_datetime_attrs[0]}' may be temporal.\n" + "In order to display visualizations for this attribute accurately, temporal attributes should be converted to Pandas Datetime objects.\n\n" + "Please consider converting this attribute using the pd.to_datetime function and providing a 'format' parameter to specify datetime format of the attribute.\n" + "For example, you can convert the 'month' attribute in a dataset to Datetime type via the following command:\n\n\t df['month'] = pd.to_datetime(df['month'], format='%m')\n\n" + "See more at: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html\n", + stacklevel=2, + ) elif len(non_datetime_attrs) > 1: warnings.warn( - f"\nLux detects that attributes {non_datetime_attrs} may be temporal.\n" - "In order to display visualizations for these attributes accurately, temporal attributes should be converted to Pandas Datetime objects.\n\n" - "Please consider converting these attributes using the pd.to_datetime function and providing a 'format' parameter to specify datetime format of the attribute.\n" - "For example, you can convert the 'month' attribute in a dataset to Datetime type via the following command:\n\n\t df['month'] = pd.to_datetime(df['month'], format='%m')\n\n" - "See more at: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html\n" - ,stacklevel=2) - def compute_data_model(self, ldf:LuxDataFrame): + f"\nLux detects that attributes {non_datetime_attrs} may be temporal.\n" + "In order to display visualizations for these attributes accurately, temporal attributes should be converted to Pandas Datetime objects.\n\n" + "Please consider converting these attributes using the pd.to_datetime function and providing a 'format' parameter to specify datetime format of the attribute.\n" + "For example, you can convert the 'month' attribute in a dataset to Datetime type via the following command:\n\n\t df['month'] = pd.to_datetime(df['month'], format='%m')\n\n" + "See more at: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html\n", + stacklevel=2, + ) + + def compute_data_model(self, ldf: LuxDataFrame): ldf.data_model = { "measure": ldf.data_type["quantitative"], - "dimension": ldf.data_type["nominal"] + ldf.data_type["temporal"] + ldf.data_type["id"] + "dimension": ldf.data_type["nominal"] + + ldf.data_type["temporal"] + + ldf.data_type["id"], } ldf.data_model_lookup = self.reverseMapping(ldf.data_model) - - def compute_stats(self, ldf:LuxDataFrame): + def compute_stats(self, ldf: LuxDataFrame): # precompute statistics ldf.unique_values = {} ldf._min_max = {} ldf.cardinality = {} for attribute in ldf.columns: - - if (isinstance(attribute,pd._libs.tslibs.timestamps.Timestamp)): + + if isinstance(attribute, pd._libs.tslibs.timestamps.Timestamp): # If timestamp, make the dictionary keys the _repr_ (e.g., TimeStamp('2020-04-05 00.000')--> '2020-04-05') attribute_repr = str(attribute._date_repr) else: @@ -399,19 +485,22 @@ def compute_stats(self, ldf:LuxDataFrame): ldf.unique_values[attribute_repr] = list(ldf[attribute_repr].unique()) ldf.cardinality[attribute_repr] = len(ldf.unique_values[attribute_repr]) - + # commenting this optimization out to make sure I can filter by cardinality when showing recommended vis # if ldf.dtypes[attribute] != "float64":# and not pd.api.types.is_datetime64_ns_dtype(self.dtypes[attribute]): # ldf.unique_values[attribute_repr] = list(ldf[attribute].unique()) # ldf.cardinality[attribute_repr] = len(ldf.unique_values[attribute]) - # else: + # else: # ldf.cardinality[attribute_repr] = 999 # special value for non-numeric attribute - + if ldf.dtypes[attribute] == "float64" or ldf.dtypes[attribute] == "int64": - ldf._min_max[attribute_repr] = (ldf[attribute].min(), ldf[attribute].max()) + ldf._min_max[attribute_repr] = ( + ldf[attribute].min(), + ldf[attribute].max(), + ) - if (ldf.index.dtype !='int64'): + if ldf.index.dtype != "int64": index_column_name = ldf.index.name ldf.unique_values[index_column_name] = list(ldf.index) - ldf.cardinality[index_column_name] = len(ldf.index) \ No newline at end of file + ldf.cardinality[index_column_name] = len(ldf.index) diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index 65500d55..2cca392d 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,10 +20,12 @@ from lux.utils import utils import math + class SQLExecutor(Executor): """ Given a Vis objects with complete specifications, fetch and process data using SQL operations. """ + def __init__(self): self.name = "Executor" self.selection = [] @@ -34,150 +36,225 @@ def __repr__(self): return f"" @staticmethod - def execute(vislist:VisList, ldf: LuxDataFrame): + def execute(vislist: VisList, ldf: LuxDataFrame): import pandas as pd - ''' + + """ Given a VisList, fetch the data required to render the vis 1) Apply filters 2) Retreive relevant attribute 3) return a DataFrame with relevant results - ''' + """ for vis in vislist: # Select relevant data based on attribute information attributes = set([]) for clause in vis._inferred_intent: - if (clause.attribute): - if (clause.attribute=="Record"): + if clause.attribute: + if clause.attribute == "Record": attributes.add(clause.attribute) - #else: + # else: attributes.add(clause.attribute) if vis.mark not in ["bar", "line", "histogram"]: where_clause, filterVars = SQLExecutor.execute_filter(vis) required_variables = attributes | set(filterVars) required_variables = ",".join(required_variables) - row_count = list(pd.read_sql("SELECT COUNT(*) FROM {} {}".format(ldf.table_name, where_clause), ldf.SQLconnection)['count'])[0] + row_count = list( + pd.read_sql( + "SELECT COUNT(*) FROM {} {}".format( + ldf.table_name, where_clause + ), + ldf.SQLconnection, + )["count"] + )[0] if row_count > 10000: - query = "SELECT {} FROM {} {} ORDER BY random() LIMIT 10000".format(required_variables, ldf.table_name, where_clause) + query = "SELECT {} FROM {} {} ORDER BY random() LIMIT 10000".format( + required_variables, ldf.table_name, where_clause + ) else: - query = "SELECT {} FROM {} {}".format(required_variables, ldf.table_name, where_clause) + query = "SELECT {} FROM {} {}".format( + required_variables, ldf.table_name, where_clause + ) data = pd.read_sql(query, ldf.SQLconnection) vis._vis_data = utils.pandas_to_lux(data) - if (vis.mark =="bar" or vis.mark =="line"): + if vis.mark == "bar" or vis.mark == "line": SQLExecutor.execute_aggregate(vis, ldf) - elif (vis.mark =="histogram"): + elif vis.mark == "histogram": SQLExecutor.execute_binning(vis, ldf) @staticmethod - def execute_aggregate(vis:Vis, ldf:LuxDataFrame): + def execute_aggregate(vis: Vis, ldf: LuxDataFrame): import pandas as pd + x_attr = vis.get_attr_by_channel("x")[0] y_attr = vis.get_attr_by_channel("y")[0] - groupby_attr ="" - measure_attr ="" - if (y_attr.aggregation!=""): + groupby_attr = "" + measure_attr = "" + if y_attr.aggregation != "": groupby_attr = x_attr measure_attr = y_attr agg_func = y_attr.aggregation - if (x_attr.aggregation!=""): + if x_attr.aggregation != "": groupby_attr = y_attr measure_attr = x_attr agg_func = x_attr.aggregation - - if (measure_attr!=""): - #barchart case, need count data for each group - if (measure_attr.attribute=="Record"): + + if measure_attr != "": + # barchart case, need count data for each group + if measure_attr.attribute == "Record": where_clause, filterVars = SQLExecutor.execute_filter(vis) - count_query = "SELECT {}, COUNT({}) FROM {} {} GROUP BY {}".format(groupby_attr.attribute, groupby_attr.attribute, ldf.table_name, where_clause, groupby_attr.attribute) + count_query = "SELECT {}, COUNT({}) FROM {} {} GROUP BY {}".format( + groupby_attr.attribute, + groupby_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + ) vis._vis_data = pd.read_sql(count_query, ldf.SQLconnection) - vis._vis_data = vis.data.rename(columns={"count":"Record"}) + vis._vis_data = vis.data.rename(columns={"count": "Record"}) vis._vis_data = utils.pandas_to_lux(vis.data) else: where_clause, filterVars = SQLExecutor.execute_filter(vis) if agg_func == "mean": - mean_query = "SELECT {}, AVG({}) as {} FROM {} {} GROUP BY {}".format(groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, ldf.table_name, where_clause, groupby_attr.attribute) + mean_query = ( + "SELECT {}, AVG({}) as {} FROM {} {} GROUP BY {}".format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + ) + ) vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) vis._vis_data = utils.pandas_to_lux(vis.data) if agg_func == "sum": - mean_query = "SELECT {}, SUM({}) as {} FROM {} {} GROUP BY {}".format(groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, ldf.table_name, where_clause, groupby_attr.attribute) + mean_query = ( + "SELECT {}, SUM({}) as {} FROM {} {} GROUP BY {}".format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + ) + ) vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) vis._vis_data = utils.pandas_to_lux(vis.data) if agg_func == "max": - mean_query = "SELECT {}, MAX({}) as {} FROM {} {} GROUP BY {}".format(groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, ldf.table_name, where_clause, groupby_attr.attribute) + mean_query = ( + "SELECT {}, MAX({}) as {} FROM {} {} GROUP BY {}".format( + groupby_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + ) + ) vis._vis_data = pd.read_sql(mean_query, ldf.SQLconnection) vis._vis_data = utils.pandas_to_lux(vis.data) - #pad empty categories with 0 counts after filter is applied + # pad empty categories with 0 counts after filter is applied all_attr_vals = ldf.unique_values[groupby_attr.attribute] result_vals = list(vis.data[groupby_attr.attribute]) - if (len(result_vals) != len(all_attr_vals)): + if len(result_vals) != len(all_attr_vals): # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints for vals in all_attr_vals: - if (vals not in result_vals): - vis.data.loc[len(vis.data)] = [vals]+[0]*(len(vis.data.columns)-1) + if vals not in result_vals: + vis.data.loc[len(vis.data)] = [vals] + [0] * ( + len(vis.data.columns) - 1 + ) + @staticmethod - def execute_binning(vis:Vis, ldf:LuxDataFrame): + def execute_binning(vis: Vis, ldf: LuxDataFrame): import numpy as np import pandas as pd - bin_attribute = list(filter(lambda x: x.bin_size!=0,vis._inferred_intent))[0] - if not math.isnan(vis.data.min_max[bin_attribute.attribute][0]) and math.isnan(vis.data.min_max[bin_attribute.attribute][1]): + + bin_attribute = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0] + if not math.isnan(vis.data.min_max[bin_attribute.attribute][0]) and math.isnan( + vis.data.min_max[bin_attribute.attribute][1] + ): num_bins = bin_attribute.bin_size attr_min = min(ldf.unique_values[bin_attribute.attribute]) attr_max = max(ldf.unique_values[bin_attribute.attribute]) attr_type = type(ldf.unique_values[bin_attribute.attribute][0]) - #need to calculate the bin edges before querying for the relevant data - bin_width = (attr_max-attr_min)/num_bins + # need to calculate the bin edges before querying for the relevant data + bin_width = (attr_max - attr_min) / num_bins upper_edges = [] for e in range(1, num_bins): - curr_edge = attr_min + e*bin_width + curr_edge = attr_min + e * bin_width if attr_type == int: upper_edges.append(str(math.ceil(curr_edge))) else: upper_edges.append(str(curr_edge)) upper_edges = ",".join(upper_edges) vis_filter, filter_vars = SQLExecutor.execute_filter(vis) - bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(bin_attribute.attribute, '{'+upper_edges+'}', ldf.table_name) + bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( + bin_attribute.attribute, "{" + upper_edges + "}", ldf.table_name + ) bin_count_data = pd.read_sql(bin_count_query, ldf.SQLconnection) - #counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size) - #binEdges of size N+1, so need to compute binCenter as the bin location + # counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size) + # binEdges of size N+1, so need to compute binCenter as the bin location upper_edges = [float(i) for i in upper_edges.split(",")] if attr_type == int: - bin_centers = np.array([math.ceil((attr_min+attr_min+bin_width)/2)]) + bin_centers = np.array( + [math.ceil((attr_min + attr_min + bin_width) / 2)] + ) else: - bin_centers = np.array([(attr_min+attr_min+bin_width)/2]) - bin_centers = np.append(bin_centers, np.mean(np.vstack([upper_edges[0:-1],upper_edges[1:]]), axis=0)) + bin_centers = np.array([(attr_min + attr_min + bin_width) / 2]) + bin_centers = np.append( + bin_centers, + np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), + ) if attr_type == int: - bin_centers = np.append(bin_centers, math.ceil((upper_edges[len(upper_edges)-1]+attr_max)/2)) + bin_centers = np.append( + bin_centers, + math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), + ) else: - bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges)-1]+attr_max)/2) + bin_centers = np.append( + bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2 + ) if len(bin_centers) > len(bin_count_data): - bucket_lables = bin_count_data['width_bucket'].unique() - for i in range(0,len(bin_centers)): + bucket_lables = bin_count_data["width_bucket"].unique() + for i in range(0, len(bin_centers)): if i not in bucket_lables: - bin_count_data = bin_count_data.append(pd.DataFrame([[i,0]], columns = bin_count_data.columns)) - vis._vis_data = pd.DataFrame(np.array([bin_centers,list(bin_count_data['count'])]).T,columns=[bin_attribute.attribute, "Number of Records"]) + bin_count_data = bin_count_data.append( + pd.DataFrame([[i, 0]], columns=bin_count_data.columns) + ) + vis._vis_data = pd.DataFrame( + np.array([bin_centers, list(bin_count_data["count"])]).T, + columns=[bin_attribute.attribute, "Number of Records"], + ) vis._vis_data = utils.pandas_to_lux(vis.data) - + @staticmethod - #takes in a vis and returns an appropriate SQL WHERE clause that based on the filters specified in the vis's _inferred_intent - def execute_filter(vis:Vis): + # takes in a vis and returns an appropriate SQL WHERE clause that based on the filters specified in the vis's _inferred_intent + def execute_filter(vis: Vis): where_clause = [] filters = utils.get_filter_specs(vis._inferred_intent) filter_vars = [] - if (filters): - for f in range(0,len(filters)): + if filters: + for f in range(0, len(filters)): if f == 0: where_clause.append("WHERE") else: where_clause.append("AND") - where_clause.extend([str(filters[f].attribute), str(filters[f].filter_op), "'" + str(filters[f].value) + "'"]) + where_clause.extend( + [ + str(filters[f].attribute), + str(filters[f].filter_op), + "'" + str(filters[f].value) + "'", + ] + ) if filters[f].attribute not in filter_vars: filter_vars.append(filters[f].attribute) if where_clause == []: - return("", []) + return ("", []) else: where_clause = " ".join(where_clause) - return(where_clause, filter_vars) \ No newline at end of file + return (where_clause, filter_vars) diff --git a/lux/executor/__init__.py b/lux/executor/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/executor/__init__.py +++ b/lux/executor/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/history/__init__.py b/lux/history/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/history/__init__.py +++ b/lux/history/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/history/event.py b/lux/history/event.py index 9313b8ae..48853aef 100644 --- a/lux/history/event.py +++ b/lux/history/event.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -class Event(): - """ - Event represents a single operation applied to the dataframe, with input arguments of operation recorded - """ - def __init__(self,name,*args,**kwargs): - self.name = name - self.args = args - self.kwargs = kwargs - def __repr__(self): - if (self.args==() and self.kwargs=={}): - return f"" - else: - return f"" \ No newline at end of file + +class Event: + """ + Event represents a single operation applied to the dataframe, with input arguments of operation recorded + """ + + def __init__(self, name, *args, **kwargs): + self.name = name + self.args = args + self.kwargs = kwargs + + def __repr__(self): + if self.args == () and self.kwargs == {}: + return f"" + else: + return f"" diff --git a/lux/history/history.py b/lux/history/history.py index 4e78cbe5..602d0d11 100644 --- a/lux/history/history.py +++ b/lux/history/history.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,25 +15,32 @@ from __future__ import annotations from typing import List, Union, Callable, Dict from lux.history.event import Event -class History(): - """ - History maintains a list of past Pandas operations performed on the dataframe - Currently only supports custom overridden functions (head, tail, info, describe) - """ - def __init__(self): - self._events=[] - - def __getitem__(self, key): - return self._events[key] - def __setitem__(self, key, value): - self._events[key] = value - def __len__(self): - return len(self._events) - def __repr__(self): - event_repr=[] - for event in self._events: - event_repr.append(event.__repr__()) - return "["+'\n'.join(event_repr)+"]" - def append_event(self,name,*args,**kwargs): - event = Event(name,*args,**kwargs) - self._events.append(event) \ No newline at end of file + + +class History: + """ + History maintains a list of past Pandas operations performed on the dataframe + Currently only supports custom overridden functions (head, tail, info, describe) + """ + + def __init__(self): + self._events = [] + + def __getitem__(self, key): + return self._events[key] + + def __setitem__(self, key, value): + self._events[key] = value + + def __len__(self): + return len(self._events) + + def __repr__(self): + event_repr = [] + for event in self._events: + event_repr.append(event.__repr__()) + return "[" + "\n".join(event_repr) + "]" + + def append_event(self, name, *args, **kwargs): + event = Event(name, *args, **kwargs) + self._events.append(event) diff --git a/lux/interestingness/__init__.py b/lux/interestingness/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/interestingness/__init__.py +++ b/lux/interestingness/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index cc8c1927..9d175583 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,273 +21,324 @@ import numpy as np from pandas.api.types import is_datetime64_any_dtype as is_datetime from scipy.spatial.distance import euclidean -def interestingness(vis:Vis ,ldf:LuxDataFrame) -> int: - """ - Compute the interestingness score of the vis. - The interestingness metric is dependent on the vis type. - - Parameters - ---------- - vis : Vis - ldf : LuxDataFrame - - Returns - ------- - int - Interestingness Score - """ - - - if vis.data is None or len(vis.data)==0: - return -1 - # raise Exception("Vis.data needs to be populated before interestingness can be computed. Run Executor.execute(vis,ldf).") - - n_dim = 0 - n_msr = 0 - - filter_specs = utils.get_filter_specs(vis._inferred_intent) - vis_attrs_specs = utils.get_attrs_specs(vis._inferred_intent) - - record_attrs = list(filter(lambda x: x.attribute=="Record" and x.data_model=="measure", vis_attrs_specs)) - n_record = len(record_attrs) - for clause in vis_attrs_specs: - if (clause.attribute!="Record"): - if (clause.data_model == 'dimension'): - n_dim += 1 - if (clause.data_model == 'measure'): - n_msr += 1 - n_filter = len(filter_specs) - attr_specs = [clause for clause in vis_attrs_specs if clause.attribute != "Record"] - dimension_lst = vis.get_attr_by_data_model("dimension") - measure_lst = vis.get_attr_by_data_model("measure") - v_size = len(vis.data) - # Line/Bar Chart - #print("r:", n_record, "m:", n_msr, "d:",n_dim) - if (n_dim == 1 and (n_msr==0 or n_msr==1)): - if (v_size<2): return -1 - if (n_filter == 0): - return unevenness(vis, ldf, measure_lst, dimension_lst) - elif(n_filter==1): - return deviation_from_overall(vis, ldf, filter_specs, measure_lst[0].attribute) - # Histogram - elif (n_dim == 0 and n_msr == 1): - if (v_size<2): return -1 - if (n_filter == 0 and "Number of Records" in vis.data): - if "Number of Records" in vis.data: - v = vis.data["Number of Records"] - return skewness(v) - elif (n_filter == 1 and "Number of Records" in vis.data): - return deviation_from_overall(vis, ldf, filter_specs, "Number of Records") - return -1 - # Scatter Plot - elif (n_dim == 0 and n_msr == 2): - if (v_size<10): return -1 - if (vis.mark=="heatmap"): - return weighted_correlation(vis.data["xBinStart"],vis.data["yBinStart"],vis.data["count"]) - if (n_filter==1): - v_filter_size = get_filtered_size(filter_specs, vis.data) - sig = v_filter_size/v_size - else: - sig = 1 - return sig * monotonicity(vis,attr_specs) - # Scatterplot colored by Dimension - elif (n_dim == 1 and n_msr == 2): - if (v_size<10): return -1 - color_attr = vis.get_attr_by_channel("color")[0].attribute - - C = ldf.cardinality[color_attr] - if (C<40): - return 1/C - else: - return -1 - # Scatterplot colored by dimension - elif (n_dim== 1 and n_msr == 2): - return 0.2 - # Scatterplot colored by measure - elif (n_msr == 3): - return 0.1 - # colored line and barchart cases - elif (vis.mark == "line" and n_dim == 2): - return 0.15 - #for colored bar chart, scoring based on Chi-square test for independence score. - #gives higher scores to colored bar charts with fewer total categories as these charts are easier to read and thus more useful for users - elif (vis.mark == "bar" and n_dim == 2): - from scipy.stats import chi2_contingency - measure_column = vis.get_attr_by_data_model("measure")[0].attribute - dimension_columns = vis.get_attr_by_data_model("dimension") - - groupby_column = dimension_columns[0].attribute - color_column = dimension_columns[1].attribute - - contingency_table = [] - groupby_cardinality = ldf.cardinality[groupby_column] - groupby_unique_vals = ldf.unique_values[groupby_column] - for c in range(0, groupby_cardinality): - contingency_table.append(vis.data[vis.data[groupby_column] == groupby_unique_vals[c]][measure_column]) - score = 0.12 - #ValueError results if an entire column of the contingency table is 0, can happen if an applied filter results in - #a category having no counts - - try: - color_cardinality = ldf.cardinality[color_column] - #scale down score based on number of categories - chi2_score = chi2_contingency(contingency_table)[0]*0.9**(color_cardinality+groupby_cardinality) - score = min(0.10, chi2_score) - except ValueError: - pass - return(score) - # Default - else: - return -1 + + +def interestingness(vis: Vis, ldf: LuxDataFrame) -> int: + """ + Compute the interestingness score of the vis. + The interestingness metric is dependent on the vis type. + + Parameters + ---------- + vis : Vis + ldf : LuxDataFrame + + Returns + ------- + int + Interestingness Score + """ + + if vis.data is None or len(vis.data) == 0: + return -1 + # raise Exception("Vis.data needs to be populated before interestingness can be computed. Run Executor.execute(vis,ldf).") + + n_dim = 0 + n_msr = 0 + + filter_specs = utils.get_filter_specs(vis._inferred_intent) + vis_attrs_specs = utils.get_attrs_specs(vis._inferred_intent) + + record_attrs = list( + filter( + lambda x: x.attribute == "Record" and x.data_model == "measure", + vis_attrs_specs, + ) + ) + n_record = len(record_attrs) + for clause in vis_attrs_specs: + if clause.attribute != "Record": + if clause.data_model == "dimension": + n_dim += 1 + if clause.data_model == "measure": + n_msr += 1 + n_filter = len(filter_specs) + attr_specs = [clause for clause in vis_attrs_specs if clause.attribute != "Record"] + dimension_lst = vis.get_attr_by_data_model("dimension") + measure_lst = vis.get_attr_by_data_model("measure") + v_size = len(vis.data) + # Line/Bar Chart + # print("r:", n_record, "m:", n_msr, "d:",n_dim) + if n_dim == 1 and (n_msr == 0 or n_msr == 1): + if v_size < 2: + return -1 + if n_filter == 0: + return unevenness(vis, ldf, measure_lst, dimension_lst) + elif n_filter == 1: + return deviation_from_overall( + vis, ldf, filter_specs, measure_lst[0].attribute + ) + # Histogram + elif n_dim == 0 and n_msr == 1: + if v_size < 2: + return -1 + if n_filter == 0 and "Number of Records" in vis.data: + if "Number of Records" in vis.data: + v = vis.data["Number of Records"] + return skewness(v) + elif n_filter == 1 and "Number of Records" in vis.data: + return deviation_from_overall(vis, ldf, filter_specs, "Number of Records") + return -1 + # Scatter Plot + elif n_dim == 0 and n_msr == 2: + if v_size < 10: + return -1 + if vis.mark == "heatmap": + return weighted_correlation( + vis.data["xBinStart"], vis.data["yBinStart"], vis.data["count"] + ) + if n_filter == 1: + v_filter_size = get_filtered_size(filter_specs, vis.data) + sig = v_filter_size / v_size + else: + sig = 1 + return sig * monotonicity(vis, attr_specs) + # Scatterplot colored by Dimension + elif n_dim == 1 and n_msr == 2: + if v_size < 10: + return -1 + color_attr = vis.get_attr_by_channel("color")[0].attribute + + C = ldf.cardinality[color_attr] + if C < 40: + return 1 / C + else: + return -1 + # Scatterplot colored by dimension + elif n_dim == 1 and n_msr == 2: + return 0.2 + # Scatterplot colored by measure + elif n_msr == 3: + return 0.1 + # colored line and barchart cases + elif vis.mark == "line" and n_dim == 2: + return 0.15 + # for colored bar chart, scoring based on Chi-square test for independence score. + # gives higher scores to colored bar charts with fewer total categories as these charts are easier to read and thus more useful for users + elif vis.mark == "bar" and n_dim == 2: + from scipy.stats import chi2_contingency + + measure_column = vis.get_attr_by_data_model("measure")[0].attribute + dimension_columns = vis.get_attr_by_data_model("dimension") + + groupby_column = dimension_columns[0].attribute + color_column = dimension_columns[1].attribute + + contingency_table = [] + groupby_cardinality = ldf.cardinality[groupby_column] + groupby_unique_vals = ldf.unique_values[groupby_column] + for c in range(0, groupby_cardinality): + contingency_table.append( + vis.data[vis.data[groupby_column] == groupby_unique_vals[c]][ + measure_column + ] + ) + score = 0.12 + # ValueError results if an entire column of the contingency table is 0, can happen if an applied filter results in + # a category having no counts + + try: + color_cardinality = ldf.cardinality[color_column] + # scale down score based on number of categories + chi2_score = chi2_contingency(contingency_table)[0] * 0.9 ** ( + color_cardinality + groupby_cardinality + ) + score = min(0.10, chi2_score) + except ValueError: + pass + return score + # Default + else: + return -1 + + def get_filtered_size(filter_specs, ldf): - filter_intents = filter_specs[0] - result = PandasExecutor.apply_filter(ldf, filter_intents.attribute, filter_intents.filter_op, filter_intents.value) - return len(result) + filter_intents = filter_specs[0] + result = PandasExecutor.apply_filter( + ldf, filter_intents.attribute, filter_intents.filter_op, filter_intents.value + ) + return len(result) + + def skewness(v): - from scipy.stats import skew - return skew(v) + from scipy.stats import skew + + return skew(v) + def weighted_avg(x, w): - return np.average(x,weights=w) + return np.average(x, weights=w) + def weighted_cov(x, y, w): return np.sum(w * (x - weighted_avg(x, w)) * (y - weighted_avg(y, w))) / np.sum(w) + def weighted_correlation(x, y, w): # Based on https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Weighted_correlation_coefficient - return weighted_cov(x, y, w) / np.sqrt(weighted_cov(x, x, w) * weighted_cov(y, y, w)) - -def deviation_from_overall(vis:Vis, ldf:LuxDataFrame, filter_specs:list, msr_attribute:str) -> int: - """ - Difference in bar chart/histogram shape from overall chart - Note: this function assumes that the filtered vis.data is operating on the same range as the unfiltered vis.data. - - Parameters - ---------- - vis : Vis - ldf : LuxDataFrame - filter_specs : list - List of filters from the Vis - msr_attribute : str - The attribute name of the measure value of the chart - - Returns - ------- - int - Score describing how different the vis is from the overall vis - """ - v_filter_size = get_filtered_size(filter_specs, ldf) - v_size = len(vis.data) - v_filter = vis.data[msr_attribute] - total = v_filter.sum() - v_filter = v_filter/total # normalize by total to get ratio - if (total==0): return 0 - # Generate an "Overall" Vis (TODO: This is computed multiple times for every vis, alternative is to directly access df.current_vis but we do not have guaruntee that will always be unfiltered vis (in the non-Filter action scenario)) - import copy - unfiltered_vis = copy.copy(vis) - unfiltered_vis._inferred_intent = utils.get_attrs_specs(vis._inferred_intent) # Remove filters, keep only attribute intent - ldf.executor.execute([unfiltered_vis],ldf) - - v = unfiltered_vis.data[msr_attribute] - v = v/v.sum() - assert len(v) == len(v_filter), "Data for filtered and unfiltered vis have unequal length." - sig = v_filter_size/v_size #significance factor - # Euclidean distance as L2 function - - rankSig = 1 #category measure value ranking significance factor - #if the vis is a barchart, count how many categories' rank, based on measure value, changes after the filter is applied - if vis.mark == "bar": - dimList = vis.get_attr_by_data_model("dimension") - - #use Pandas rank function to calculate rank positions for each category - v_rank = unfiltered_vis.data.rank() - v_filter_rank = vis.data.rank() - #go through and count the number of ranking changes between the filtered and unfiltered data - numCategories = ldf.cardinality[dimList[0].attribute] - for r in range(0, numCategories-1): - if v_rank[msr_attribute][r] != v_filter_rank[msr_attribute][r]: - rankSig += 1 - #normalize ranking significance factor - rankSig = rankSig/numCategories - - from scipy.spatial.distance import euclidean - return sig*rankSig* euclidean(v, v_filter) - -def unevenness(vis:Vis, ldf:LuxDataFrame, measure_lst:list, dimension_lst:list) -> int: - """ - Measure the unevenness of a bar chart vis. - If a bar chart is highly uneven across the possible values, then it may be interesting. (e.g., USA produces lots of cars compared to Japan and Europe) - Likewise, if a bar chart shows that the measure is the same for any possible values the dimension attribute could take on, then it may not very informative. - (e.g., The cars produced across all Origins (Europe, Japan, and USA) has approximately the same average Acceleration.) - - Parameters - ---------- - vis : Vis - ldf : LuxDataFrame - measure_lst : list - List of measures - dimension_lst : list - List of dimensions - Returns - ------- - int - Score describing how uneven the bar chart is. - """ - v = vis.data[measure_lst[0].attribute] - v = v/v.sum() # normalize by total to get ratio - C = ldf.cardinality[dimension_lst[0].attribute] - D = (0.9) ** C # cardinality-based discounting factor - v_flat = pd.Series([1 / C] * len(v)) - if (is_datetime(v)): - v = v.astype('int') - return D * euclidean(v, v_flat) - -def mutual_information(v_x:list , v_y:list) -> int: - #Interestingness metric for two measure attributes - #Calculate maximal information coefficient (see Murphy pg 61) or Pearson's correlation - from sklearn.metrics import mutual_info_score - return mutual_info_score(v_x, v_y) - -def monotonicity(vis:Vis, attr_specs:list, ignore_identity:bool=True) ->int: - """ - Monotonicity measures there is a monotonic trend in the scatterplot, whether linear or not. - This score is computed as the square of the Spearman correlation coefficient, which is the Pearson correlation on the ranks of x and y. - See "Graph-Theoretic Scagnostics", Wilkinson et al 2005: https://research.tableau.com/sites/default/files/Wilkinson_Infovis-05.pdf - Parameters - ---------- - vis : Vis - attr_spec: list - List of attribute Clause objects - - ignore_identity: bool - Boolean flag to ignore items with the same x and y attribute (score as -1) - - Returns - ------- - int - Score describing the strength of monotonic relationship in vis - """ - from scipy.stats import spearmanr - msr1 = attr_specs[0].attribute - msr2 = attr_specs[1].attribute - - if(ignore_identity and msr1 == msr2): #remove if measures are the same - return -1 - v_x = vis.data[msr1] - v_y = vis.data[msr2] - - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings('error') - try: - score = (spearmanr(v_x, v_y)[0]) ** 2 - except(RuntimeWarning): - # RuntimeWarning: invalid value encountered in true_divide (occurs when v_x and v_y are uniform, stdev in denominator is zero, leading to spearman's correlation as nan), ignore these cases. - score = -1 - - if pd.isnull(score): - return -1 - else: - return score - # import scipy.stats - # return abs(scipy.stats.pearsonr(v_x,v_y)[0]) + return weighted_cov(x, y, w) / np.sqrt( + weighted_cov(x, x, w) * weighted_cov(y, y, w) + ) + + +def deviation_from_overall( + vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str +) -> int: + """ + Difference in bar chart/histogram shape from overall chart + Note: this function assumes that the filtered vis.data is operating on the same range as the unfiltered vis.data. + + Parameters + ---------- + vis : Vis + ldf : LuxDataFrame + filter_specs : list + List of filters from the Vis + msr_attribute : str + The attribute name of the measure value of the chart + + Returns + ------- + int + Score describing how different the vis is from the overall vis + """ + v_filter_size = get_filtered_size(filter_specs, ldf) + v_size = len(vis.data) + v_filter = vis.data[msr_attribute] + total = v_filter.sum() + v_filter = v_filter / total # normalize by total to get ratio + if total == 0: + return 0 + # Generate an "Overall" Vis (TODO: This is computed multiple times for every vis, alternative is to directly access df.current_vis but we do not have guaruntee that will always be unfiltered vis (in the non-Filter action scenario)) + import copy + + unfiltered_vis = copy.copy(vis) + unfiltered_vis._inferred_intent = utils.get_attrs_specs( + vis._inferred_intent + ) # Remove filters, keep only attribute intent + ldf.executor.execute([unfiltered_vis], ldf) + + v = unfiltered_vis.data[msr_attribute] + v = v / v.sum() + assert len(v) == len( + v_filter + ), "Data for filtered and unfiltered vis have unequal length." + sig = v_filter_size / v_size # significance factor + # Euclidean distance as L2 function + + rankSig = 1 # category measure value ranking significance factor + # if the vis is a barchart, count how many categories' rank, based on measure value, changes after the filter is applied + if vis.mark == "bar": + dimList = vis.get_attr_by_data_model("dimension") + + # use Pandas rank function to calculate rank positions for each category + v_rank = unfiltered_vis.data.rank() + v_filter_rank = vis.data.rank() + # go through and count the number of ranking changes between the filtered and unfiltered data + numCategories = ldf.cardinality[dimList[0].attribute] + for r in range(0, numCategories - 1): + if v_rank[msr_attribute][r] != v_filter_rank[msr_attribute][r]: + rankSig += 1 + # normalize ranking significance factor + rankSig = rankSig / numCategories + + from scipy.spatial.distance import euclidean + + return sig * rankSig * euclidean(v, v_filter) + + +def unevenness( + vis: Vis, ldf: LuxDataFrame, measure_lst: list, dimension_lst: list +) -> int: + """ + Measure the unevenness of a bar chart vis. + If a bar chart is highly uneven across the possible values, then it may be interesting. (e.g., USA produces lots of cars compared to Japan and Europe) + Likewise, if a bar chart shows that the measure is the same for any possible values the dimension attribute could take on, then it may not very informative. + (e.g., The cars produced across all Origins (Europe, Japan, and USA) has approximately the same average Acceleration.) + + Parameters + ---------- + vis : Vis + ldf : LuxDataFrame + measure_lst : list + List of measures + dimension_lst : list + List of dimensions + Returns + ------- + int + Score describing how uneven the bar chart is. + """ + v = vis.data[measure_lst[0].attribute] + v = v / v.sum() # normalize by total to get ratio + C = ldf.cardinality[dimension_lst[0].attribute] + D = (0.9) ** C # cardinality-based discounting factor + v_flat = pd.Series([1 / C] * len(v)) + if is_datetime(v): + v = v.astype("int") + return D * euclidean(v, v_flat) + + +def mutual_information(v_x: list, v_y: list) -> int: + # Interestingness metric for two measure attributes + # Calculate maximal information coefficient (see Murphy pg 61) or Pearson's correlation + from sklearn.metrics import mutual_info_score + + return mutual_info_score(v_x, v_y) + + +def monotonicity(vis: Vis, attr_specs: list, ignore_identity: bool = True) -> int: + """ + Monotonicity measures there is a monotonic trend in the scatterplot, whether linear or not. + This score is computed as the square of the Spearman correlation coefficient, which is the Pearson correlation on the ranks of x and y. + See "Graph-Theoretic Scagnostics", Wilkinson et al 2005: https://research.tableau.com/sites/default/files/Wilkinson_Infovis-05.pdf + Parameters + ---------- + vis : Vis + attr_spec: list + List of attribute Clause objects + + ignore_identity: bool + Boolean flag to ignore items with the same x and y attribute (score as -1) + + Returns + ------- + int + Score describing the strength of monotonic relationship in vis + """ + from scipy.stats import spearmanr + + msr1 = attr_specs[0].attribute + msr2 = attr_specs[1].attribute + + if ignore_identity and msr1 == msr2: # remove if measures are the same + return -1 + v_x = vis.data[msr1] + v_y = vis.data[msr2] + + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("error") + try: + score = (spearmanr(v_x, v_y)[0]) ** 2 + except (RuntimeWarning): + # RuntimeWarning: invalid value encountered in true_divide (occurs when v_x and v_y are uniform, stdev in denominator is zero, leading to spearman's correlation as nan), ignore these cases. + score = -1 + + if pd.isnull(score): + return -1 + else: + return score + # import scipy.stats + # return abs(scipy.stats.pearsonr(v_x,v_y)[0]) diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py index 17edb97a..0635f2de 100644 --- a/lux/processor/Compiler.py +++ b/lux/processor/Compiler.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -22,407 +22,467 @@ import numpy as np import warnings + class Compiler: - ''' - Given a intent with underspecified inputs, compile the intent into fully specified visualizations for visualization. - ''' - - def __init__(self): - self.name = "Compiler" - - def __repr__(self): - return f"" - - @staticmethod - def compile_vis(ldf: LuxDataFrame,vis:Vis) -> VisList: - if (vis): - vis_collection = Compiler.populate_data_type_model(ldf, [vis]) # autofill data type/model information - vis_collection = Compiler.remove_all_invalid(vis_collection) # remove invalid visualizations from collection - for vis in vis_collection: - Compiler.determine_encoding(ldf, vis) # autofill viz related information - ldf._compiled=True - return vis_collection - @staticmethod - def compile_intent(ldf: LuxDataFrame,_inferred_intent:List[Clause]) -> VisList: - """ - Compiles input specifications in the intent of the ldf into a collection of lux.vis objects for visualization. - 1) Enumerate a collection of visualizations interested by the user to generate a vis list - 2) Expand underspecified specifications(lux.Clause) for each of the generated visualizations. - 3) Determine encoding properties for each vis - - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. - vis_collection : list[lux.vis.Vis] - empty list that will be populated with specified lux.Vis objects. - - Returns - ------- - vis_collection: list[lux.Vis] - vis list with compiled lux.Vis objects. - """ - if (_inferred_intent): - vis_collection = Compiler.enumerate_collection(_inferred_intent,ldf) - vis_collection = Compiler.populate_data_type_model(ldf, vis_collection) # autofill data type/model information - if len(vis_collection)>=1: - vis_collection = Compiler.remove_all_invalid(vis_collection) # remove invalid visualizations from collection - for vis in vis_collection: - Compiler.determine_encoding(ldf, vis) # autofill viz related information - ldf._compiled=True - return vis_collection - - @staticmethod - def enumerate_collection(_inferred_intent:List[Clause],ldf: LuxDataFrame) -> VisList: - """ - Given specifications that have been expanded thorught populateOptions, - recursively iterate over the resulting list combinations to generate a vis list. - - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. - - Returns - ------- - VisList: list[lux.Vis] - vis list with compiled lux.Vis objects. - """ - import copy - intent = Compiler.populate_wildcard_options(_inferred_intent, ldf) - attributes = intent['attributes'] - filters = intent['filters'] - if len(attributes) == 0 and len(filters) > 0: - return [] - - collection = [] - - # TODO: generate combinations of column attributes recursively by continuing to accumulate attributes for len(colAtrr) times - def combine(col_attrs, accum): - last = (len(col_attrs) == 1) - n = len(col_attrs[0]) - for i in range(n): - column_list = copy.deepcopy(accum + [col_attrs[0][i]]) - if last: - if len(filters) > 0: # if we have filters, generate combinations for each row. - for row in filters: - _inferred_intent = copy.deepcopy(column_list + [row]) - vis = Vis(_inferred_intent) - collection.append(vis) - else: - vis = Vis(column_list) - collection.append(vis) - else: - combine(col_attrs[1:], column_list) - combine(attributes, []) - return VisList(collection) - - @staticmethod - def populate_data_type_model(ldf, vis_collection) -> VisList: - """ - Given a underspecified Clause, populate the data_type and data_model information accordingly - - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent - - vis_collection : list[lux.vis.Vis] - List of lux.Vis objects that will have their underspecified Clause details filled out. - Returns - ------- - vlist: VisList - vis list with compiled lux.Vis objects. - """ - # TODO: copy might not be neccesary - from lux.utils.date_utils import is_datetime_string - import copy - vlist = copy.deepcopy(vis_collection) # Preserve the original dobj - for vis in vlist: - for clause in vis._inferred_intent: - if (clause.description == "?"): - clause.description = "" - # TODO: Note that "and not is_datetime_string(clause.attribute))" is a temporary hack and breaks the `test_row_column_group` example - if (clause.attribute!="" and clause.attribute!="Record"):# and not is_datetime_string(clause.attribute): - if (clause.data_type == ""): - clause.data_type = ldf.data_type_lookup[clause.attribute] - if (clause.data_type=="id"): - clause.data_type = "nominal" - if (clause.data_model == ""): - clause.data_model = ldf.data_model_lookup[clause.attribute] - if (clause.value!=""): - if (vis.title == ""): #If user provided title for Vis, then don't override. - if(isinstance(clause.value,np.datetime64)): - chart_title = date_utils.date_formatter(clause.value,ldf) - else: - chart_title = clause.value - vis.title = f"{clause.attribute} {clause.filter_op} {chart_title}" - return vlist - - @staticmethod - def remove_all_invalid(vis_collection:VisList) -> VisList: - """ - Given an expanded vis list, remove all visualizations that are invalid. - Currently, the invalid visualizations are ones that contain two of the same attribute, no more than two temporal attributes, or overlapping attributes (same filter attribute and visualized attribute). - Parameters - ---------- - vis_collection : list[lux.vis.Vis] - empty list that will be populated with specified lux.Vis objects. - Returns - ------- - lux.vis.VisList - vis list with compiled lux.Vis objects. - """ - new_vc = [] - for vis in vis_collection: - num_temporal_specs = 0 - attribute_set = set() - for clause in vis._inferred_intent: - attribute_set.add(clause.attribute) - if clause.data_type == "temporal": - num_temporal_specs += 1 - all_distinct_specs = 0 == len(vis._inferred_intent) - len(attribute_set) - if num_temporal_specs < 2 and all_distinct_specs: - new_vc.append(vis) - # else: - # warnings.warn("\nThere is more than one duplicate attribute specified in the intent.\nPlease check your intent specification again.") - - return VisList(new_vc) - - @staticmethod - def determine_encoding(ldf: LuxDataFrame, vis: Vis): - ''' - Populates Vis with the appropriate mark type and channel information based on ShowMe logic - Currently support up to 3 dimensions or measures - - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent - vis : lux.vis.Vis - - Returns - ------- - None - - Notes - ----- - Implementing automatic encoding from Tableau's VizQL - Mackinlay, J. D., Hanrahan, P., & Stolte, C. (2007). - Show Me: Automatic presentation for visual analysis. - IEEE Transactions on Visualization and Computer Graphics, 13(6), 1137–1144. - https://doi.org/10.1109/TVCG.2007.70594 - ''' - # Count number of measures and dimensions - ndim = 0 - nmsr = 0 - filters = [] - for clause in vis._inferred_intent: - if (clause.value==""): - if (clause.data_model == "dimension"): - ndim += 1 - elif (clause.data_model == "measure" and clause.attribute!="Record"): - nmsr += 1 - else: # preserve to add back to _inferred_intent later - filters.append(clause) - # Helper function (TODO: Move this into utils) - def line_or_bar(ldf, dimension:Clause, measure:Clause): - dim_type = dimension.data_type - # If no aggregation function is specified, then default as average - if (measure.aggregation==""): - measure.set_aggregation("mean") - if (dim_type == "temporal" or dim_type == "oridinal"): - return "line", {"x": dimension, "y": measure} - else: # unordered categorical - # if cardinality large than 5 then sort bars - if ldf.cardinality[dimension.attribute]>5: - dimension.sort = "ascending" - return "bar", {"x": measure, "y": dimension} - # ShowMe logic + additional heuristics - #count_col = Clause( attribute="count()", data_model="measure") - count_col = Clause( attribute="Record", aggregation="count", data_model="measure", data_type="quantitative") - auto_channel={} - if (ndim == 0 and nmsr == 1): - # Histogram with Count - measure = vis.get_attr_by_data_model("measure",exclude_record=True)[0] - if (len(vis.get_attr_by_attr_name("Record"))<0): - vis._inferred_intent.append(count_col) - # If no bin specified, then default as 10 - if (measure.bin_size == 0): - measure.bin_size = 10 - auto_channel = {"x": measure, "y": count_col} - vis._mark = "histogram" - elif (ndim == 1 and (nmsr == 0 or nmsr == 1)): - # Line or Bar Chart - if (nmsr == 0): - vis._inferred_intent.append(count_col) - dimension = vis.get_attr_by_data_model("dimension")[0] - measure = vis.get_attr_by_data_model("measure")[0] - vis._mark, auto_channel = line_or_bar(ldf, dimension, measure) - elif (ndim == 2 and (nmsr == 0 or nmsr == 1)): - # Line or Bar chart broken down by the dimension - dimensions = vis.get_attr_by_data_model("dimension") - d1 = dimensions[0] - d2 = dimensions[1] - if (ldf.cardinality[d1.attribute] < ldf.cardinality[d2.attribute]): - # d1.channel = "color" - vis.remove_column_from_spec(d1.attribute) - dimension = d2 - color_attr = d1 - else: - if (d1.attribute == d2.attribute): - vis._inferred_intent.pop(0) # if same attribute then remove_column_from_spec will remove both dims, we only want to remove one - else: - vis.remove_column_from_spec(d2.attribute) - dimension = d1 - color_attr = d2 - # Colored Bar/Line chart with Count as default measure - if not ldf.pre_aggregated: - if (nmsr == 0 and not ldf.pre_aggregated): - vis._inferred_intent.append(count_col) - measure = vis.get_attr_by_data_model("measure")[0] - vis._mark, auto_channel = line_or_bar(ldf, dimension, measure) - auto_channel["color"] = color_attr - elif (ndim == 0 and nmsr == 2): - # Scatterplot - vis._mark = "scatter" - vis._inferred_intent[0].set_aggregation(None) - vis._inferred_intent[1].set_aggregation(None) - auto_channel = {"x": vis._inferred_intent[0], - "y": vis._inferred_intent[1]} - elif (ndim == 1 and nmsr == 2): - # Scatterplot broken down by the dimension - measure = vis.get_attr_by_data_model("measure") - m1 = measure[0] - m2 = measure[1] - - vis._inferred_intent[0].set_aggregation(None) - vis._inferred_intent[1].set_aggregation(None) - - color_attr = vis.get_attr_by_data_model("dimension")[0] - vis.remove_column_from_spec(color_attr) - vis._mark = "scatter" - auto_channel = {"x": m1, - "y": m2, - "color": color_attr} - elif (ndim == 0 and nmsr == 3): - # Scatterplot with color - vis._mark = "scatter" - auto_channel = {"x": vis._inferred_intent[0], - "y": vis._inferred_intent[1], - "color": vis._inferred_intent[2]} - relevant_attributes = [auto_channel[channel].attribute for channel in auto_channel] - relevant_min_max = dict((attr, ldf._min_max[attr]) for attr in relevant_attributes if attr != "Record" and attr in ldf._min_max) - vis._min_max = relevant_min_max - if (auto_channel!={}): - vis = Compiler.enforce_specified_channel(vis, auto_channel) - vis._inferred_intent.extend(filters) # add back the preserved filters - @staticmethod - def enforce_specified_channel(vis: Vis, auto_channel: Dict[str, str]): - """ - Enforces that the channels specified in the Vis by users overrides the showMe autoChannels. - - Parameters - ---------- - vis : lux.vis.Vis - Input Vis without channel specification. - auto_channel : Dict[str,str] - Key-value pair in the form [channel: attributeName] specifying the showMe recommended channel location. - - Returns - ------- - vis : lux.vis.Vis - Vis with channel specification combining both original and auto_channel specification. - - Raises - ------ - ValueError - Ensures no more than one attribute is placed in the same channel. - """ - result_dict = {} # result of enforcing specified channel will be stored in result_dict - specified_dict = {} # specified_dict={"x":[],"y":[list of Dobj with y specified as channel]} - # create a dictionary of specified channels in the given dobj - for val in auto_channel.keys(): - specified_dict[val] = vis.get_attr_by_channel(val) - result_dict[val] = "" - # for every element, replace with what's in specified_dict if specified - for sVal, sAttr in specified_dict.items(): - if (len(sAttr) == 1): # if specified in dobj - # remove the specified channel from auto_channel (matching by value, since channel key may not be same) - for i in list(auto_channel.keys()): - if ((auto_channel[i].attribute == sAttr[0].attribute) - and (auto_channel[i].channel == sVal)): # need to ensure that the channel is the same (edge case when duplicate Cols with same attribute name) - auto_channel.pop(i) - break - sAttr[0].channel = sVal - result_dict[sVal] = sAttr[0] - elif (len(sAttr) > 1): - raise ValueError("There should not be more than one attribute specified in the same channel.") - # For the leftover channels that are still unspecified in result_dict, - # and the leftovers in the auto_channel specification, - # step through them together and fill it automatically. - leftover_channels = list(filter(lambda x: result_dict[x] == '', result_dict)) - for leftover_channel, leftover_encoding in zip(leftover_channels, auto_channel.values()): - leftover_encoding.channel = leftover_channel - result_dict[leftover_channel] = leftover_encoding - vis._inferred_intent = list(result_dict.values()) - return vis - - @staticmethod - # def populate_wildcard_options(ldf: LuxDataFrame) -> dict: - def populate_wildcard_options(_inferred_intent:List[Clause], ldf: LuxDataFrame) -> dict: - """ - Given wildcards and constraints in the LuxDataFrame's intent, - return the list of available values that satisfies the data_type or data_model constraints. - - Parameters - ---------- - ldf : LuxDataFrame - LuxDataFrame with row or attributes populated with available wildcard options. - - Returns - ------- - intent: Dict[str,list] - a dictionary that holds the attributes and filters generated from wildcards and constraints. - """ - import copy - from lux.utils.utils import convert_to_list - - intent = {"attributes": [], "filters": []} - for clause in _inferred_intent: - spec_options = [] - if clause.value == "": # attribute - if clause.attribute == "?": - options = set(list(ldf.columns)) # all attributes - if (clause.data_type != ""): - options = options.intersection(set(ldf.data_type[clause.data_type])) - if (clause.data_model != ""): - options = options.intersection(set(ldf.data_model[clause.data_model])) - options = list(options) - else: - options = convert_to_list(clause.attribute) - for optStr in options: - if str(optStr) not in clause.exclude: - spec_copy = copy.copy(clause) - spec_copy.attribute = optStr - spec_options.append(spec_copy) - intent["attributes"].append(spec_options) - else: # filters - attr_lst = convert_to_list(clause.attribute) - for attr in attr_lst: - options = [] - if clause.value == "?": - options = ldf.unique_values[attr] - specInd = _inferred_intent.index(clause) - _inferred_intent[specInd] = Clause(attribute=clause.attribute, filter_op="=", value=list(options)) - else: - options.extend(convert_to_list(clause.value)) - for optStr in options: - if str(optStr) not in clause.exclude: - spec_copy = copy.copy(clause) - spec_copy.attribute = attr - spec_copy.value = optStr - spec_options.append(spec_copy) - intent["filters"].extend(spec_options) - - return intent + """ + Given a intent with underspecified inputs, compile the intent into fully specified visualizations for visualization. + """ + + def __init__(self): + self.name = "Compiler" + + def __repr__(self): + return f"" + + @staticmethod + def compile_vis(ldf: LuxDataFrame, vis: Vis) -> VisList: + if vis: + vis_collection = Compiler.populate_data_type_model( + ldf, [vis] + ) # autofill data type/model information + vis_collection = Compiler.remove_all_invalid( + vis_collection + ) # remove invalid visualizations from collection + for vis in vis_collection: + Compiler.determine_encoding( + ldf, vis + ) # autofill viz related information + ldf._compiled = True + return vis_collection + + @staticmethod + def compile_intent(ldf: LuxDataFrame, _inferred_intent: List[Clause]) -> VisList: + """ + Compiles input specifications in the intent of the ldf into a collection of lux.vis objects for visualization. + 1) Enumerate a collection of visualizations interested by the user to generate a vis list + 2) Expand underspecified specifications(lux.Clause) for each of the generated visualizations. + 3) Determine encoding properties for each vis + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. + vis_collection : list[lux.vis.Vis] + empty list that will be populated with specified lux.Vis objects. + + Returns + ------- + vis_collection: list[lux.Vis] + vis list with compiled lux.Vis objects. + """ + if _inferred_intent: + vis_collection = Compiler.enumerate_collection(_inferred_intent, ldf) + vis_collection = Compiler.populate_data_type_model( + ldf, vis_collection + ) # autofill data type/model information + if len(vis_collection) >= 1: + vis_collection = Compiler.remove_all_invalid( + vis_collection + ) # remove invalid visualizations from collection + for vis in vis_collection: + Compiler.determine_encoding( + ldf, vis + ) # autofill viz related information + ldf._compiled = True + return vis_collection + + @staticmethod + def enumerate_collection( + _inferred_intent: List[Clause], ldf: LuxDataFrame + ) -> VisList: + """ + Given specifications that have been expanded thorught populateOptions, + recursively iterate over the resulting list combinations to generate a vis list. + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. + + Returns + ------- + VisList: list[lux.Vis] + vis list with compiled lux.Vis objects. + """ + import copy + + intent = Compiler.populate_wildcard_options(_inferred_intent, ldf) + attributes = intent["attributes"] + filters = intent["filters"] + if len(attributes) == 0 and len(filters) > 0: + return [] + + collection = [] + + # TODO: generate combinations of column attributes recursively by continuing to accumulate attributes for len(colAtrr) times + def combine(col_attrs, accum): + last = len(col_attrs) == 1 + n = len(col_attrs[0]) + for i in range(n): + column_list = copy.deepcopy(accum + [col_attrs[0][i]]) + if last: + if ( + len(filters) > 0 + ): # if we have filters, generate combinations for each row. + for row in filters: + _inferred_intent = copy.deepcopy(column_list + [row]) + vis = Vis(_inferred_intent) + collection.append(vis) + else: + vis = Vis(column_list) + collection.append(vis) + else: + combine(col_attrs[1:], column_list) + + combine(attributes, []) + return VisList(collection) + + @staticmethod + def populate_data_type_model(ldf, vis_collection) -> VisList: + """ + Given a underspecified Clause, populate the data_type and data_model information accordingly + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent + + vis_collection : list[lux.vis.Vis] + List of lux.Vis objects that will have their underspecified Clause details filled out. + Returns + ------- + vlist: VisList + vis list with compiled lux.Vis objects. + """ + # TODO: copy might not be neccesary + from lux.utils.date_utils import is_datetime_string + import copy + + vlist = copy.deepcopy(vis_collection) # Preserve the original dobj + for vis in vlist: + for clause in vis._inferred_intent: + if clause.description == "?": + clause.description = "" + # TODO: Note that "and not is_datetime_string(clause.attribute))" is a temporary hack and breaks the `test_row_column_group` example + if ( + clause.attribute != "" and clause.attribute != "Record" + ): # and not is_datetime_string(clause.attribute): + if clause.data_type == "": + clause.data_type = ldf.data_type_lookup[clause.attribute] + if clause.data_type == "id": + clause.data_type = "nominal" + if clause.data_model == "": + clause.data_model = ldf.data_model_lookup[clause.attribute] + if clause.value != "": + if ( + vis.title == "" + ): # If user provided title for Vis, then don't override. + if isinstance(clause.value, np.datetime64): + chart_title = date_utils.date_formatter(clause.value, ldf) + else: + chart_title = clause.value + vis.title = ( + f"{clause.attribute} {clause.filter_op} {chart_title}" + ) + return vlist + + @staticmethod + def remove_all_invalid(vis_collection: VisList) -> VisList: + """ + Given an expanded vis list, remove all visualizations that are invalid. + Currently, the invalid visualizations are ones that contain two of the same attribute, no more than two temporal attributes, or overlapping attributes (same filter attribute and visualized attribute). + Parameters + ---------- + vis_collection : list[lux.vis.Vis] + empty list that will be populated with specified lux.Vis objects. + Returns + ------- + lux.vis.VisList + vis list with compiled lux.Vis objects. + """ + new_vc = [] + for vis in vis_collection: + num_temporal_specs = 0 + attribute_set = set() + for clause in vis._inferred_intent: + attribute_set.add(clause.attribute) + if clause.data_type == "temporal": + num_temporal_specs += 1 + all_distinct_specs = 0 == len(vis._inferred_intent) - len(attribute_set) + if num_temporal_specs < 2 and all_distinct_specs: + new_vc.append(vis) + # else: + # warnings.warn("\nThere is more than one duplicate attribute specified in the intent.\nPlease check your intent specification again.") + + return VisList(new_vc) + + @staticmethod + def determine_encoding(ldf: LuxDataFrame, vis: Vis): + """ + Populates Vis with the appropriate mark type and channel information based on ShowMe logic + Currently support up to 3 dimensions or measures + + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent + vis : lux.vis.Vis + + Returns + ------- + None + + Notes + ----- + Implementing automatic encoding from Tableau's VizQL + Mackinlay, J. D., Hanrahan, P., & Stolte, C. (2007). + Show Me: Automatic presentation for visual analysis. + IEEE Transactions on Visualization and Computer Graphics, 13(6), 1137–1144. + https://doi.org/10.1109/TVCG.2007.70594 + """ + # Count number of measures and dimensions + ndim = 0 + nmsr = 0 + filters = [] + for clause in vis._inferred_intent: + if clause.value == "": + if clause.data_model == "dimension": + ndim += 1 + elif clause.data_model == "measure" and clause.attribute != "Record": + nmsr += 1 + else: # preserve to add back to _inferred_intent later + filters.append(clause) + # Helper function (TODO: Move this into utils) + def line_or_bar(ldf, dimension: Clause, measure: Clause): + dim_type = dimension.data_type + # If no aggregation function is specified, then default as average + if measure.aggregation == "": + measure.set_aggregation("mean") + if dim_type == "temporal" or dim_type == "oridinal": + return "line", {"x": dimension, "y": measure} + else: # unordered categorical + # if cardinality large than 5 then sort bars + if ldf.cardinality[dimension.attribute] > 5: + dimension.sort = "ascending" + return "bar", {"x": measure, "y": dimension} + + # ShowMe logic + additional heuristics + # count_col = Clause( attribute="count()", data_model="measure") + count_col = Clause( + attribute="Record", + aggregation="count", + data_model="measure", + data_type="quantitative", + ) + auto_channel = {} + if ndim == 0 and nmsr == 1: + # Histogram with Count + measure = vis.get_attr_by_data_model("measure", exclude_record=True)[0] + if len(vis.get_attr_by_attr_name("Record")) < 0: + vis._inferred_intent.append(count_col) + # If no bin specified, then default as 10 + if measure.bin_size == 0: + measure.bin_size = 10 + auto_channel = {"x": measure, "y": count_col} + vis._mark = "histogram" + elif ndim == 1 and (nmsr == 0 or nmsr == 1): + # Line or Bar Chart + if nmsr == 0: + vis._inferred_intent.append(count_col) + dimension = vis.get_attr_by_data_model("dimension")[0] + measure = vis.get_attr_by_data_model("measure")[0] + vis._mark, auto_channel = line_or_bar(ldf, dimension, measure) + elif ndim == 2 and (nmsr == 0 or nmsr == 1): + # Line or Bar chart broken down by the dimension + dimensions = vis.get_attr_by_data_model("dimension") + d1 = dimensions[0] + d2 = dimensions[1] + if ldf.cardinality[d1.attribute] < ldf.cardinality[d2.attribute]: + # d1.channel = "color" + vis.remove_column_from_spec(d1.attribute) + dimension = d2 + color_attr = d1 + else: + if d1.attribute == d2.attribute: + vis._inferred_intent.pop( + 0 + ) # if same attribute then remove_column_from_spec will remove both dims, we only want to remove one + else: + vis.remove_column_from_spec(d2.attribute) + dimension = d1 + color_attr = d2 + # Colored Bar/Line chart with Count as default measure + if not ldf.pre_aggregated: + if nmsr == 0 and not ldf.pre_aggregated: + vis._inferred_intent.append(count_col) + measure = vis.get_attr_by_data_model("measure")[0] + vis._mark, auto_channel = line_or_bar(ldf, dimension, measure) + auto_channel["color"] = color_attr + elif ndim == 0 and nmsr == 2: + # Scatterplot + vis._mark = "scatter" + vis._inferred_intent[0].set_aggregation(None) + vis._inferred_intent[1].set_aggregation(None) + auto_channel = {"x": vis._inferred_intent[0], "y": vis._inferred_intent[1]} + elif ndim == 1 and nmsr == 2: + # Scatterplot broken down by the dimension + measure = vis.get_attr_by_data_model("measure") + m1 = measure[0] + m2 = measure[1] + + vis._inferred_intent[0].set_aggregation(None) + vis._inferred_intent[1].set_aggregation(None) + + color_attr = vis.get_attr_by_data_model("dimension")[0] + vis.remove_column_from_spec(color_attr) + vis._mark = "scatter" + auto_channel = {"x": m1, "y": m2, "color": color_attr} + elif ndim == 0 and nmsr == 3: + # Scatterplot with color + vis._mark = "scatter" + auto_channel = { + "x": vis._inferred_intent[0], + "y": vis._inferred_intent[1], + "color": vis._inferred_intent[2], + } + relevant_attributes = [ + auto_channel[channel].attribute for channel in auto_channel + ] + relevant_min_max = dict( + (attr, ldf._min_max[attr]) + for attr in relevant_attributes + if attr != "Record" and attr in ldf._min_max + ) + vis._min_max = relevant_min_max + if auto_channel != {}: + vis = Compiler.enforce_specified_channel(vis, auto_channel) + vis._inferred_intent.extend(filters) # add back the preserved filters + + @staticmethod + def enforce_specified_channel(vis: Vis, auto_channel: Dict[str, str]): + """ + Enforces that the channels specified in the Vis by users overrides the showMe autoChannels. + + Parameters + ---------- + vis : lux.vis.Vis + Input Vis without channel specification. + auto_channel : Dict[str,str] + Key-value pair in the form [channel: attributeName] specifying the showMe recommended channel location. + + Returns + ------- + vis : lux.vis.Vis + Vis with channel specification combining both original and auto_channel specification. + + Raises + ------ + ValueError + Ensures no more than one attribute is placed in the same channel. + """ + result_dict = ( + {} + ) # result of enforcing specified channel will be stored in result_dict + specified_dict = ( + {} + ) # specified_dict={"x":[],"y":[list of Dobj with y specified as channel]} + # create a dictionary of specified channels in the given dobj + for val in auto_channel.keys(): + specified_dict[val] = vis.get_attr_by_channel(val) + result_dict[val] = "" + # for every element, replace with what's in specified_dict if specified + for sVal, sAttr in specified_dict.items(): + if len(sAttr) == 1: # if specified in dobj + # remove the specified channel from auto_channel (matching by value, since channel key may not be same) + for i in list(auto_channel.keys()): + if (auto_channel[i].attribute == sAttr[0].attribute) and ( + auto_channel[i].channel == sVal + ): # need to ensure that the channel is the same (edge case when duplicate Cols with same attribute name) + auto_channel.pop(i) + break + sAttr[0].channel = sVal + result_dict[sVal] = sAttr[0] + elif len(sAttr) > 1: + raise ValueError( + "There should not be more than one attribute specified in the same channel." + ) + # For the leftover channels that are still unspecified in result_dict, + # and the leftovers in the auto_channel specification, + # step through them together and fill it automatically. + leftover_channels = list(filter(lambda x: result_dict[x] == "", result_dict)) + for leftover_channel, leftover_encoding in zip( + leftover_channels, auto_channel.values() + ): + leftover_encoding.channel = leftover_channel + result_dict[leftover_channel] = leftover_encoding + vis._inferred_intent = list(result_dict.values()) + return vis + + @staticmethod + # def populate_wildcard_options(ldf: LuxDataFrame) -> dict: + def populate_wildcard_options( + _inferred_intent: List[Clause], ldf: LuxDataFrame + ) -> dict: + """ + Given wildcards and constraints in the LuxDataFrame's intent, + return the list of available values that satisfies the data_type or data_model constraints. + + Parameters + ---------- + ldf : LuxDataFrame + LuxDataFrame with row or attributes populated with available wildcard options. + + Returns + ------- + intent: Dict[str,list] + a dictionary that holds the attributes and filters generated from wildcards and constraints. + """ + import copy + from lux.utils.utils import convert_to_list + + intent = {"attributes": [], "filters": []} + for clause in _inferred_intent: + spec_options = [] + if clause.value == "": # attribute + if clause.attribute == "?": + options = set(list(ldf.columns)) # all attributes + if clause.data_type != "": + options = options.intersection( + set(ldf.data_type[clause.data_type]) + ) + if clause.data_model != "": + options = options.intersection( + set(ldf.data_model[clause.data_model]) + ) + options = list(options) + else: + options = convert_to_list(clause.attribute) + for optStr in options: + if str(optStr) not in clause.exclude: + spec_copy = copy.copy(clause) + spec_copy.attribute = optStr + spec_options.append(spec_copy) + intent["attributes"].append(spec_options) + else: # filters + attr_lst = convert_to_list(clause.attribute) + for attr in attr_lst: + options = [] + if clause.value == "?": + options = ldf.unique_values[attr] + specInd = _inferred_intent.index(clause) + _inferred_intent[specInd] = Clause( + attribute=clause.attribute, + filter_op="=", + value=list(options), + ) + else: + options.extend(convert_to_list(clause.value)) + for optStr in options: + if str(optStr) not in clause.exclude: + spec_copy = copy.copy(clause) + spec_copy.attribute = attr + spec_copy.value = optStr + spec_options.append(spec_copy) + intent["filters"].extend(spec_options) + + return intent diff --git a/lux/processor/Parser.py b/lux/processor/Parser.py index 7127b841..c1852021 100644 --- a/lux/processor/Parser.py +++ b/lux/processor/Parser.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,92 +15,103 @@ from lux.vis.Clause import Clause from lux.core.frame import LuxDataFrame from typing import List, Union + + class Parser: - """ - The parser takes in the user's input specifications (with string `description` fields), - then generates the Lux internal specification through lux.Clause. - """ - @staticmethod - def parse(intent: List[Union[Clause,str]]) -> List[Clause]: - """ - Given the string description from a list of input Clauses (intent), - assign the appropriate clause.attribute, clause.filter_op, and clause.value. - - Parameters - ---------- - intent : List[Clause] - Underspecified list of lux.Clause objects. + """ + The parser takes in the user's input specifications (with string `description` fields), + then generates the Lux internal specification through lux.Clause. + """ + + @staticmethod + def parse(intent: List[Union[Clause, str]]) -> List[Clause]: + """ + Given the string description from a list of input Clauses (intent), + assign the appropriate clause.attribute, clause.filter_op, and clause.value. + + Parameters + ---------- + intent : List[Clause] + Underspecified list of lux.Clause objects. + + Returns + ------- + List[Clause] + Parsed list of lux.Clause objects. + """ + if type(intent) != list: + raise TypeError( + "Input intent must be a list consisting of string descriptions or lux.Clause objects." + "\nSee more at: https://lux-api.readthedocs.io/en/latest/source/guide/intent.html" + ) + import re - Returns - ------- - List[Clause] - Parsed list of lux.Clause objects. - """ - if type(intent)!=list: - raise TypeError("Input intent must be a list consisting of string descriptions or lux.Clause objects." - "\nSee more at: https://lux-api.readthedocs.io/en/latest/source/guide/intent.html" - ) - import re - # intent = ldf.get_context() - new_context = [] - #checks for and converts users' string inputs into lux specifications - for clause in intent: - valid_values = [] - if isinstance(clause,list): - valid_values = [] - for v in clause: - if type(v) is str: # and v in list(ldf.columns): #TODO: Move validation check to Validator - valid_values.append(v) - temp_spec = Clause(attribute = valid_values) - new_context.append(temp_spec) - elif isinstance(clause,str): - #case where user specifies a filter - if "=" in clause: - eqInd = clause.index("=") - var = clause[0:eqInd] - if "|" in clause: - values = clause[eqInd+1:].split("|") - for v in values: - # if v in ldf.unique_values[var]: #TODO: Move validation check to Validator - valid_values.append(v) - else: - valid_values = clause[eqInd+1:] - # if var in list(ldf.columns): #TODO: Move validation check to Validator - temp_spec = Clause(attribute = var, filter_op = "=", value = valid_values) - new_context.append(temp_spec) - #case where user specifies a variable - else: - if "|" in clause: - values = clause.split("|") - for v in values: - # if v in list(ldf.columns): #TODO: Move validation check to Validator - valid_values.append(v) - else: - valid_values = clause - temp_spec = Clause(attribute = valid_values) - new_context.append(temp_spec) - elif type(clause) is Clause: - new_context.append(clause) - intent = new_context - # ldf._intent = new_context + # intent = ldf.get_context() + new_context = [] + # checks for and converts users' string inputs into lux specifications + for clause in intent: + valid_values = [] + if isinstance(clause, list): + valid_values = [] + for v in clause: + if ( + type(v) is str + ): # and v in list(ldf.columns): #TODO: Move validation check to Validator + valid_values.append(v) + temp_spec = Clause(attribute=valid_values) + new_context.append(temp_spec) + elif isinstance(clause, str): + # case where user specifies a filter + if "=" in clause: + eqInd = clause.index("=") + var = clause[0:eqInd] + if "|" in clause: + values = clause[eqInd + 1 :].split("|") + for v in values: + # if v in ldf.unique_values[var]: #TODO: Move validation check to Validator + valid_values.append(v) + else: + valid_values = clause[eqInd + 1 :] + # if var in list(ldf.columns): #TODO: Move validation check to Validator + temp_spec = Clause(attribute=var, filter_op="=", value=valid_values) + new_context.append(temp_spec) + # case where user specifies a variable + else: + if "|" in clause: + values = clause.split("|") + for v in values: + # if v in list(ldf.columns): #TODO: Move validation check to Validator + valid_values.append(v) + else: + valid_values = clause + temp_spec = Clause(attribute=valid_values) + new_context.append(temp_spec) + elif type(clause) is Clause: + new_context.append(clause) + intent = new_context + # ldf._intent = new_context - for clause in intent: - if (clause.description): - #TODO: Move validation check to Validator - #if ((clause.description in list(ldf.columns)) or clause.description == "?"):# if clause.description in the list of attributes - if any(ext in [">","<","=","!="] for ext in clause.description): # clause.description contain ">","<". or "=" - # then parse it and assign to clause.attribute, clause.filter_op, clause.values - clause.filter_op = re.findall(r'/.*/|>|=|<|>=|<=|!=', clause.description)[0] - split_description = clause.description.split(clause.filter_op) - clause.attribute = split_description[0] - clause.value = split_description[1] - if re.match(r'^-?\d+(?:\.\d+)?$', clause.value): - clause.value = float(clause.value) - elif (type(clause.description) == str): - clause.attribute = clause.description - elif (type(clause.description)==list): - clause.attribute = clause.description - # else: # then it is probably a value - # clause.values = clause.description - return intent - # ldf._intent = intent \ No newline at end of file + for clause in intent: + if clause.description: + # TODO: Move validation check to Validator + # if ((clause.description in list(ldf.columns)) or clause.description == "?"):# if clause.description in the list of attributes + if any( + ext in [">", "<", "=", "!="] for ext in clause.description + ): # clause.description contain ">","<". or "=" + # then parse it and assign to clause.attribute, clause.filter_op, clause.values + clause.filter_op = re.findall( + r"/.*/|>|=|<|>=|<=|!=", clause.description + )[0] + split_description = clause.description.split(clause.filter_op) + clause.attribute = split_description[0] + clause.value = split_description[1] + if re.match(r"^-?\d+(?:\.\d+)?$", clause.value): + clause.value = float(clause.value) + elif type(clause.description) == str: + clause.attribute = clause.description + elif type(clause.description) == list: + clause.attribute = clause.description + # else: # then it is probably a value + # clause.values = clause.description + return intent + # ldf._intent = intent diff --git a/lux/processor/Validator.py b/lux/processor/Validator.py index a3262a20..688a5f05 100644 --- a/lux/processor/Validator.py +++ b/lux/processor/Validator.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,65 +16,85 @@ from lux.core.frame import LuxDataFrame from lux.vis.Clause import Clause from typing import List -from lux.utils.date_utils import is_datetime_series,is_datetime_string +from lux.utils.date_utils import is_datetime_series, is_datetime_string import warnings + + class Validator: - ''' - Contains methods for validating lux.Clause objects in the intent. - ''' - def __init__(self): - self.name = "Validator" + """ + Contains methods for validating lux.Clause objects in the intent. + """ + + def __init__(self): + self.name = "Validator" + + def __repr__(self): + return f"" - def __repr__(self): - return f"" + @staticmethod + def validate_intent(intent: List[Clause], ldf: LuxDataFrame) -> None: + """ + Validates input specifications from the user to find inconsistencies and errors. - @staticmethod - def validate_intent(intent: List[Clause], ldf:LuxDataFrame) -> None: - """ - Validates input specifications from the user to find inconsistencies and errors. + Parameters + ---------- + ldf : lux.core.frame + LuxDataFrame with underspecified intent. - Parameters - ---------- - ldf : lux.core.frame - LuxDataFrame with underspecified intent. + Returns + ------- + None - Returns - ------- - None + Raises + ------ + ValueError + Ensures input intent are consistent with DataFrame content. - Raises - ------ - ValueError - Ensures input intent are consistent with DataFrame content. - - """ + """ - def validate_clause(clause): - if not((clause.attribute and clause.attribute == "?") or (clause.value and clause.value=="?")): - if isinstance(clause.attribute,list): - for attr in clause.attribute: - if attr not in list(ldf.columns): - warnings.warn(f"The input attribute '{attr}' does not exist in the DataFrame.") - else: - if (clause.attribute!="Record"): - #we don't value check datetime since datetime can take filter values that don't exactly match the exact TimeStamp representation - if (clause.attribute and not is_datetime_string(clause.attribute)): - if not clause.attribute in list(ldf.columns): - warnings.warn(f"The input attribute '{clause.attribute}' does not exist in the DataFrame.") - if (clause.value and clause.attribute and clause.filter_op=="="): - series = ldf[clause.attribute] - if (not is_datetime_series(series)): - if isinstance(clause.value, list): - vals = clause.value - else: - vals = [clause.value] - for val in vals: - if (val not in series.values):#(not series.str.contains(val).any()): - warnings.warn(f"The input value '{val}' does not exist for the attribute '{clause.attribute}' for the DataFrame.") + def validate_clause(clause): + if not ( + (clause.attribute and clause.attribute == "?") + or (clause.value and clause.value == "?") + ): + if isinstance(clause.attribute, list): + for attr in clause.attribute: + if attr not in list(ldf.columns): + warnings.warn( + f"The input attribute '{attr}' does not exist in the DataFrame." + ) + else: + if clause.attribute != "Record": + # we don't value check datetime since datetime can take filter values that don't exactly match the exact TimeStamp representation + if clause.attribute and not is_datetime_string( + clause.attribute + ): + if not clause.attribute in list(ldf.columns): + warnings.warn( + f"The input attribute '{clause.attribute}' does not exist in the DataFrame." + ) + if ( + clause.value + and clause.attribute + and clause.filter_op == "=" + ): + series = ldf[clause.attribute] + if not is_datetime_series(series): + if isinstance(clause.value, list): + vals = clause.value + else: + vals = [clause.value] + for val in vals: + if ( + val not in series.values + ): # (not series.str.contains(val).any()): + warnings.warn( + f"The input value '{val}' does not exist for the attribute '{clause.attribute}' for the DataFrame." + ) - for clause in intent: - if type(clause) is list: - for s in clause: - validate_clause(s) - else: - validate_clause(clause) \ No newline at end of file + for clause in intent: + if type(clause) is list: + for s in clause: + validate_clause(s) + else: + validate_clause(clause) diff --git a/lux/processor/__init__.py b/lux/processor/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/processor/__init__.py +++ b/lux/processor/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/utils/__init__.py b/lux/utils/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/utils/__init__.py +++ b/lux/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/utils/date_utils.py b/lux/utils/date_utils.py index 497db71e..eb067ea6 100644 --- a/lux/utils/date_utils.py +++ b/lux/utils/date_utils.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,103 +14,121 @@ import pandas as pd -def date_formatter(time_stamp,ldf): - """ - Given a numpy timestamp and ldf, inspects which date granularity is appropriate and reformats timestamp accordingly - - Example - ---------- - For changing granularity the results differ as so. - days: '2020-01-01' -> '2020-1-1' - months: '2020-01-01' -> '2020-1' - years: '2020-01-01' -> '2020' - - Parameters - ---------- - time_stamp: np.datetime64 - timestamp object holding the date information - ldf : lux.core.frame - LuxDataFrame with a temporal field - - Returns - ------- - date_str: str - A reformatted version of the time_stamp according to granularity - """ - datetime = pd.to_datetime(time_stamp) - if ldf.data_type["temporal"]: - date_column = ldf[ldf.data_type["temporal"][0]] # assumes only one temporal column, may need to change this function to recieve multiple temporal columns in the future - granularity = compute_date_granularity(date_column) - date_str = "" - if granularity == "year": - date_str += str(datetime.year) - elif granularity == "month": - date_str += str(datetime.year)+ "-" + str(datetime.month) - elif granularity == "day": - date_str += str(datetime.year) +"-"+ str(datetime.month) +"-"+ str(datetime.day) - else: - # non supported granularity - return datetime.date() - - return date_str - - -def compute_date_granularity(date_column:pd.core.series.Series): - """ - Given a temporal column (pandas.core.series.Series), finds out the granularity of dates. - - Example - ---------- - ['2018-01-01', '2019-01-02', '2018-01-03'] -> "day" - ['2018-01-01', '2019-02-01', '2018-03-01'] -> "month" - ['2018-01-01', '2019-01-01', '2020-01-01'] -> "year" - - Parameters - ---------- - date_column: pandas.core.series.Series - Column series with datetime type - - Returns - ------- - field: str - A str specifying the granularity of dates for the inspected temporal column - """ - date_fields = ["day", "month", "year"] #supporting a limited set of Vega-Lite TimeUnit (https://vega.github.io/vega-lite/docs/timeunit.html) - date_index = pd.DatetimeIndex(date_column) - for field in date_fields: - if hasattr(date_index,field) and len(getattr(date_index, field).unique()) != 1 : #can be changed to sum(getattr(date_index, field)) != 0 - return field - return "year" #if none, then return year by default -def is_datetime_series(series:pd.Series) -> bool: - - """ - Check if the Series object is of datetime type - - Parameters - ---------- - series : pd.Series - - Returns - ------- - is_date: bool - """ - return pd.api.types.is_datetime64_any_dtype(series) or pd.api.types.is_period_dtype(series) -def is_datetime_string(string:str)-> bool: - """ - Check if the string is date-like. - - Parameters - ---------- - string : str - - Returns - ------- - is_date: bool - """ - from dateutil.parser import parse - try: - parse(string) - return True - - except ValueError: - return False \ No newline at end of file + +def date_formatter(time_stamp, ldf): + """ + Given a numpy timestamp and ldf, inspects which date granularity is appropriate and reformats timestamp accordingly + + Example + ---------- + For changing granularity the results differ as so. + days: '2020-01-01' -> '2020-1-1' + months: '2020-01-01' -> '2020-1' + years: '2020-01-01' -> '2020' + + Parameters + ---------- + time_stamp: np.datetime64 + timestamp object holding the date information + ldf : lux.core.frame + LuxDataFrame with a temporal field + + Returns + ------- + date_str: str + A reformatted version of the time_stamp according to granularity + """ + datetime = pd.to_datetime(time_stamp) + if ldf.data_type["temporal"]: + date_column = ldf[ + ldf.data_type["temporal"][0] + ] # assumes only one temporal column, may need to change this function to recieve multiple temporal columns in the future + granularity = compute_date_granularity(date_column) + date_str = "" + if granularity == "year": + date_str += str(datetime.year) + elif granularity == "month": + date_str += str(datetime.year) + "-" + str(datetime.month) + elif granularity == "day": + date_str += ( + str(datetime.year) + "-" + str(datetime.month) + "-" + str(datetime.day) + ) + else: + # non supported granularity + return datetime.date() + + return date_str + + +def compute_date_granularity(date_column: pd.core.series.Series): + """ + Given a temporal column (pandas.core.series.Series), finds out the granularity of dates. + + Example + ---------- + ['2018-01-01', '2019-01-02', '2018-01-03'] -> "day" + ['2018-01-01', '2019-02-01', '2018-03-01'] -> "month" + ['2018-01-01', '2019-01-01', '2020-01-01'] -> "year" + + Parameters + ---------- + date_column: pandas.core.series.Series + Column series with datetime type + + Returns + ------- + field: str + A str specifying the granularity of dates for the inspected temporal column + """ + date_fields = [ + "day", + "month", + "year", + ] # supporting a limited set of Vega-Lite TimeUnit (https://vega.github.io/vega-lite/docs/timeunit.html) + date_index = pd.DatetimeIndex(date_column) + for field in date_fields: + if ( + hasattr(date_index, field) and len(getattr(date_index, field).unique()) != 1 + ): # can be changed to sum(getattr(date_index, field)) != 0 + return field + return "year" # if none, then return year by default + + +def is_datetime_series(series: pd.Series) -> bool: + + """ + Check if the Series object is of datetime type + + Parameters + ---------- + series : pd.Series + + Returns + ------- + is_date: bool + """ + return pd.api.types.is_datetime64_any_dtype(series) or pd.api.types.is_period_dtype( + series + ) + + +def is_datetime_string(string: str) -> bool: + """ + Check if the string is date-like. + + Parameters + ---------- + string : str + + Returns + ------- + is_date: bool + """ + from dateutil.parser import parse + + try: + parse(string) + return True + + except ValueError: + return False diff --git a/lux/utils/message.py b/lux/utils/message.py index 57bf5dde..638fd581 100644 --- a/lux/utils/message.py +++ b/lux/utils/message.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,23 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. + class Message: def __init__(self): self.messages = [] - def add_unique(self,item,priority=-1): - msg = {"text":item,"priority":priority} - if (msg not in self.messages): + + def add_unique(self, item, priority=-1): + msg = {"text": item, "priority": priority} + if msg not in self.messages: self.messages.append(msg) - def add(self,item,priority=-1): - self.messages.append({"text":item,"priority":priority}) + + def add(self, item, priority=-1): + self.messages.append({"text": item, "priority": priority}) + def to_html(self): - if (len(self.messages)==0): + if len(self.messages) == 0: return "" else: - sorted_msgs = sorted(self.messages, key = lambda i: i['priority'],reverse=True) + sorted_msgs = sorted( + self.messages, key=lambda i: i["priority"], reverse=True + ) html = "
    " for msg in sorted_msgs: msgTxt = msg["text"] - html+=f"
  • {msgTxt}
  • " + html += f"
  • {msgTxt}
  • " html += "
" - return html \ No newline at end of file + return html diff --git a/lux/utils/utils.py b/lux/utils/utils.py index e28d931b..148509db 100644 --- a/lux/utils/utils.py +++ b/lux/utils/utils.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,59 +12,83 @@ # See the License for the specific language governing permissions and # limitations under the License. import pandas as pd + + def convert_to_list(x): - ''' - "a" --> ["a"] - ["a","b"] --> ["a","b"] - ''' - if type(x) != list: - return [x] - else: - return x + """ + "a" --> ["a"] + ["a","b"] --> ["a","b"] + """ + if type(x) != list: + return [x] + else: + return x + def pandas_to_lux(df): - from lux.core.frame import LuxDataFrame - values = df.values.tolist() - ldf = LuxDataFrame(values, columns = df.columns) - return(ldf) + from lux.core.frame import LuxDataFrame + + values = df.values.tolist() + ldf = LuxDataFrame(values, columns=df.columns) + return ldf + def get_attrs_specs(intent): - if (intent is None): return [] - spec_obj = list(filter(lambda x: x.value=="", intent)) - return spec_obj + if intent is None: + return [] + spec_obj = list(filter(lambda x: x.value == "", intent)) + return spec_obj + def get_filter_specs(intent): - if (intent is None): return [] - spec_obj = list(filter(lambda x: x.value!="", intent)) - return spec_obj + if intent is None: + return [] + spec_obj = list(filter(lambda x: x.value != "", intent)) + return spec_obj + def check_import_lux_widget(): - import pkgutil - if (pkgutil.find_loader("luxwidget") is None): - raise Exception("luxwidget is not installed. Run `pip install luxwidget' to install the Jupyter widget.\nSee more at: https://github.com/lux-org/lux-widget") + import pkgutil + + if pkgutil.find_loader("luxwidget") is None: + raise Exception( + "luxwidget is not installed. Run `pip install luxwidget' to install the Jupyter widget.\nSee more at: https://github.com/lux-org/lux-widget" + ) + def get_agg_title(clause): - if (clause.aggregation is None): - return f'{clause.attribute}' - elif (clause.attribute=="Record"): - return f'Number of Records' - else: - return f'{clause._aggregation_name.capitalize()} of {clause.attribute}' -def check_if_id_like(df,attribute): - import re - # Strong signals - high_cardinality = df.cardinality[attribute]>500 # so that aggregated reset_index fields don't get misclassified - attribute_contain_id = re.search(r'id',str(attribute)) is not None - almost_all_vals_unique = df.cardinality[attribute] >=0.98* len(df) - is_string = pd.api.types.is_string_dtype(df[attribute]) - if (is_string): - # For string IDs, usually serial numbers or codes with alphanumerics have a consistent length (eg., CG-39405) with little deviation. For a high cardinality string field but not ID field (like Name or Brand), there is less uniformity across the string lengths. - if (len(df)>50): - sampled = df[attribute].sample(50,random_state=99) - else: - sampled = df[attribute] - str_length_uniformity = sampled.apply(lambda x: type(x)==str and len(x)).std() < 3 - return high_cardinality and (attribute_contain_id or almost_all_vals_unique) and str_length_uniformity - else: - # TODO: Could probably add some type of entropy measure (since the binned id fields are usually very even) - return high_cardinality and (attribute_contain_id or almost_all_vals_unique) \ No newline at end of file + if clause.aggregation is None: + return f"{clause.attribute}" + elif clause.attribute == "Record": + return f"Number of Records" + else: + return f"{clause._aggregation_name.capitalize()} of {clause.attribute}" + + +def check_if_id_like(df, attribute): + import re + + # Strong signals + high_cardinality = ( + df.cardinality[attribute] > 500 + ) # so that aggregated reset_index fields don't get misclassified + attribute_contain_id = re.search(r"id", str(attribute)) is not None + almost_all_vals_unique = df.cardinality[attribute] >= 0.98 * len(df) + is_string = pd.api.types.is_string_dtype(df[attribute]) + if is_string: + # For string IDs, usually serial numbers or codes with alphanumerics have a consistent length (eg., CG-39405) with little deviation. For a high cardinality string field but not ID field (like Name or Brand), there is less uniformity across the string lengths. + if len(df) > 50: + sampled = df[attribute].sample(50, random_state=99) + else: + sampled = df[attribute] + str_length_uniformity = ( + sampled.apply(lambda x: type(x) == str and len(x)).std() < 3 + ) + return ( + high_cardinality + and (attribute_contain_id or almost_all_vals_unique) + and str_length_uniformity + ) + else: + # TODO: Could probably add some type of entropy measure (since the binned id fields are usually very even) + return high_cardinality and (attribute_contain_id or almost_all_vals_unique) diff --git a/lux/vis/Clause.py b/lux/vis/Clause.py index 2d313de3..b4faff52 100644 --- a/lux/vis/Clause.py +++ b/lux/vis/Clause.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,120 +13,136 @@ # limitations under the License. import typing + + class Clause: - """ - Clause is the object representation of a single unit of the specification. - """ + """ + Clause is the object representation of a single unit of the specification. + """ + + def __init__( + self, + description: typing.Union[str, list] = "", + attribute: typing.Union[str, list] = "", + value: typing.Union[str, list] = "", + filter_op: str = "=", + channel: str = "", + data_type: str = "", + data_model: str = "", + aggregation: typing.Union[str, callable] = "", + bin_size: int = 0, + weight: float = 1, + sort: str = "", + exclude: typing.Union[str, list] = "", + ): + """ + + Parameters + ---------- + description : typing.Union[str,list], optional + Convenient shorthand description of specification, parser parses description into other properties (attribute, value, filter_op), by default "" + attribute : typing.Union[str,list], optional + Specified attribute(s) of interest, by default "" + By providing a list of attributes (e.g., [Origin,Brand]), user is interested in either one of the attribute (i.e., Origin or Brand). + value : typing.Union[str,list], optional + Specified value(s) of interest, by default "" + By providing a list of values (e.g., ["USA","Europe"]), user is interested in either one of the attribute (i.e., USA or Europe). + filter_op : str, optional + Filter operation of interest. + Possible values: '=', '<', '>', '<=', '>=', '!=', by default "=" + channel : str, optional + Encoding channel where the specified attribute should be placed. + Possible values: 'x','y','color', by default "" + data_type : str, optional + Data type for the specified attribute. + Possible values: 'nominal', 'quantitative','temporal', by default "" + data_model : str, optional + Data model for the specified attribute + Possible values: 'dimension', 'measure', by default "" + aggregation : typing.Union[str,callable], optional + Aggregation function for specified attribute, by default "" set as 'mean' + Possible values: 'sum','mean', and others string shorthand or functions supported by Pandas.aggregate (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.aggregate.html), including numpy aggregation functions (e.g., np.ptp), by default "" + Input `None` means no aggregation should be applied (e.g., data has been pre-aggregated) + bin_size : int, optional + Number of bins for histograms, by default 0 + weight : float, optional + A number between 0 and 1 indicating the importance of this Clause, by default 1 + sort : str, optional + Specifying whether and how the bar chart should be sorted + Possible values: 'ascending', 'descending', by default "" + """ + # Descriptor + self.description = description + # Description gets compiled to attribute, value, filter_op + self.attribute = attribute + self.value = value + self.filter_op = filter_op + # self.parseDescription() + # Properties + self.channel = channel + self.data_type = data_type + self.data_model = data_model + self.set_aggregation(aggregation) + self.bin_size = bin_size + self.weight = weight + self.sort = sort + self.exclude = exclude + + def get_attr(self): + return self.attribute - def __init__(self, description:typing.Union[str,list] ="",attribute: typing.Union[str,list] ="",value: typing.Union[str,list]="", - filter_op:str ="=", channel:str ="", data_type:str="",data_model:str="", - aggregation:typing.Union[str,callable] = "", bin_size:int=0, weight:float=1,sort:str="", exclude: typing.Union[str,list] =""): - """ + def copy_clause(self): + copied_clause = Clause() + copied_clause.__dict__ = self.__dict__.copy() # just a shallow copy + return copied_clause - Parameters - ---------- - description : typing.Union[str,list], optional - Convenient shorthand description of specification, parser parses description into other properties (attribute, value, filter_op), by default "" - attribute : typing.Union[str,list], optional - Specified attribute(s) of interest, by default "" - By providing a list of attributes (e.g., [Origin,Brand]), user is interested in either one of the attribute (i.e., Origin or Brand). - value : typing.Union[str,list], optional - Specified value(s) of interest, by default "" - By providing a list of values (e.g., ["USA","Europe"]), user is interested in either one of the attribute (i.e., USA or Europe). - filter_op : str, optional - Filter operation of interest. - Possible values: '=', '<', '>', '<=', '>=', '!=', by default "=" - channel : str, optional - Encoding channel where the specified attribute should be placed. - Possible values: 'x','y','color', by default "" - data_type : str, optional - Data type for the specified attribute. - Possible values: 'nominal', 'quantitative','temporal', by default "" - data_model : str, optional - Data model for the specified attribute - Possible values: 'dimension', 'measure', by default "" - aggregation : typing.Union[str,callable], optional - Aggregation function for specified attribute, by default "" set as 'mean' - Possible values: 'sum','mean', and others string shorthand or functions supported by Pandas.aggregate (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.aggregate.html), including numpy aggregation functions (e.g., np.ptp), by default "" - Input `None` means no aggregation should be applied (e.g., data has been pre-aggregated) - bin_size : int, optional - Number of bins for histograms, by default 0 - weight : float, optional - A number between 0 and 1 indicating the importance of this Clause, by default 1 - sort : str, optional - Specifying whether and how the bar chart should be sorted - Possible values: 'ascending', 'descending', by default "" - """ - # Descriptor - self.description = description - # Description gets compiled to attribute, value, filter_op - self.attribute = attribute - self.value = value - self.filter_op = filter_op - # self.parseDescription() - # Properties - self.channel = channel - self.data_type = data_type - self.data_model = data_model - self.set_aggregation(aggregation) - self.bin_size = bin_size - self.weight = weight - self.sort = sort - self.exclude = exclude + def set_aggregation(self, aggregation: typing.Union[str, callable]): + """ + Sets the aggregation function of Clause, + while updating _aggregation_name internally - def get_attr(self): - return self.attribute - - def copy_clause(self): - copied_clause = Clause() - copied_clause.__dict__ = self.__dict__.copy() # just a shallow copy - return(copied_clause) + Parameters + ---------- + aggregation : typing.Union[str,callable] + """ + self.aggregation = aggregation + # If aggregation input is a function (e.g., np.std), get the string name of the function for plotting + if hasattr(self.aggregation, "__name__"): + self._aggregation_name = self.aggregation.__name__ + else: + self._aggregation_name = self.aggregation - def set_aggregation(self,aggregation:typing.Union[str,callable]): - """ - Sets the aggregation function of Clause, - while updating _aggregation_name internally + def to_string(self): + if isinstance(self.attribute, list): + clauseStr = "|".join(self.attribute) + elif self.value == "": + clauseStr = self.attribute + else: + clauseStr = f"{self.attribute}{self.filter_op}{self.value}" + return clauseStr - Parameters - ---------- - aggregation : typing.Union[str,callable] - """ - self.aggregation = aggregation - # If aggregation input is a function (e.g., np.std), get the string name of the function for plotting - if hasattr(self.aggregation,'__name__'): - self._aggregation_name = self.aggregation.__name__ - else: - self._aggregation_name = self.aggregation - def to_string(self): - if isinstance(self.attribute,list): - clauseStr = '|'.join(self.attribute) - elif (self.value==""): - clauseStr = self.attribute - else: - clauseStr = f"{self.attribute}{self.filter_op}{self.value}" - return clauseStr - def __repr__(self): - attributes = [] - if self.description != "": - attributes.append(" description: " + self.description) - if self.channel != "": - attributes.append(" channel: " + self.channel) - if len(self.attribute) != 0: - attributes.append(" attribute: " + str(self.attribute)) - if self.filter_op != "=": - attributes.append(f" filter_op: {str(self.filter_op)}" ) - if self.aggregation != "" and self.aggregation is not None: - attributes.append(" aggregation: " + self._aggregation_name) - if self.value!="" or len(self.value) != 0 : - attributes.append(" value: " + str(self.value)) - if self.data_model != "": - attributes.append(" data_model: " + self.data_model) - if len(self.data_type) != 0: - attributes.append(" data_type: " + str(self.data_type)) - if self.bin_size != None: - attributes.append(" bin_size: " + str(self.bin_size)) - if len(self.exclude) != 0: - attributes.append(" exclude: " + str(self.exclude)) - attributes[0] = "" - filter_intents = None - channels, additional_channels = [], [] - for clause in self._inferred_intent: - - if hasattr(clause,"value"): - if clause.value != "": - filter_intents = clause - if hasattr(clause,"attribute"): - if clause.attribute != "": - if clause.aggregation != "" and clause.aggregation is not None: - attribute = clause._aggregation_name.upper() + "(" + clause.attribute + ")" - elif clause.bin_size > 0: - attribute = "BIN(" + clause.attribute + ")" - else: - attribute = clause.attribute - if clause.channel == "x": - channels.insert(0, [clause.channel, attribute]) - elif clause.channel == "y": - channels.insert(1, [clause.channel, attribute]) - elif clause.channel != "": - additional_channels.append([clause.channel, attribute]) - - channels.extend(additional_channels) - str_channels = "" - for channel in channels: - str_channels += channel[0] + ": " + channel[1] + ", " - - if filter_intents: - return f"" - else: - return f"" - @property - def data(self): - return self._vis_data - @property - def code(self): - return self._code - @property - def mark(self): - return self._mark - @property - def min_max(self): - return self._min_max - @property - def intent(self): - return self._intent - @intent.setter - def intent(self, intent:List[Clause]) -> None: - self.set_intent(intent) - def set_intent(self, intent:List[Clause]) -> None: - """ - Sets the intent of the Vis and refresh the source based on the new intent - - Parameters - ---------- - intent : List[Clause] - Query specifying the desired VisList - """ - self._intent = intent - self.refresh_source(self._source) - @property - def plot_config(self): - return self._plot_config - @plot_config.setter - def plot_config(self,config_func:Callable): - """ - Modify plot aesthetic settings to the Vis - Currently only supported for Altair visualizations - - Parameters - ---------- - config_func : typing.Callable - A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output - """ - self._plot_config = config_func - def clear_plot_config(self): - self._plot_config = None - def _repr_html_(self): - from IPython.display import display - check_import_lux_widget() - import luxwidget - if (self.data is None): - raise Exception("No data is populated in Vis. In order to generate data required for the vis, use the 'refresh_source' function to populate the Vis with a data source (e.g., vis.refresh_source(df)).") - else: - from lux.core.frame import LuxDataFrame - widget = luxwidget.LuxWidget( - currentVis= LuxDataFrame.current_vis_to_JSON([self]), - recommendations=[], - intent="", - message = "" - ) - display(widget) - def get_attr_by_attr_name(self,attr_name): - return list(filter(lambda x: x.attribute == attr_name, self._inferred_intent)) - - def get_attr_by_channel(self, channel): - spec_obj = list(filter(lambda x: x.channel == channel and x.value=='' if hasattr(x, "channel") else False, self._inferred_intent)) - return spec_obj - - def get_attr_by_data_model(self, dmodel, exclude_record=False): - if (exclude_record): - return list(filter(lambda x: x.data_model == dmodel and x.value=='' if x.attribute!="Record" and hasattr(x, "data_model") else False, self._inferred_intent)) - else: - return list(filter(lambda x: x.data_model == dmodel and x.value=='' if hasattr(x, "data_model") else False, self._inferred_intent)) - - def get_attr_by_data_type(self, dtype): - return list(filter(lambda x: x.data_type == dtype and x.value=='' if hasattr(x, "data_type") else False, self._inferred_intent)) - - def remove_filter_from_spec(self, value): - new_intent = list(filter(lambda x: x.value != value, self._inferred_intent)) - self.set_intent(new_intent) - - def remove_column_from_spec(self, attribute, remove_first:bool=False): - """ - Removes an attribute from the Vis's clause - - Parameters - ---------- - attribute : str - attribute to be removed - remove_first : bool, optional - Boolean flag to determine whether to remove all instances of the attribute or only one (first) instance, by default False - """ - if (not remove_first): - new_inferred = list(filter(lambda x: x.attribute != attribute, self._inferred_intent)) - self._inferred_intent = new_inferred - self._intent = new_inferred - elif (remove_first): - new_inferred = [] - skip_check = False - for i in range(0, len(self._inferred_intent)): - if self._inferred_intent[i].value=="": # clause is type attribute - column_spec = [] - column_names = self._inferred_intent[i].attribute - # if only one variable in a column, columnName results in a string and not a list so - # you need to differentiate the cases - if isinstance(column_names, list): - for column in column_names: - if (column != attribute) or skip_check: - column_spec.append(column) - elif (remove_first): - remove_first = True - new_inferred.append(Clause(column_spec)) - else: - if column_names != attribute or skip_check: - new_inferred.append(Clause(attribute = column_names)) - elif (remove_first): - skip_check = True - else: - new_inferred.append(self._inferred_intent[i]) - self._intent = new_inferred - self._inferred_intent = new_inferred - - def to_Altair(self, standalone = False) -> str: - """ - Generate minimal Altair code to visualize the Vis - - Parameters - ---------- - standalone : bool, optional - Flag to determine if outputted code uses user-defined variable names or can be run independently, by default False - - Returns - ------- - str - String version of the Altair code. Need to print out the string to apply formatting. - """ - from lux.vislib.altair.AltairRenderer import AltairRenderer - renderer = AltairRenderer(output_type="Altair") - self._code= renderer.create_vis(self, standalone) - return self._code - - def to_VegaLite(self, prettyOutput = True) -> Union[dict,str]: - """ - Generate minimal Vega-Lite code to visualize the Vis - - Returns - ------- - Union[dict,str] - String or Dictionary of the VegaLite JSON specification - """ - import json - from lux.vislib.altair.AltairRenderer import AltairRenderer - renderer = AltairRenderer(output_type="VegaLite") - self._code = renderer.create_vis(self) - if (prettyOutput): - return "** Remove this comment -- Copy Text Below to Vega Editor(vega.github.io/editor) to visualize and edit **\n"+json.dumps(self._code, indent=2) - else: - return self._code - - def render_VSpec(self, renderer="altair"): - if (renderer == "altair"): - return self.to_VegaLite(prettyOutput=False) - - def refresh_source(self, ldf):# -> Vis: - """ - Loading the source data into the Vis by instantiating the specification and - populating the Vis based on the source data, effectively "materializing" the Vis. - - Parameters - ---------- - ldf : LuxDataframe - Input Dataframe to be attached to the Vis - - Returns - ------- - Vis - Complete Vis with fully-specified fields - - See Also - -------- - lux.Vis.VisList.refresh_source - - Note - ---- - Function derives a new _inferred_intent by instantiating the intent specification on the new data - """ - if (ldf is not None): - from lux.processor.Parser import Parser - from lux.processor.Validator import Validator - from lux.processor.Compiler import Compiler - from lux.executor.PandasExecutor import PandasExecutor #TODO: temporary (generalize to executor) - ldf.maintain_metadata() - self._source = ldf - self._inferred_intent = Parser.parse(self._intent) - Validator.validate_intent(self._inferred_intent,ldf) - vlist = Compiler.compile_vis(ldf,self) - ldf.executor.execute(vlist,ldf) - # Copying properties over since we can not redefine `self` within class function - if (len(vlist)>0): - vis = vlist[0] - self.title = vis.title - self._mark = vis._mark - self._inferred_intent = vis._inferred_intent - self._vis_data = vis.data - self._min_max = vis._min_max + """ + + def __init__(self, intent, source=None, title="", score=0.0): + self._intent = intent # This is the user's original intent to Vis + self._inferred_intent = intent # This is the re-written, expanded version of user's original intent (include inferred vis info) + self._source = source # This is the original data that is attached to the Vis + self._vis_data = None # This is the data that represents the Vis (e.g., selected, aggregated, binned) + self._code = None + self._mark = "" + self._min_max = {} + self._plot_config = None + self._postbin = None + self.title = title + self.score = score + self.refresh_source(self._source) + + def __repr__(self): + if self._source is None: + return ( + f"" + ) + filter_intents = None + channels, additional_channels = [], [] + for clause in self._inferred_intent: + + if hasattr(clause, "value"): + if clause.value != "": + filter_intents = clause + if hasattr(clause, "attribute"): + if clause.attribute != "": + if clause.aggregation != "" and clause.aggregation is not None: + attribute = ( + clause._aggregation_name.upper() + + "(" + + clause.attribute + + ")" + ) + elif clause.bin_size > 0: + attribute = "BIN(" + clause.attribute + ")" + else: + attribute = clause.attribute + if clause.channel == "x": + channels.insert(0, [clause.channel, attribute]) + elif clause.channel == "y": + channels.insert(1, [clause.channel, attribute]) + elif clause.channel != "": + additional_channels.append([clause.channel, attribute]) + + channels.extend(additional_channels) + str_channels = "" + for channel in channels: + str_channels += channel[0] + ": " + channel[1] + ", " + + if filter_intents: + return f"" + else: + return ( + f"" + ) + + @property + def data(self): + return self._vis_data + + @property + def code(self): + return self._code + + @property + def mark(self): + return self._mark + + @property + def min_max(self): + return self._min_max + + @property + def intent(self): + return self._intent + + @intent.setter + def intent(self, intent: List[Clause]) -> None: + self.set_intent(intent) + + def set_intent(self, intent: List[Clause]) -> None: + """ + Sets the intent of the Vis and refresh the source based on the new intent + + Parameters + ---------- + intent : List[Clause] + Query specifying the desired VisList + """ + self._intent = intent + self.refresh_source(self._source) + + @property + def plot_config(self): + return self._plot_config + + @plot_config.setter + def plot_config(self, config_func: Callable): + """ + Modify plot aesthetic settings to the Vis + Currently only supported for Altair visualizations + + Parameters + ---------- + config_func : typing.Callable + A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output + """ + self._plot_config = config_func + + def clear_plot_config(self): + self._plot_config = None + + def _repr_html_(self): + from IPython.display import display + + check_import_lux_widget() + import luxwidget + + if self.data is None: + raise Exception( + "No data is populated in Vis. In order to generate data required for the vis, use the 'refresh_source' function to populate the Vis with a data source (e.g., vis.refresh_source(df))." + ) + else: + from lux.core.frame import LuxDataFrame + + widget = luxwidget.LuxWidget( + currentVis=LuxDataFrame.current_vis_to_JSON([self]), + recommendations=[], + intent="", + message="", + ) + display(widget) + + def get_attr_by_attr_name(self, attr_name): + return list(filter(lambda x: x.attribute == attr_name, self._inferred_intent)) + + def get_attr_by_channel(self, channel): + spec_obj = list( + filter( + lambda x: x.channel == channel and x.value == "" + if hasattr(x, "channel") + else False, + self._inferred_intent, + ) + ) + return spec_obj + + def get_attr_by_data_model(self, dmodel, exclude_record=False): + if exclude_record: + return list( + filter( + lambda x: x.data_model == dmodel and x.value == "" + if x.attribute != "Record" and hasattr(x, "data_model") + else False, + self._inferred_intent, + ) + ) + else: + return list( + filter( + lambda x: x.data_model == dmodel and x.value == "" + if hasattr(x, "data_model") + else False, + self._inferred_intent, + ) + ) + + def get_attr_by_data_type(self, dtype): + return list( + filter( + lambda x: x.data_type == dtype and x.value == "" + if hasattr(x, "data_type") + else False, + self._inferred_intent, + ) + ) + + def remove_filter_from_spec(self, value): + new_intent = list(filter(lambda x: x.value != value, self._inferred_intent)) + self.set_intent(new_intent) + + def remove_column_from_spec(self, attribute, remove_first: bool = False): + """ + Removes an attribute from the Vis's clause + + Parameters + ---------- + attribute : str + attribute to be removed + remove_first : bool, optional + Boolean flag to determine whether to remove all instances of the attribute or only one (first) instance, by default False + """ + if not remove_first: + new_inferred = list( + filter(lambda x: x.attribute != attribute, self._inferred_intent) + ) + self._inferred_intent = new_inferred + self._intent = new_inferred + elif remove_first: + new_inferred = [] + skip_check = False + for i in range(0, len(self._inferred_intent)): + if self._inferred_intent[i].value == "": # clause is type attribute + column_spec = [] + column_names = self._inferred_intent[i].attribute + # if only one variable in a column, columnName results in a string and not a list so + # you need to differentiate the cases + if isinstance(column_names, list): + for column in column_names: + if (column != attribute) or skip_check: + column_spec.append(column) + elif remove_first: + remove_first = True + new_inferred.append(Clause(column_spec)) + else: + if column_names != attribute or skip_check: + new_inferred.append(Clause(attribute=column_names)) + elif remove_first: + skip_check = True + else: + new_inferred.append(self._inferred_intent[i]) + self._intent = new_inferred + self._inferred_intent = new_inferred + + def to_Altair(self, standalone=False) -> str: + """ + Generate minimal Altair code to visualize the Vis + + Parameters + ---------- + standalone : bool, optional + Flag to determine if outputted code uses user-defined variable names or can be run independently, by default False + + Returns + ------- + str + String version of the Altair code. Need to print out the string to apply formatting. + """ + from lux.vislib.altair.AltairRenderer import AltairRenderer + + renderer = AltairRenderer(output_type="Altair") + self._code = renderer.create_vis(self, standalone) + return self._code + + def to_VegaLite(self, prettyOutput=True) -> Union[dict, str]: + """ + Generate minimal Vega-Lite code to visualize the Vis + + Returns + ------- + Union[dict,str] + String or Dictionary of the VegaLite JSON specification + """ + import json + from lux.vislib.altair.AltairRenderer import AltairRenderer + + renderer = AltairRenderer(output_type="VegaLite") + self._code = renderer.create_vis(self) + if prettyOutput: + return ( + "** Remove this comment -- Copy Text Below to Vega Editor(vega.github.io/editor) to visualize and edit **\n" + + json.dumps(self._code, indent=2) + ) + else: + return self._code + + def render_VSpec(self, renderer="altair"): + if renderer == "altair": + return self.to_VegaLite(prettyOutput=False) + + def refresh_source(self, ldf): # -> Vis: + """ + Loading the source data into the Vis by instantiating the specification and + populating the Vis based on the source data, effectively "materializing" the Vis. + + Parameters + ---------- + ldf : LuxDataframe + Input Dataframe to be attached to the Vis + + Returns + ------- + Vis + Complete Vis with fully-specified fields + + See Also + -------- + lux.Vis.VisList.refresh_source + + Note + ---- + Function derives a new _inferred_intent by instantiating the intent specification on the new data + """ + if ldf is not None: + from lux.processor.Parser import Parser + from lux.processor.Validator import Validator + from lux.processor.Compiler import Compiler + from lux.executor.PandasExecutor import ( + PandasExecutor, + ) # TODO: temporary (generalize to executor) + + ldf.maintain_metadata() + self._source = ldf + self._inferred_intent = Parser.parse(self._intent) + Validator.validate_intent(self._inferred_intent, ldf) + vlist = Compiler.compile_vis(ldf, self) + ldf.executor.execute(vlist, ldf) + # Copying properties over since we can not redefine `self` within class function + if len(vlist) > 0: + vis = vlist[0] + self.title = vis.title + self._mark = vis._mark + self._inferred_intent = vis._inferred_intent + self._vis_data = vis.data + self._min_max = vis._min_max diff --git a/lux/vis/VisList.py b/lux/vis/VisList.py index 1aec2f62..86ccf1a1 100644 --- a/lux/vis/VisList.py +++ b/lux/vis/VisList.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,278 +19,331 @@ from lux.vis.Vis import Vis from lux.vis.Clause import Clause import warnings -class VisList(): - """VisList is a list of Vis objects. - """ - def __init__(self,input_lst:Union[List[Vis],List[Clause]],source=None): - # Overloaded Constructor - self._source = source - self._input_lst = input_lst - if len(input_lst)>0: - if (self._is_vis_input()): - self._collection = input_lst - self._intent = [] - else: - self._intent = input_lst - self._collection = [] - else: - self._collection = [] - self._intent = [] - self._widget = None - self.refresh_source(self._source) - @property - def intent(self): - return self._intent - @intent.setter - def intent(self, intent:List[Clause]) -> None: - self.set_intent(intent) - def set_intent(self, intent:List[Clause]) -> None: - """ - Sets the intent of the VisList and refresh the source based on the new clause - - Parameters - ---------- - intent : List[Clause] - Query specifying the desired VisList - """ - self._intent = intent - self.refresh_source(self._source) - @property - def exported(self) -> VisList: - """ - Get selected visualizations as exported Vis List - - Notes + + +class VisList: + """VisList is a list of Vis objects.""" + + def __init__(self, input_lst: Union[List[Vis], List[Clause]], source=None): + # Overloaded Constructor + self._source = source + self._input_lst = input_lst + if len(input_lst) > 0: + if self._is_vis_input(): + self._collection = input_lst + self._intent = [] + else: + self._intent = input_lst + self._collection = [] + else: + self._collection = [] + self._intent = [] + self._widget = None + self.refresh_source(self._source) + + @property + def intent(self): + return self._intent + + @intent.setter + def intent(self, intent: List[Clause]) -> None: + self.set_intent(intent) + + def set_intent(self, intent: List[Clause]) -> None: + """ + Sets the intent of the VisList and refresh the source based on the new clause + Parameters + ---------- + intent : List[Clause] + Query specifying the desired VisList + """ + self._intent = intent + self.refresh_source(self._source) + + @property + def exported(self) -> VisList: + """ + Get selected visualizations as exported Vis List + Notes ----- - Convert the _selectedVisIdxs dictionary into a programmable VisList - Example _selectedVisIdxs : - {'Vis List': [0, 2]} - - Returns - ------- - VisList - return a VisList of selected visualizations. -> VisList(v1, v2...) - """ - if not hasattr(self,"widget"): - warnings.warn( - "\nNo widget attached to the VisList." - "Please assign VisList to an output variable.\n" - "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips" - , stacklevel=2) - return [] - exported_vis_lst =self._widget._selectedVisIdxs - if (exported_vis_lst=={}): - warnings.warn( - "\nNo visualization selected to export.\n" - "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips" - ,stacklevel=2) - return [] - else: - exported_vis = VisList(list(map(self.__getitem__, exported_vis_lst["Vis List"]))) - return exported_vis - def remove_duplicates(self) -> None: - """ - Removes duplicate visualizations in Vis List - """ - self._collection = list(set(self._collection)) - - def remove_index(self, index): - self._collection.pop(index) - - def _is_vis_input(self): - if (type(self._input_lst[0])==Vis): - return True - elif (type(self._input_lst[0])==Clause): - return False - def __getitem__(self, key): - return self._collection[key] - def __setitem__(self, key, value): - self._collection[key] = value - def __len__(self): - return len(self._collection) - def __repr__(self): - if len(self._collection) == 0: - return str(self._input_lst) - x_channel = "" - y_channel = "" - largest_mark = 0 - largest_filter = 0 - for vis in self._collection: #finds longest x attribute among all visualizations - filter_intents = None - for clause in vis._inferred_intent: - if clause.value != "": - filter_intents = clause - - if (clause.aggregation != "" and clause.aggregation is not None): - attribute = clause._aggregation_name.upper() + "(" + clause.attribute + ")" - elif clause.bin_size > 0: - attribute = "BIN(" + clause.attribute + ")" - else: - attribute = clause.attribute - - if clause.channel == "x" and len(x_channel) < len(attribute): - x_channel = attribute - if clause.channel == "y" and len(y_channel) < len(attribute): - y_channel = attribute - if len(vis.mark) > largest_mark: - largest_mark = len(vis.mark) - if filter_intents and len(str(filter_intents.value)) + len(filter_intents.attribute) > largest_filter: - largest_filter = len(str(filter_intents.value)) + len(filter_intents.attribute) - vis_repr = [] - largest_x_length = len(x_channel) - largest_y_length = len(y_channel) - for vis in self._collection: #pads the shorter visualizations with spaces before the y attribute - filter_intents = None - x_channel = "" - y_channel = "" - additional_channels = [] - for clause in vis._inferred_intent: - if clause.value != "": - filter_intents = clause - - if (clause.aggregation != "" and clause.aggregation is not None and vis.mark!='scatter'): - attribute = clause._aggregation_name.upper() + "(" + clause.attribute + ")" - elif clause.bin_size > 0: - attribute = "BIN(" + clause.attribute + ")" - else: - attribute = clause.attribute - - if clause.channel == "x": - x_channel = attribute.ljust(largest_x_length) - elif clause.channel == "y": - y_channel = attribute - elif clause.channel != "": - additional_channels.append([clause.channel, attribute]) - if filter_intents: - y_channel = y_channel.ljust(largest_y_length) - elif largest_filter != 0: - y_channel = y_channel.ljust(largest_y_length + largest_filter + 9) - else: - y_channel = y_channel.ljust(largest_y_length + largest_filter) - if x_channel != "": - x_channel = "x: " + x_channel + ", " - if y_channel != "": - y_channel = "y: " + y_channel - aligned_mark = vis.mark.ljust(largest_mark) - str_additional_channels = "" - for channel in additional_channels: - str_additional_channels += ", " + channel[0] + ": " + channel[1] - if filter_intents: - aligned_filter = " -- [" + filter_intents.attribute + filter_intents.filter_op + str(filter_intents.value) + "]" - aligned_filter = aligned_filter.ljust(largest_filter + 8) - vis_repr.append(f" ") - else: - vis_repr.append(f" ") - return '['+',\n'.join(vis_repr)[1:]+']' - def map(self,function): - # generalized way of applying a function to each element - return map(function, self._collection) - - def get(self,field_name): - # Get the value of the field for all objects in the collection - def get_field(d_obj): - field_val = getattr(d_obj,field_name) - # Might want to write catch error if key not in field - return field_val - return self.map(get_field) - - def set(self,field_name,field_val): - return NotImplemented - def set_plot_config(self,config_func:Callable): - """ - Modify plot aesthetic settings to the Vis List - Currently only supported for Altair visualizations - - Parameters - ---------- - config_func : typing.Callable - A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output - """ - for vis in self._collection: - vis.plot_config = config_func - def clear_plot_config(self): - for vis in self._collection: - vis.plot_config = None - def sort(self, remove_invalid=True, descending = True): - # remove the items that have invalid (-1) score - if (remove_invalid): self._collection = list(filter(lambda x: x.score!=-1,self._collection)) - # sort in-place by “score” by default if available, otherwise user-specified field to sort by - self._collection.sort(key=lambda x: x.score, reverse=descending) - - def topK(self,k): - #sort and truncate list to first K items - self.sort(remove_invalid=True) - return VisList(self._collection[:k]) - def bottomK(self,k): - #sort and truncate list to first K items - self.sort(descending=False,remove_invalid=True) - return VisList(self._collection[:k]) - def normalize_score(self, invert_order = False): - max_score = max(list(self.get("score"))) - for dobj in self._collection: - dobj.score = dobj.score/max_score - if (invert_order): dobj.score = 1 - dobj.score - def _repr_html_(self): - self._widget = None - from IPython.display import display - from lux.core.frame import LuxDataFrame - recommendation = {"action": "Vis List", - "description": "Shows a vis list defined by the intent"} - recommendation["collection"] = self._collection - - check_import_lux_widget() - import luxwidget - recJSON = LuxDataFrame.rec_to_JSON([recommendation]) - self._widget = luxwidget.LuxWidget( - currentVis={}, - recommendations=recJSON, - intent="", - message = "" - ) - display(self._widget) - - def refresh_source(self, ldf) : - """ - Loading the source into the visualizations in the VisList, then populating each visualization - based on the new source data, effectively "materializing" the visualization collection. - - Parameters - ---------- - ldf : LuxDataframe - Input Dataframe to be attached to the VisList - - Returns - ------- - VisList - Complete VisList with fully-specified fields - - See Also - -------- - lux.vis.Vis.refresh_source - - Note - ---- - Function derives a new _inferred_intent by instantiating the intent specification on the new data - """ - if (ldf is not None): - from lux.processor.Parser import Parser - from lux.processor.Validator import Validator - from lux.processor.Compiler import Compiler - self._source = ldf - self._source.maintain_metadata() - if len(self._input_lst)>0: - if (self._is_vis_input()): - compiled_collection = [] - for vis in self._collection: - vis._inferred_intent = Parser.parse(vis._intent) - Validator.validate_intent(vis._inferred_intent,ldf) - vislist = Compiler.compile_vis(ldf,vis) - if (len(vislist)>0): - vis = vislist[0] - compiled_collection.append(vis) - self._collection = compiled_collection - else: - self._inferred_intent = Parser.parse(self._intent) - Validator.validate_intent(self._inferred_intent,ldf) - self._collection = Compiler.compile_intent(ldf,self._inferred_intent) - ldf.executor.execute(self._collection,ldf) + Convert the _selectedVisIdxs dictionary into a programmable VisList + Example _selectedVisIdxs : + {'Vis List': [0, 2]} + + Returns + ------- + VisList + return a VisList of selected visualizations. -> VisList(v1, v2...) + """ + if not hasattr(self, "widget"): + warnings.warn( + "\nNo widget attached to the VisList." + "Please assign VisList to an output variable.\n" + "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips", + stacklevel=2, + ) + return [] + exported_vis_lst = self._widget._selectedVisIdxs + if exported_vis_lst == {}: + warnings.warn( + "\nNo visualization selected to export.\n" + "See more: https://lux-api.readthedocs.io/en/latest/source/guide/FAQ.html#troubleshooting-tips", + stacklevel=2, + ) + return [] + else: + exported_vis = VisList( + list(map(self.__getitem__, exported_vis_lst["Vis List"])) + ) + return exported_vis + + def remove_duplicates(self) -> None: + """ + Removes duplicate visualizations in Vis List + """ + self._collection = list(set(self._collection)) + + def remove_index(self, index): + self._collection.pop(index) + + def _is_vis_input(self): + if type(self._input_lst[0]) == Vis: + return True + elif type(self._input_lst[0]) == Clause: + return False + + def __getitem__(self, key): + return self._collection[key] + + def __setitem__(self, key, value): + self._collection[key] = value + + def __len__(self): + return len(self._collection) + + def __repr__(self): + if len(self._collection) == 0: + return str(self._input_lst) + x_channel = "" + y_channel = "" + largest_mark = 0 + largest_filter = 0 + for ( + vis + ) in self._collection: # finds longest x attribute among all visualizations + filter_intents = None + for clause in vis._inferred_intent: + if clause.value != "": + filter_intents = clause + + if clause.aggregation != "" and clause.aggregation is not None: + attribute = ( + clause._aggregation_name.upper() + "(" + clause.attribute + ")" + ) + elif clause.bin_size > 0: + attribute = "BIN(" + clause.attribute + ")" + else: + attribute = clause.attribute + + if clause.channel == "x" and len(x_channel) < len(attribute): + x_channel = attribute + if clause.channel == "y" and len(y_channel) < len(attribute): + y_channel = attribute + if len(vis.mark) > largest_mark: + largest_mark = len(vis.mark) + if ( + filter_intents + and len(str(filter_intents.value)) + len(filter_intents.attribute) + > largest_filter + ): + largest_filter = len(str(filter_intents.value)) + len( + filter_intents.attribute + ) + vis_repr = [] + largest_x_length = len(x_channel) + largest_y_length = len(y_channel) + for ( + vis + ) in ( + self._collection + ): # pads the shorter visualizations with spaces before the y attribute + filter_intents = None + x_channel = "" + y_channel = "" + additional_channels = [] + for clause in vis._inferred_intent: + if clause.value != "": + filter_intents = clause + + if ( + clause.aggregation != "" + and clause.aggregation is not None + and vis.mark != "scatter" + ): + attribute = ( + clause._aggregation_name.upper() + "(" + clause.attribute + ")" + ) + elif clause.bin_size > 0: + attribute = "BIN(" + clause.attribute + ")" + else: + attribute = clause.attribute + + if clause.channel == "x": + x_channel = attribute.ljust(largest_x_length) + elif clause.channel == "y": + y_channel = attribute + elif clause.channel != "": + additional_channels.append([clause.channel, attribute]) + if filter_intents: + y_channel = y_channel.ljust(largest_y_length) + elif largest_filter != 0: + y_channel = y_channel.ljust(largest_y_length + largest_filter + 9) + else: + y_channel = y_channel.ljust(largest_y_length + largest_filter) + if x_channel != "": + x_channel = "x: " + x_channel + ", " + if y_channel != "": + y_channel = "y: " + y_channel + aligned_mark = vis.mark.ljust(largest_mark) + str_additional_channels = "" + for channel in additional_channels: + str_additional_channels += ", " + channel[0] + ": " + channel[1] + if filter_intents: + aligned_filter = ( + " -- [" + + filter_intents.attribute + + filter_intents.filter_op + + str(filter_intents.value) + + "]" + ) + aligned_filter = aligned_filter.ljust(largest_filter + 8) + vis_repr.append( + f" " + ) + else: + vis_repr.append( + f" " + ) + return "[" + ",\n".join(vis_repr)[1:] + "]" + + def map(self, function): + # generalized way of applying a function to each element + return map(function, self._collection) + + def get(self, field_name): + # Get the value of the field for all objects in the collection + def get_field(d_obj): + field_val = getattr(d_obj, field_name) + # Might want to write catch error if key not in field + return field_val + + return self.map(get_field) + + def set(self, field_name, field_val): + return NotImplemented + + def set_plot_config(self, config_func: Callable): + """ + Modify plot aesthetic settings to the Vis List + Currently only supported for Altair visualizations + Parameters + ---------- + config_func : typing.Callable + A function that takes in an AltairChart (https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html) as input and returns an AltairChart as output + """ + for vis in self._collection: + vis.plot_config = config_func + + def clear_plot_config(self): + for vis in self._collection: + vis.plot_config = None + + def sort(self, remove_invalid=True, descending=True): + # remove the items that have invalid (-1) score + if remove_invalid: + self._collection = list(filter(lambda x: x.score != -1, self._collection)) + # sort in-place by “score” by default if available, otherwise user-specified field to sort by + self._collection.sort(key=lambda x: x.score, reverse=descending) + + def topK(self, k): + # sort and truncate list to first K items + self.sort(remove_invalid=True) + return VisList(self._collection[:k]) + + def bottomK(self, k): + # sort and truncate list to first K items + self.sort(descending=False, remove_invalid=True) + return VisList(self._collection[:k]) + + def normalize_score(self, invert_order=False): + max_score = max(list(self.get("score"))) + for dobj in self._collection: + dobj.score = dobj.score / max_score + if invert_order: + dobj.score = 1 - dobj.score + + def _repr_html_(self): + self._widget = None + from IPython.display import display + from lux.core.frame import LuxDataFrame + + recommendation = { + "action": "Vis List", + "description": "Shows a vis list defined by the intent", + } + recommendation["collection"] = self._collection + + check_import_lux_widget() + import luxwidget + + recJSON = LuxDataFrame.rec_to_JSON([recommendation]) + self._widget = luxwidget.LuxWidget( + currentVis={}, recommendations=recJSON, intent="", message="" + ) + display(self._widget) + + def refresh_source(self, ldf): + """ + Loading the source into the visualizations in the VisList, then populating each visualization + based on the new source data, effectively "materializing" the visualization collection. + Parameters + ---------- + ldf : LuxDataframe + Input Dataframe to be attached to the VisList + Returns + ------- + VisList + Complete VisList with fully-specified fields + + See Also + -------- + lux.vis.Vis.refresh_source + Note + ---- + Function derives a new _inferred_intent by instantiating the intent specification on the new data + """ + if ldf is not None: + from lux.processor.Parser import Parser + from lux.processor.Validator import Validator + from lux.processor.Compiler import Compiler + + self._source = ldf + self._source.maintain_metadata() + if len(self._input_lst) > 0: + if self._is_vis_input(): + compiled_collection = [] + for vis in self._collection: + vis._inferred_intent = Parser.parse(vis._intent) + Validator.validate_intent(vis._inferred_intent, ldf) + vislist = Compiler.compile_vis(ldf, vis) + if len(vislist) > 0: + vis = vislist[0] + compiled_collection.append(vis) + self._collection = compiled_collection + else: + self._inferred_intent = Parser.parse(self._intent) + Validator.validate_intent(self._inferred_intent, ldf) + self._collection = Compiler.compile_intent( + ldf, self._inferred_intent + ) + ldf.executor.execute(self._collection, ldf) diff --git a/lux/vis/__init__.py b/lux/vis/__init__.py index 33956083..0f58f612 100644 --- a/lux/vis/__init__.py +++ b/lux/vis/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .Clause import Clause \ No newline at end of file +from .Clause import Clause diff --git a/lux/vislib/__init__.py b/lux/vislib/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/vislib/__init__.py +++ b/lux/vislib/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/lux/vislib/altair/AltairChart.py b/lux/vislib/altair/AltairChart.py index 0bf0c33e..09a01013 100644 --- a/lux/vislib/altair/AltairChart.py +++ b/lux/vislib/altair/AltairChart.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,70 +15,105 @@ import pandas as pd import altair as alt from lux.utils.date_utils import compute_date_granularity + + class AltairChart: - """ - AltairChart is a representation of a chart. - Common utilities for charts that is independent of chart types should go here. + """ + AltairChart is a representation of a chart. + Common utilities for charts that is independent of chart types should go here. + + See Also + -------- + altair-viz.github.io + + """ + + def __init__(self, vis): + self.vis = vis + self.data = vis.data + self.tooltip = True + # ----- START self.code modification ----- + self.code = "" + self.chart = self.initialize_chart() + # self.add_tooltip() + self.encode_color() + self.add_title() + self.apply_default_config() + + # ----- END self.code modification ----- + + def __repr__(self): + return f"AltairChart <{str(self.vis)}>" + + def add_tooltip(self): + if self.tooltip: + self.chart = self.chart.encode(tooltip=list(self.vis.data.columns)) - See Also - -------- - altair-viz.github.io + def apply_default_config(self): + self.chart = self.chart.configure_title( + fontWeight=500, fontSize=13, font="Helvetica Neue" + ) + self.chart = self.chart.configure_axis( + titleFontWeight=500, + titleFontSize=11, + titleFont="Helvetica Neue", + labelFontWeight=400, + labelFontSize=9, + labelFont="Helvetica Neue", + labelColor="#505050", + ) + self.chart = self.chart.configure_legend( + titleFontWeight=500, + titleFontSize=10, + titleFont="Helvetica Neue", + labelFontWeight=400, + labelFontSize=9, + labelFont="Helvetica Neue", + ) + self.chart = self.chart.properties(width=160, height=150) + self.code += "\nchart = chart.configure_title(fontWeight=500,fontSize=13,font='Helvetica Neue')\n" + self.code += "chart = chart.configure_axis(titleFontWeight=500,titleFontSize=11,titleFont='Helvetica Neue',\n" + self.code += " labelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue',labelColor='#505050')\n" + self.code += "chart = chart.configure_legend(titleFontWeight=500,titleFontSize=10,titleFont='Helvetica Neue',\n" + self.code += ( + " labelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue')\n" + ) + self.code += "chart = chart.properties(width=160,height=150)\n" - """ - def __init__(self, vis): - self.vis = vis - self.data = vis.data - self.tooltip = True - # ----- START self.code modification ----- - self.code = "" - self.chart = self.initialize_chart() - # self.add_tooltip() - self.encode_color() - self.add_title() - self.apply_default_config() + def encode_color(self): + color_attr = self.vis.get_attr_by_channel("color") + if len(color_attr) == 1: + color_attr_name = color_attr[0].attribute + color_attr_type = color_attr[0].data_type + if color_attr_type == "temporal": + timeUnit = compute_date_granularity(self.vis.data[color_attr_name]) + self.chart = self.chart.encode( + color=alt.Color( + color_attr_name, + type=color_attr_type, + timeUnit=timeUnit, + title=color_attr_name, + ) + ) + self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}',timeUnit='{timeUnit}',title='{color_attr_name}'))" + else: + self.chart = self.chart.encode( + color=alt.Color(color_attr_name, type=color_attr_type) + ) + self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}'))\n" + elif len(color_attr) > 1: + raise ValueError( + "There should not be more than one attribute specified in the same channel." + ) - # ----- END self.code modification ----- - def __repr__(self): - return f"AltairChart <{str(self.vis)}>" - def add_tooltip(self): - if (self.tooltip): - self.chart = self.chart.encode(tooltip=list(self.vis.data.columns)) - def apply_default_config(self): - self.chart = self.chart.configure_title(fontWeight=500,fontSize=13,font="Helvetica Neue") - self.chart = self.chart.configure_axis(titleFontWeight=500,titleFontSize=11,titleFont="Helvetica Neue", - labelFontWeight=400,labelFontSize=9,labelFont="Helvetica Neue",labelColor="#505050") - self.chart = self.chart.configure_legend(titleFontWeight=500,titleFontSize=10,titleFont="Helvetica Neue", - labelFontWeight=400,labelFontSize=9,labelFont="Helvetica Neue") - self.chart = self.chart.properties(width=160,height=150) - self.code+= "\nchart = chart.configure_title(fontWeight=500,fontSize=13,font='Helvetica Neue')\n" - self.code+= "chart = chart.configure_axis(titleFontWeight=500,titleFontSize=11,titleFont='Helvetica Neue',\n" - self.code+= " labelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue',labelColor='#505050')\n" - self.code+= "chart = chart.configure_legend(titleFontWeight=500,titleFontSize=10,titleFont='Helvetica Neue',\n" - self.code+= " labelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue')\n" - self.code+= "chart = chart.properties(width=160,height=150)\n" + def add_title(self): + chart_title = self.vis.title + if chart_title: + self.chart = self.chart.encode().properties(title=chart_title) + if self.code != "": + self.code += ( + f"chart = chart.encode().properties(title = '{chart_title}')" + ) - def encode_color(self): - color_attr = self.vis.get_attr_by_channel("color") - if (len(color_attr)==1): - color_attr_name = color_attr[0].attribute - color_attr_type = color_attr[0].data_type - if (color_attr_type=="temporal"): - timeUnit = compute_date_granularity(self.vis.data[color_attr_name]) - self.chart = self.chart.encode(color=alt.Color(color_attr_name,type=color_attr_type,timeUnit=timeUnit,title=color_attr_name)) - self.code+=f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}',timeUnit='{timeUnit}',title='{color_attr_name}'))" - else: - self.chart = self.chart.encode(color=alt.Color(color_attr_name,type=color_attr_type)) - self.code+=f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}'))\n" - elif (len(color_attr)>1): - raise ValueError("There should not be more than one attribute specified in the same channel.") - - def add_title(self): - chart_title = self.vis.title - if chart_title: - self.chart = self.chart.encode().properties( - title = chart_title - ) - if (self.code!=""): - self.code+=f"chart = chart.encode().properties(title = '{chart_title}')" - def initialize_chart(self): - return NotImplemented + def initialize_chart(self): + return NotImplemented diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index 0e9ebe54..2692f72e 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,84 +21,111 @@ from lux.vislib.altair.Histogram import Histogram from lux.vislib.altair.Heatmap import Heatmap + class AltairRenderer: - """ - Renderer for Charts based on Altair (https://altair-viz.github.io/) - """ - def __init__(self,output_type="VegaLite"): - self.output_type = output_type - def __repr__(self): - return f"AltairRenderer" - def create_vis(self,vis, standalone=True): - """ - Input DataObject and return a visualization specification - - Parameters - ---------- - vis: lux.vis.Vis - Input Vis (with data) - standalone: bool - Flag to determine if outputted code uses user-defined variable names or can be run independently - Returns - ------- - chart : altair.Chart - Output Altair Chart Object - """ - # Lazy Evaluation for 2D Binning - if (vis.mark == "scatter" and vis._postbin): - vis._mark = "heatmap" - from lux.executor.PandasExecutor import PandasExecutor - PandasExecutor.execute_2D_binning(vis) - # If a column has a Period dtype, or contains Period objects, convert it back to Datetime - if vis.data is not None: - for attr in list(vis.data.columns): - if pd.api.types.is_period_dtype(vis.data.dtypes[attr]) or isinstance(vis.data[attr].iloc[0], pd.Period): - dateColumn = vis.data[attr] - vis.data[attr] = pd.PeriodIndex(dateColumn.values).to_timestamp() - if pd.api.types.is_interval_dtype(vis.data.dtypes[attr]) or isinstance(vis.data[attr].iloc[0], pd.Interval): - vis.data[attr] = vis.data[attr].astype(str) - if (vis.mark =="histogram"): - chart = Histogram(vis) - elif (vis.mark =="bar"): - chart = BarChart(vis) - elif (vis.mark =="scatter"): - chart = ScatterChart(vis) - elif (vis.mark =="line"): - chart = LineChart(vis) - elif (vis.mark =="heatmap"): - chart = Heatmap(vis) - else: - chart = None + """ + Renderer for Charts based on Altair (https://altair-viz.github.io/) + """ + + def __init__(self, output_type="VegaLite"): + self.output_type = output_type + + def __repr__(self): + return f"AltairRenderer" + + def create_vis(self, vis, standalone=True): + """ + Input DataObject and return a visualization specification + + Parameters + ---------- + vis: lux.vis.Vis + Input Vis (with data) + standalone: bool + Flag to determine if outputted code uses user-defined variable names or can be run independently + Returns + ------- + chart : altair.Chart + Output Altair Chart Object + """ + # Lazy Evaluation for 2D Binning + if vis.mark == "scatter" and vis._postbin: + vis._mark = "heatmap" + from lux.executor.PandasExecutor import PandasExecutor + + PandasExecutor.execute_2D_binning(vis) + # If a column has a Period dtype, or contains Period objects, convert it back to Datetime + if vis.data is not None: + for attr in list(vis.data.columns): + if pd.api.types.is_period_dtype(vis.data.dtypes[attr]) or isinstance( + vis.data[attr].iloc[0], pd.Period + ): + dateColumn = vis.data[attr] + vis.data[attr] = pd.PeriodIndex(dateColumn.values).to_timestamp() + if pd.api.types.is_interval_dtype(vis.data.dtypes[attr]) or isinstance( + vis.data[attr].iloc[0], pd.Interval + ): + vis.data[attr] = vis.data[attr].astype(str) + if vis.mark == "histogram": + chart = Histogram(vis) + elif vis.mark == "bar": + chart = BarChart(vis) + elif vis.mark == "scatter": + chart = ScatterChart(vis) + elif vis.mark == "line": + chart = LineChart(vis) + elif vis.mark == "heatmap": + chart = Heatmap(vis) + else: + chart = None + + if chart: + if vis.plot_config: + chart.chart = vis.plot_config(chart.chart) + if self.output_type == "VegaLite": + chart_dict = chart.chart.to_dict() + # this is a bit of a work around because altair must take a pandas dataframe and we can only generate a luxDataFrame + # chart["data"] = { "values": vis.data.to_dict(orient='records') } + # chart_dict["width"] = 160 + # chart_dict["height"] = 150 + return chart_dict + elif self.output_type == "Altair": + import inspect - if (chart): - if (vis.plot_config): chart.chart = vis.plot_config(chart.chart) - if (self.output_type=="VegaLite"): - chart_dict = chart.chart.to_dict() - # this is a bit of a work around because altair must take a pandas dataframe and we can only generate a luxDataFrame - # chart["data"] = { "values": vis.data.to_dict(orient='records') } - # chart_dict["width"] = 160 - # chart_dict["height"] = 150 - return chart_dict - elif (self.output_type=="Altair"): - import inspect - if (vis.plot_config): chart.code +='\n'.join(inspect.getsource(vis.plot_config).split('\n ')[1:-1]) - chart.code +="\nchart" - chart.code = chart.code.replace('\n\t\t','\n') + if vis.plot_config: + chart.code += "\n".join( + inspect.getsource(vis.plot_config).split("\n ")[1:-1] + ) + chart.code += "\nchart" + chart.code = chart.code.replace("\n\t\t", "\n") - var = vis._source - if var is not None: - all_vars = [] - for f_info in inspect.getouterframes(inspect.currentframe()): - local_vars = f_info.frame.f_back - if local_vars: - callers_local_vars = local_vars.f_locals.items() - possible_vars = [var_name for var_name, var_val in callers_local_vars if var_val is var] - all_vars.extend(possible_vars) - found_variable = [possible_var for possible_var in all_vars if possible_var[0] != '_'][0] - else: # if vis._source was not set when the Vis was created - found_variable = "df" - if standalone: - chart.code = chart.code.replace("placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})") - else: - chart.code = chart.code.replace("placeholder_variable", found_variable) # TODO: Placeholder (need to read dynamically via locals()) - return chart.code \ No newline at end of file + var = vis._source + if var is not None: + all_vars = [] + for f_info in inspect.getouterframes(inspect.currentframe()): + local_vars = f_info.frame.f_back + if local_vars: + callers_local_vars = local_vars.f_locals.items() + possible_vars = [ + var_name + for var_name, var_val in callers_local_vars + if var_val is var + ] + all_vars.extend(possible_vars) + found_variable = [ + possible_var + for possible_var in all_vars + if possible_var[0] != "_" + ][0] + else: # if vis._source was not set when the Vis was created + found_variable = "df" + if standalone: + chart.code = chart.code.replace( + "placeholder_variable", + f"pd.DataFrame({str(vis.data.to_dict())})", + ) + else: + chart.code = chart.code.replace( + "placeholder_variable", found_variable + ) # TODO: Placeholder (need to read dynamically via locals()) + return chart.code diff --git a/lux/vislib/altair/BarChart.py b/lux/vislib/altair/BarChart.py index 8cbe9ab7..561a23d4 100644 --- a/lux/vislib/altair/BarChart.py +++ b/lux/vislib/altair/BarChart.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,66 +14,83 @@ from lux.vislib.altair.AltairChart import AltairChart import altair as alt + alt.data_transformers.disable_max_rows() -from lux.utils.utils import get_agg_title +from lux.utils.utils import get_agg_title + + class BarChart(AltairChart): - """ - BarChart is a subclass of AltairChart that render as a bar charts. - All rendering properties for bar charts are set here. + """ + BarChart is a subclass of AltairChart that render as a bar charts. + All rendering properties for bar charts are set here. + + See Also + -------- + altair-viz.github.io + """ + + def __init__(self, dobj): + super().__init__(dobj) - See Also - -------- - altair-viz.github.io - """ + def __repr__(self): + return f"Bar Chart <{str(self.vis)}>" - def __init__(self,dobj): - super().__init__(dobj) - def __repr__(self): - return f"Bar Chart <{str(self.vis)}>" - def initialize_chart(self): - self.tooltip = False - x_attr = self.vis.get_attr_by_channel("x")[0] - y_attr = self.vis.get_attr_by_channel("y")[0] - - if (x_attr.data_model == "measure"): - agg_title = get_agg_title(x_attr) - measure_attr = x_attr.attribute - bar_attr = y_attr.attribute - y_attr_field = alt.Y(y_attr.attribute, type= y_attr.data_type, axis=alt.Axis(labelOverlap=True)) - x_attr_field = alt.X(x_attr.attribute, type= x_attr.data_type, title=agg_title) - y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True))" - x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', title='{agg_title}')" + def initialize_chart(self): + self.tooltip = False + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] - if (y_attr.sort=="ascending"): - y_attr_field.sort="-x" - y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True), sort ='-x')" - else: - agg_title = get_agg_title(y_attr) - measure_attr = y_attr.attribute - bar_attr = x_attr.attribute - x_attr_field = alt.X(x_attr.attribute, type = x_attr.data_type,axis=alt.Axis(labelOverlap=True)) - y_attr_field = alt.Y(y_attr.attribute,type=y_attr.data_type,title=agg_title) - x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True))" - y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}')" - if (x_attr.sort=="ascending"): - x_attr_field.sort="-y" - x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True),sort='-y')" - k=10 - self._topkcode = "" - n_bars = len(self.data[bar_attr].unique()) - if n_bars>k: # Truncating to only top k - remaining_bars = n_bars-k - self.data = self.data.nlargest(k,measure_attr) - self.text = alt.Chart(self.data).mark_text( - x=155, - y=142, - align="right", - color = "#ff8e04", - fontSize = 11, - text=f"+ {remaining_bars} more ..." - ) + if x_attr.data_model == "measure": + agg_title = get_agg_title(x_attr) + measure_attr = x_attr.attribute + bar_attr = y_attr.attribute + y_attr_field = alt.Y( + y_attr.attribute, + type=y_attr.data_type, + axis=alt.Axis(labelOverlap=True), + ) + x_attr_field = alt.X( + x_attr.attribute, type=x_attr.data_type, title=agg_title + ) + y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True))" + x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', title='{agg_title}')" - self._topkcode = f'''text = alt.Chart(visData).mark_text( + if y_attr.sort == "ascending": + y_attr_field.sort = "-x" + y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True), sort ='-x')" + else: + agg_title = get_agg_title(y_attr) + measure_attr = y_attr.attribute + bar_attr = x_attr.attribute + x_attr_field = alt.X( + x_attr.attribute, + type=x_attr.data_type, + axis=alt.Axis(labelOverlap=True), + ) + y_attr_field = alt.Y( + y_attr.attribute, type=y_attr.data_type, title=agg_title + ) + x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True))" + y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}')" + if x_attr.sort == "ascending": + x_attr_field.sort = "-y" + x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True),sort='-y')" + k = 10 + self._topkcode = "" + n_bars = len(self.data[bar_attr].unique()) + if n_bars > k: # Truncating to only top k + remaining_bars = n_bars - k + self.data = self.data.nlargest(k, measure_attr) + self.text = alt.Chart(self.data).mark_text( + x=155, + y=142, + align="right", + color="#ff8e04", + fontSize=11, + text=f"+ {remaining_bars} more ...", + ) + + self._topkcode = f"""text = alt.Chart(visData).mark_text( x=155, y=142, align="right", @@ -81,33 +98,36 @@ def initialize_chart(self): fontSize = 11, text=f"+ {remaining_bars} more ..." ) - chart = chart + text\n''' - - chart = alt.Chart(self.data).mark_bar().encode( - y = y_attr_field, - x = x_attr_field - ) - # TODO: tooltip messes up the count() bar charts - # Can not do interactive whenever you have default count measure otherwise output strange error (Javascript Error: Cannot read property 'length' of undefined) - #chart = chart.interactive() # If you want to enable Zooming and Panning - - self.code += "import altair as alt\n" - # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" - self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" - self.code += f''' + chart = chart + text\n""" + + chart = alt.Chart(self.data).mark_bar().encode(y=y_attr_field, x=x_attr_field) + # TODO: tooltip messes up the count() bar charts + # Can not do interactive whenever you have default count measure otherwise output strange error (Javascript Error: Cannot read property 'length' of undefined) + # chart = chart.interactive() # If you want to enable Zooming and Panning + + self.code += "import altair as alt\n" + # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" + self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" + self.code += f""" chart = alt.Chart(visData).mark_bar().encode( y = {y_attr_field_code}, x = {x_attr_field_code}, - )\n''' - return chart - - def add_text(self): - if (self._topkcode!=""): - self.chart = self.chart + self.text - self.code += self._topkcode - - def encode_color(self): # override encode_color in AltairChart to enforce add_text occurs afterwards - AltairChart.encode_color(self) - self.add_text() - self.chart = self.chart.configure_mark(tooltip=alt.TooltipContent('encoding')) # Setting tooltip as non-null - self.code += f'''chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding'))''' \ No newline at end of file + )\n""" + return chart + + def add_text(self): + if self._topkcode != "": + self.chart = self.chart + self.text + self.code += self._topkcode + + def encode_color( + self, + ): # override encode_color in AltairChart to enforce add_text occurs afterwards + AltairChart.encode_color(self) + self.add_text() + self.chart = self.chart.configure_mark( + tooltip=alt.TooltipContent("encoding") + ) # Setting tooltip as non-null + self.code += ( + f"""chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding'))""" + ) diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py index a5d3c106..56ae7276 100644 --- a/lux/vislib/altair/Heatmap.py +++ b/lux/vislib/altair/Heatmap.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,45 +13,72 @@ # limitations under the License. from lux.vislib.altair.AltairChart import AltairChart -import altair as alt +import altair as alt + alt.data_transformers.disable_max_rows() + + class Heatmap(AltairChart): - """ - Heatmap is a subclass of AltairChart that render as a heatmap. - All rendering properties for heatmap are set here. + """ + Heatmap is a subclass of AltairChart that render as a heatmap. + All rendering properties for heatmap are set here. - See Also - -------- - altair-viz.github.io - """ - def __init__(self,vis): - super().__init__(vis) - def __repr__(self): - return f"Heatmap <{str(self.vis)}>" - def initialize_chart(self): - # return NotImplemented - x_attr = self.vis.get_attr_by_channel("x")[0] - y_attr = self.vis.get_attr_by_channel("y")[0] + See Also + -------- + altair-viz.github.io + """ - chart = alt.Chart(self.data).mark_rect().encode( - x=alt.X('xBinStart', type='quantitative', axis=alt.Axis(title=x_attr.attribute), bin = alt.BinParams(binned=True)), - x2=alt.X2('xBinEnd'), - y=alt.Y('yBinStart', type='quantitative', axis=alt.Axis(title=y_attr.attribute), bin = alt.BinParams(binned=True)), - y2=alt.Y2('yBinEnd'), - opacity = alt.Opacity('count',type='quantitative',scale=alt.Scale(type="log"),legend=None) - ) - chart = chart.configure_scale(minOpacity=0.1,maxOpacity=1) - chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding')) # Setting tooltip as non-null - chart = chart.interactive() # Enable Zooming and Panning + def __init__(self, vis): + super().__init__(vis) + + def __repr__(self): + return f"Heatmap <{str(self.vis)}>" + + def initialize_chart(self): + # return NotImplemented + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + chart = ( + alt.Chart(self.data) + .mark_rect() + .encode( + x=alt.X( + "xBinStart", + type="quantitative", + axis=alt.Axis(title=x_attr.attribute), + bin=alt.BinParams(binned=True), + ), + x2=alt.X2("xBinEnd"), + y=alt.Y( + "yBinStart", + type="quantitative", + axis=alt.Axis(title=y_attr.attribute), + bin=alt.BinParams(binned=True), + ), + y2=alt.Y2("yBinEnd"), + opacity=alt.Opacity( + "count", + type="quantitative", + scale=alt.Scale(type="log"), + legend=None, + ), + ) + ) + chart = chart.configure_scale(minOpacity=0.1, maxOpacity=1) + chart = chart.configure_mark( + tooltip=alt.TooltipContent("encoding") + ) # Setting tooltip as non-null + chart = chart.interactive() # Enable Zooming and Panning + + #################################### + # Constructing Altair Code String ## + #################################### - #################################### - # Constructing Altair Code String ## - #################################### - - self.code += "import altair as alt\n" - # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" - self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" - self.code += f''' + self.code += "import altair as alt\n" + # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" + self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" + self.code += f""" chart = alt.Chart(visData).mark_rect().encode( x=alt.X('xBinStart', type='quantitative', axis=alt.Axis(title='{x_attr.attribute}'), bin = alt.BinParams(binned=True)), x2=alt.X2('xBinEnd'), @@ -60,5 +87,5 @@ def initialize_chart(self): opacity = alt.Opacity('count',type='quantitative',scale=alt.Scale(type="log"),legend=None) ) chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding')) # Setting tooltip as non-null - ''' - return chart \ No newline at end of file + """ + return chart diff --git a/lux/vislib/altair/Histogram.py b/lux/vislib/altair/Histogram.py index f6113ba9..b9d1da4a 100644 --- a/lux/vislib/altair/Histogram.py +++ b/lux/vislib/altair/Histogram.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,61 +14,90 @@ from lux.vislib.altair.AltairChart import AltairChart import altair as alt + alt.data_transformers.disable_max_rows() + + class Histogram(AltairChart): - """ - Histogram is a subclass of AltairChart that render as a histograms. - All rendering properties for histograms are set here. + """ + Histogram is a subclass of AltairChart that render as a histograms. + All rendering properties for histograms are set here. + + See Also + -------- + altair-viz.github.io + """ + + def __init__(self, vis): + super().__init__(vis) + + def __repr__(self): + return f"Histogram <{str(self.vis)}>" + + def initialize_chart(self): + self.tooltip = False + measure = self.vis.get_attr_by_data_model("measure", exclude_record=True)[0] + msr_attr = self.vis.get_attr_by_channel(measure.channel)[0] + x_min = self.vis.min_max[msr_attr.attribute][0] + x_max = self.vis.min_max[msr_attr.attribute][1] - See Also - -------- - altair-viz.github.io - """ - def __init__(self,vis): - super().__init__(vis) - def __repr__(self): - return f"Histogram <{str(self.vis)}>" - def initialize_chart(self): - self.tooltip = False - measure = self.vis.get_attr_by_data_model("measure",exclude_record=True)[0] - msr_attr = self.vis.get_attr_by_channel(measure.channel)[0] - x_min = self.vis.min_max[msr_attr.attribute][0] - x_max = self.vis.min_max[msr_attr.attribute][1] + x_range = abs( + max(self.vis.data[msr_attr.attribute]) + - min(self.vis.data[msr_attr.attribute]) + ) + plot_range = abs(x_max - x_min) + markbar = x_range / plot_range * 12 - x_range = abs(max(self.vis.data[msr_attr.attribute]) - - min(self.vis.data[msr_attr.attribute])) - plot_range = abs(x_max - x_min) - markbar = x_range / plot_range * 12 + if measure.channel == "x": + chart = ( + alt.Chart(self.data) + .mark_bar(size=markbar) + .encode( + alt.X( + msr_attr.attribute, + title=f"{msr_attr.attribute} (binned)", + bin=alt.Bin(binned=True), + type=msr_attr.data_type, + axis=alt.Axis(labelOverlap=True), + scale=alt.Scale(domain=(x_min, x_max)), + ), + alt.Y("Number of Records", type="quantitative"), + ) + ) + elif measure.channel == "y": + chart = ( + alt.Chart(self.data) + .mark_bar(size=markbar) + .encode( + x=alt.X("Number of Records", type="quantitative"), + y=alt.Y( + msr_attr.attribute, + title=f"{msr_attr.attribute} (binned)", + bin=alt.Bin(binned=True), + axis=alt.Axis(labelOverlap=True), + scale=alt.Scale(domain=(x_min, x_max)), + ), + ) + ) + ##################################### + ## Constructing Altair Code String ## + ##################################### - if (measure.channel=="x"): - chart = alt.Chart(self.data).mark_bar(size=markbar).encode( - alt.X(msr_attr.attribute, title=f'{msr_attr.attribute} (binned)',bin=alt.Bin(binned=True), type=msr_attr.data_type, axis=alt.Axis(labelOverlap=True), scale=alt.Scale(domain=(x_min, x_max))), - alt.Y("Number of Records", type="quantitative") - ) - elif (measure.channel=="y"): - chart = alt.Chart(self.data).mark_bar(size=markbar).encode( - x = alt.X("Number of Records", type="quantitative"), - y = alt.Y(msr_attr.attribute, title=f'{msr_attr.attribute} (binned)', bin=alt.Bin(binned=True), axis=alt.Axis(labelOverlap=True), scale=alt.Scale(domain=(x_min, x_max))) - ) - ##################################### - ## Constructing Altair Code String ## - ##################################### - - self.code += "import altair as alt\n" - # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" - self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" - if (measure.channel=="x"): - self.code += f''' + self.code += "import altair as alt\n" + # self.code += f"visData = pd.DataFrame({str(self.data.to_dict(orient='records'))})\n" + self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" + if measure.channel == "x": + self.code += f""" chart = alt.Chart(visData).mark_bar(size={markbar}).encode( alt.X('{msr_attr.attribute}', title='{msr_attr.attribute} (binned)',bin=alt.Bin(binned=True), type='{msr_attr.data_type}', axis=alt.Axis(labelOverlap=True), scale=alt.Scale(domain=({x_min}, {x_max}))), alt.Y("Number of Records", type="quantitative") ) - ''' - elif (measure.channel=="y"): - self.code += f''' + """ + elif measure.channel == "y": + self.code += f""" chart = alt.Chart(visData).mark_bar(size={markbar}).encode( alt.Y('{msr_attr.attribute}', title='{msr_attr.attribute} (binned)',bin=alt.Bin(binned=True), type='{msr_attr.data_type}', axis=alt.Axis(labelOverlap=True), scale=alt.Scale(domain=({x_min}, {x_max}))), alt.X("Number of Records", type="quantitative") ) - ''' - return chart \ No newline at end of file + """ + return chart diff --git a/lux/vislib/altair/LineChart.py b/lux/vislib/altair/LineChart.py index fc94a1e7..1e01eabf 100644 --- a/lux/vislib/altair/LineChart.py +++ b/lux/vislib/altair/LineChart.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,56 +14,65 @@ from lux.vislib.altair.AltairChart import AltairChart import altair as alt + alt.data_transformers.disable_max_rows() -from lux.utils.utils import get_agg_title +from lux.utils.utils import get_agg_title + + class LineChart(AltairChart): - """ - LineChart is a subclass of AltairChart that render as a line charts. - All rendering properties for line charts are set here. + """ + LineChart is a subclass of AltairChart that render as a line charts. + All rendering properties for line charts are set here. - See Also - -------- - altair-viz.github.io - """ - def __init__(self,dobj): - super().__init__(dobj) - def __repr__(self): - return f"Line Chart <{str(self.vis)}>" - def initialize_chart(self): - self.tooltip = False # tooltip looks weird for line chart - x_attr = self.vis.get_attr_by_channel("x")[0] - y_attr = self.vis.get_attr_by_channel("y")[0] + See Also + -------- + altair-viz.github.io + """ + def __init__(self, dobj): + super().__init__(dobj) - self.code += "import altair as alt\n" - self.code += "import pandas._libs.tslibs.timestamps\n" - self.code += "from pandas._libs.tslibs.timestamps import Timestamp\n" - self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" - - if (y_attr.data_model == "measure"): - agg_title = get_agg_title(y_attr) - x_attr_spec = alt.X(x_attr.attribute, type = x_attr.data_type) - y_attr_spec = alt.Y(y_attr.attribute, type= y_attr.data_type, title=agg_title) - x_attr_field_code = f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}')" - y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}')" - else: - agg_title = get_agg_title(x_attr) - x_attr_spec = alt.X(x_attr.attribute,type= x_attr.data_type, title=agg_title) - y_attr_spec = alt.Y(y_attr.attribute, type = y_attr.data_type) - x_attr_field_code = f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}', title='{agg_title}')" - y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}')" + def __repr__(self): + return f"Line Chart <{str(self.vis)}>" - chart = alt.Chart(self.data).mark_line().encode( - x = x_attr_spec, - y = y_attr_spec - ) - chart = chart.interactive() # Enable Zooming and Panning - self.code += f''' + def initialize_chart(self): + self.tooltip = False # tooltip looks weird for line chart + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + self.code += "import altair as alt\n" + self.code += "import pandas._libs.tslibs.timestamps\n" + self.code += "from pandas._libs.tslibs.timestamps import Timestamp\n" + self.code += f"visData = pd.DataFrame({str(self.data.to_dict())})\n" + + if y_attr.data_model == "measure": + agg_title = get_agg_title(y_attr) + x_attr_spec = alt.X(x_attr.attribute, type=x_attr.data_type) + y_attr_spec = alt.Y( + y_attr.attribute, type=y_attr.data_type, title=agg_title + ) + x_attr_field_code = ( + f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}')" + ) + y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}')" + else: + agg_title = get_agg_title(x_attr) + x_attr_spec = alt.X( + x_attr.attribute, type=x_attr.data_type, title=agg_title + ) + y_attr_spec = alt.Y(y_attr.attribute, type=y_attr.data_type) + x_attr_field_code = f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}', title='{agg_title}')" + y_attr_fieldCode = ( + f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}')" + ) + + chart = alt.Chart(self.data).mark_line().encode(x=x_attr_spec, y=y_attr_spec) + chart = chart.interactive() # Enable Zooming and Panning + self.code += f""" chart = alt.Chart(visData).mark_line().encode( y = {y_attr_fieldCode}, x = {x_attr_field_code}, ) chart = chart.interactive() # Enable Zooming and Panning - ''' - return chart - \ No newline at end of file + """ + return chart diff --git a/lux/vislib/altair/ScatterChart.py b/lux/vislib/altair/ScatterChart.py index d9370b23..a6463041 100644 --- a/lux/vislib/altair/ScatterChart.py +++ b/lux/vislib/altair/ScatterChart.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,49 +13,69 @@ # limitations under the License. from lux.vislib.altair.AltairChart import AltairChart -import altair as alt +import altair as alt + alt.data_transformers.disable_max_rows() + + class ScatterChart(AltairChart): - """ - ScatterChart is a subclass of AltairChart that render as a scatter charts. - All rendering properties for scatter charts are set here. - - See Also - -------- - altair-viz.github.io - """ - def __init__(self,vis): - super().__init__(vis) - def __repr__(self): - return f"ScatterChart <{str(self.vis)}>" - def initialize_chart(self): - x_attr = self.vis.get_attr_by_channel("x")[0] - y_attr = self.vis.get_attr_by_channel("y")[0] - x_min = self.vis.min_max[x_attr.attribute][0] - x_max = self.vis.min_max[x_attr.attribute][1] - - y_min = self.vis.min_max[y_attr.attribute][0] - y_max = self.vis.min_max[y_attr.attribute][1] - - chart = alt.Chart(self.data).mark_circle().encode( - x=alt.X(x_attr.attribute,scale=alt.Scale(domain=(x_min, x_max)),type=x_attr.data_type), - y=alt.Y(y_attr.attribute,scale=alt.Scale(domain=(y_min, y_max)),type=y_attr.data_type) - ) - chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding')) # Setting tooltip as non-null - chart = chart.interactive() # Enable Zooming and Panning + """ + ScatterChart is a subclass of AltairChart that render as a scatter charts. + All rendering properties for scatter charts are set here. + + See Also + -------- + altair-viz.github.io + """ + + def __init__(self, vis): + super().__init__(vis) + + def __repr__(self): + return f"ScatterChart <{str(self.vis)}>" + + def initialize_chart(self): + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + x_min = self.vis.min_max[x_attr.attribute][0] + x_max = self.vis.min_max[x_attr.attribute][1] + + y_min = self.vis.min_max[y_attr.attribute][0] + y_max = self.vis.min_max[y_attr.attribute][1] + + chart = ( + alt.Chart(self.data) + .mark_circle() + .encode( + x=alt.X( + x_attr.attribute, + scale=alt.Scale(domain=(x_min, x_max)), + type=x_attr.data_type, + ), + y=alt.Y( + y_attr.attribute, + scale=alt.Scale(domain=(y_min, y_max)), + type=y_attr.data_type, + ), + ) + ) + chart = chart.configure_mark( + tooltip=alt.TooltipContent("encoding") + ) # Setting tooltip as non-null + chart = chart.interactive() # Enable Zooming and Panning + + ##################################### + ## Constructing Altair Code String ## + ##################################### - ##################################### - ## Constructing Altair Code String ## - ##################################### - - self.code += "import altair as alt\n" - dfname = "placeholder_variable" - self.code += f''' + self.code += "import altair as alt\n" + dfname = "placeholder_variable" + self.code += f""" chart = alt.Chart({dfname}).mark_circle().encode( x=alt.X('{x_attr.attribute}',scale=alt.Scale(domain=({x_min}, {x_max})),type='{x_attr.data_type}'), y=alt.Y('{y_attr.attribute}',scale=alt.Scale(domain=({y_min}, {y_max})),type='{y_attr.data_type}') ) chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding')) # Setting tooltip as non-null chart = chart.interactive() # Enable Zooming and Panning - ''' - return chart \ No newline at end of file + """ + return chart diff --git a/lux/vislib/altair/__init__.py b/lux/vislib/altair/__init__.py index cbfa9f5b..948becf5 100644 --- a/lux/vislib/altair/__init__.py +++ b/lux/vislib/altair/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/requirements-dev.txt b/requirements-dev.txt index 3d919655..0365d5bd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pytest>=5.3.1 pytest-cov>=2.8.1 Sphinx>=3.0.2 sphinx-rtd-theme>=0.4.3 +black diff --git a/requirements.txt b/requirements.txt index 0b31c7a4..b23c6009 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,7 @@ scipy>=1.3.3 altair>=4.0.0 pandas>=1.1.0 scikit-learn>=0.22 -lux-widget>=0.1.0 \ No newline at end of file +# Install only to use SQLExecutor +# psycopg2>=2.8.5 +# psycopg2-binary>=2.8.5 +lux-widget>=0.1.0 diff --git a/setup.py b/setup.py index 43b731a6..bca149b8 100644 --- a/setup.py +++ b/setup.py @@ -4,43 +4,41 @@ HERE = path.abspath(path.dirname(__file__)) # Get the long description from the README file -with open(path.join(HERE, 'README.md'), encoding='utf-8') as f: +with open(path.join(HERE, "README.md"), encoding="utf-8") as f: long_description = f.read() -with open(path.join(HERE, 'requirements.txt')) as fp: +with open(path.join(HERE, "requirements.txt")) as fp: install_requires = fp.read() version_dict = {} -with open(path.join(HERE, 'lux/_version.py')) as fp: +with open(path.join(HERE, "lux/_version.py")) as fp: exec(fp.read(), {}, version_dict) version = version_dict["__version__"] setup( - name='lux-api', # PyPI Name (pip install [name]) + name="lux-api", # PyPI Name (pip install [name]) version=version, # Required - description='A Python API for Intelligent Data Discovery', - long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/lux-org/lux', - author='Doris Jung-Lin Lee', - author_email='dorisjunglinlee@gmail.com', - license = 'Apache-2.0 License', - classifiers=[ - 'Development Status :: 1 - Planning', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Intended Audience :: Other Audience', - 'Topic :: Scientific/Engineering :: Information Analysis', - 'Topic :: Scientific/Engineering :: Visualization', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3' + description="A Python API for Intelligent Data Discovery", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/lux-org/lux", + author="Doris Jung-Lin Lee", + author_email="dorisjunglinlee@gmail.com", + license="Apache-2.0 License", + classifiers=[ + "Development Status :: 1 - Planning", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Other Audience", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Scientific/Engineering :: Visualization", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", ], - keywords= ['Visualization','Analytics','Data Science','Data Analysis'], + keywords=["Visualization", "Analytics", "Data Science", "Data Analysis"], include_data_package=True, packages=find_packages(), # Required - python_requires='>=3.5', + python_requires=">=3.5", install_requires=install_requires, - extras_require={ - 'test': ['pytest'] - } -) \ No newline at end of file + extras_require={"test": ["pytest"]}, +) diff --git a/tests/__init__.py b/tests/__init__.py index cbfa9f5b..948becf5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/context.py b/tests/context.py index 00894665..42b3eb15 100644 --- a/tests/context.py +++ b/tests/context.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,6 +14,7 @@ import os import sys -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -import lux \ No newline at end of file +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import lux diff --git a/tests/test_action.py b/tests/test_action.py index 2b80a82a..3b3097ad 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,92 +17,183 @@ import pandas as pd from lux.vis.Vis import Vis + + def test_vary_filter_val(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' - df = pd.read_csv(url) - vis = Vis(["Height","SportType=Ball"],df) - df.set_intent_as_vis(vis) - df._repr_html_() - assert len(df.recommendation["Filter"]) == len(df["SportType"].unique())-1 + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) + df = pd.read_csv(url) + vis = Vis(["Height", "SportType=Ball"], df) + df.set_intent_as_vis(vis) + df._repr_html_() + assert len(df.recommendation["Filter"]) == len(df["SportType"].unique()) - 1 + + def test_filter_inequality(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + + df.set_intent( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="MilesPerGal"), + lux.Clause(attribute="Acceleration", filter_op=">", value=10), + ] + ) + df._repr_html_() + + from lux.utils.utils import get_filter_specs + + complement_vis = df.recommendation["Filter"][0] + fltr_clause = get_filter_specs(complement_vis._intent)[0] + assert fltr_clause.filter_op == "<=" + assert fltr_clause.value == 10 - df.set_intent([lux.Clause(attribute = "Horsepower"),lux.Clause(attribute = "MilesPerGal"),lux.Clause(attribute = "Acceleration", filter_op=">",value = 10)]) - df._repr_html_() - from lux.utils.utils import get_filter_specs - complement_vis = df.recommendation["Filter"][0] - fltr_clause = get_filter_specs(complement_vis._intent)[0] - assert fltr_clause.filter_op =="<=" - assert fltr_clause.value ==10 def test_generalize_action(): - #test that generalize action creates all unique visualizations - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - df.set_intent(["Acceleration", "MilesPerGal", "Cylinders", "Origin=USA"]) - df._repr_html_() - assert(len(df.recommendation['Generalize']) == 4) - v1 = df.recommendation['Generalize'][0] - v2 = df.recommendation['Generalize'][1] - v3 = df.recommendation['Generalize'][2] - v4 = df.recommendation['Generalize'][3] - - for clause in v4._inferred_intent: - assert clause.value=="" #No filter value - assert v4.title =='Overall' - - check1 = v1 != v2 and v1 != v3 and v1 != v4 - check2 = v2 != v3 and v2 != v4 - check3 = v3 != v4 - assert(check1 and check2 and check3) + # test that generalize action creates all unique visualizations + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + df.set_intent(["Acceleration", "MilesPerGal", "Cylinders", "Origin=USA"]) + df._repr_html_() + assert len(df.recommendation["Generalize"]) == 4 + v1 = df.recommendation["Generalize"][0] + v2 = df.recommendation["Generalize"][1] + v3 = df.recommendation["Generalize"][2] + v4 = df.recommendation["Generalize"][3] + + for clause in v4._inferred_intent: + assert clause.value == "" # No filter value + assert v4.title == "Overall" + + check1 = v1 != v2 and v1 != v3 and v1 != v4 + check2 = v2 != v3 and v2 != v4 + check3 = v3 != v4 + assert check1 and check2 and check3 + def test_row_column_group(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/state_timeseries.csv?raw=true' - df = pd.read_csv(url) - df["Date"] = pd.to_datetime(df["Date"]) - tseries = df.pivot(index="State",columns="Date",values="Value") - # Interpolating missing values - tseries[tseries.columns.min()] = tseries[tseries.columns.min()].fillna(0) - tseries[tseries.columns.max()] = tseries[tseries.columns.max()].fillna(tseries.max(axis=1)) - tseries = tseries.interpolate('zero',axis=1) - tseries._repr_html_() - assert list(tseries.recommendation.keys() ) == ['Row Groups','Column Groups'] + url = "https://github.com/lux-org/lux-datasets/blob/master/data/state_timeseries.csv?raw=true" + df = pd.read_csv(url) + df["Date"] = pd.to_datetime(df["Date"]) + tseries = df.pivot(index="State", columns="Date", values="Value") + # Interpolating missing values + tseries[tseries.columns.min()] = tseries[tseries.columns.min()].fillna(0) + tseries[tseries.columns.max()] = tseries[tseries.columns.max()].fillna( + tseries.max(axis=1) + ) + tseries = tseries.interpolate("zero", axis=1) + tseries._repr_html_() + assert list(tseries.recommendation.keys()) == ["Row Groups", "Column Groups"] + def test_groupby(): - df = pd.read_csv("lux/data/college.csv") - groupbyResult = df.groupby("Region").sum() - groupbyResult._repr_html_() - assert list(groupbyResult.recommendation.keys() ) == ['Column Groups'] + df = pd.read_csv("lux/data/college.csv") + groupbyResult = df.groupby("Region").sum() + groupbyResult._repr_html_() + assert list(groupbyResult.recommendation.keys()) == ["Column Groups"] + def test_crosstab(): - # Example from http://www.datasciencemadesimple.com/cross-tab-cross-table-python-pandas/ - d = { - 'Name':['Alisa','Bobby','Cathrine','Alisa','Bobby','Cathrine', - 'Alisa','Bobby','Cathrine','Alisa','Bobby','Cathrine'], - 'Exam':['Semester 1','Semester 1','Semester 1','Semester 1','Semester 1','Semester 1', - 'Semester 2','Semester 2','Semester 2','Semester 2','Semester 2','Semester 2'], - - 'Subject':['Mathematics','Mathematics','Mathematics','Science','Science','Science', - 'Mathematics','Mathematics','Mathematics','Science','Science','Science'], - 'Result':['Pass','Pass','Fail','Pass','Fail','Pass','Pass','Fail','Fail','Pass','Pass','Fail']} - - df = pd.DataFrame(d,columns=['Name','Exam','Subject','Result']) - result = pd.crosstab([df.Exam],df.Result) - result._repr_html_() - assert list(result.recommendation.keys() ) == ['Row Groups','Column Groups'] + # Example from http://www.datasciencemadesimple.com/cross-tab-cross-table-python-pandas/ + d = { + "Name": [ + "Alisa", + "Bobby", + "Cathrine", + "Alisa", + "Bobby", + "Cathrine", + "Alisa", + "Bobby", + "Cathrine", + "Alisa", + "Bobby", + "Cathrine", + ], + "Exam": [ + "Semester 1", + "Semester 1", + "Semester 1", + "Semester 1", + "Semester 1", + "Semester 1", + "Semester 2", + "Semester 2", + "Semester 2", + "Semester 2", + "Semester 2", + "Semester 2", + ], + "Subject": [ + "Mathematics", + "Mathematics", + "Mathematics", + "Science", + "Science", + "Science", + "Mathematics", + "Mathematics", + "Mathematics", + "Science", + "Science", + "Science", + ], + "Result": [ + "Pass", + "Pass", + "Fail", + "Pass", + "Fail", + "Pass", + "Pass", + "Fail", + "Fail", + "Pass", + "Pass", + "Fail", + ], + } + + df = pd.DataFrame(d, columns=["Name", "Exam", "Subject", "Result"]) + result = pd.crosstab([df.Exam], df.Result) + result._repr_html_() + assert list(result.recommendation.keys()) == ["Row Groups", "Column Groups"] + def test_custom_aggregation(): - import numpy as np - df = pd.read_csv("lux/data/college.csv") - df.set_intent(["HighestDegree",lux.Clause("AverageCost",aggregation=np.ptp)]) - df._repr_html_() - assert list(df.recommendation.keys()) ==['Enhance', 'Filter', 'Generalize'] + import numpy as np + + df = pd.read_csv("lux/data/college.csv") + df.set_intent(["HighestDegree", lux.Clause("AverageCost", aggregation=np.ptp)]) + df._repr_html_() + assert list(df.recommendation.keys()) == ["Enhance", "Filter", "Generalize"] + + def test_year_filter_value(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df.set_intent(["Acceleration","Horsepower"]) - df._repr_html_() - list_of_vis_with_year_filter = list(filter(lambda vis: len(list(filter(lambda clause: clause.value!='' and clause.attribute=="Year",vis._intent)))!=0, df.recommendation["Filter"])) - vis = list_of_vis_with_year_filter[0] - assert "T00:00:00.000000000" not in vis.to_Altair(), "Year filter title contains extraneous string, not displayed as summarized string" + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df.set_intent(["Acceleration", "Horsepower"]) + df._repr_html_() + list_of_vis_with_year_filter = list( + filter( + lambda vis: len( + list( + filter( + lambda clause: clause.value != "" + and clause.attribute == "Year", + vis._intent, + ) + ) + ) + != 0, + df.recommendation["Filter"], + ) + ) + vis = list_of_vis_with_year_filter[0] + assert ( + "T00:00:00.000000000" not in vis.to_Altair() + ), "Year filter title contains extraneous string, not displayed as summarized string" diff --git a/tests/test_compiler.py b/tests/test_compiler.py index cda948d1..be3c32e8 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,27 +18,31 @@ from lux.vis.Vis import Vis from lux.vis.VisList import VisList + def test_underspecified_no_vis(test_recs): - no_vis_actions = ["Correlation", "Distribution", "Occurrence","Temporal"] - df = pd.read_csv("lux/data/car.csv") - test_recs(df, no_vis_actions) - assert len(df.current_vis) == 0 + no_vis_actions = ["Correlation", "Distribution", "Occurrence", "Temporal"] + df = pd.read_csv("lux/data/car.csv") + test_recs(df, no_vis_actions) + assert len(df.current_vis) == 0 + + # test only one filter context case. + df.set_intent([lux.Clause(attribute="Origin", filter_op="=", value="USA")]) + test_recs(df, no_vis_actions) + assert len(df.current_vis) == 0 - # test only one filter context case. - df.set_intent([lux.Clause(attribute ="Origin", filter_op="=", value="USA")]) - test_recs(df, no_vis_actions) - assert len(df.current_vis) == 0 def test_underspecified_single_vis(test_recs): - one_vis_actions = ["Enhance", "Filter", "Generalize"] - df = pd.read_csv("lux/data/car.csv") - df.set_intent([lux.Clause(attribute ="MilesPerGal"), lux.Clause(attribute ="Weight")]) - test_recs(df, one_vis_actions) - assert len(df.current_vis) == 1 - assert df.current_vis[0].mark == "scatter" - for attr in df.current_vis[0]._inferred_intent: assert attr.data_model == "measure" - for attr in df.current_vis[0]._inferred_intent: assert attr.data_type == "quantitative" - + one_vis_actions = ["Enhance", "Filter", "Generalize"] + df = pd.read_csv("lux/data/car.csv") + df.set_intent([lux.Clause(attribute="MilesPerGal"), lux.Clause(attribute="Weight")]) + test_recs(df, one_vis_actions) + assert len(df.current_vis) == 1 + assert df.current_vis[0].mark == "scatter" + for attr in df.current_vis[0]._inferred_intent: + assert attr.data_model == "measure" + for attr in df.current_vis[0]._inferred_intent: + assert attr.data_type == "quantitative" + # def test_underspecified_vis_collection(test_recs): # multiple_vis_actions = ["Current viss"] @@ -72,182 +76,316 @@ def test_underspecified_single_vis(test_recs): # assert len(df.current_vis) == len([vis.get_attr_by_data_model("measure") for vis in df.current_vis]) #should be 25 # test_recs(df, multiple_vis_actions) def test_set_intent_as_vis(test_recs): - df = pd.read_csv("lux/data/car.csv") - df._repr_html_() - vis = df.recommendation["Correlation"][0] - df.intent = vis - df._repr_html_() - test_recs(df,["Enhance","Filter","Generalize"]) + df = pd.read_csv("lux/data/car.csv") + df._repr_html_() + vis = df.recommendation["Correlation"][0] + df.intent = vis + df._repr_html_() + test_recs(df, ["Enhance", "Filter", "Generalize"]) + @pytest.fixture def test_recs(): - def test_recs_function(df, actions): - df._repr_html_() - assert (len(df.recommendation) > 0) - recKeys = list(df.recommendation.keys()) - list_equal(recKeys,actions) - return test_recs_function + def test_recs_function(df, actions): + df._repr_html_() + assert len(df.recommendation) > 0 + recKeys = list(df.recommendation.keys()) + list_equal(recKeys, actions) + + return test_recs_function + def test_parse(): - df = pd.read_csv("lux/data/car.csv") - vlst = VisList([lux.Clause("Origin=?"), lux.Clause(attribute ="MilesPerGal")],df) - assert len(vlst) == 3 + df = pd.read_csv("lux/data/car.csv") + vlst = VisList([lux.Clause("Origin=?"), lux.Clause(attribute="MilesPerGal")], df) + assert len(vlst) == 3 + + df = pd.read_csv("lux/data/car.csv") + vlst = VisList([lux.Clause("Origin=?"), lux.Clause("MilesPerGal")], df) + assert len(vlst) == 3 + - df = pd.read_csv("lux/data/car.csv") - vlst = VisList([lux.Clause("Origin=?"), lux.Clause("MilesPerGal")],df) - assert len(vlst) == 3 def test_underspecified_vis_collection_zval(): - # check if the number of charts is correct - df = pd.read_csv("lux/data/car.csv") - vlst = VisList([lux.Clause(attribute ="Origin", filter_op="=", value="?"), lux.Clause(attribute ="MilesPerGal")],df) - assert len(vlst) == 3 + # check if the number of charts is correct + df = pd.read_csv("lux/data/car.csv") + vlst = VisList( + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ], + df, + ) + assert len(vlst) == 3 + + # does not work + # df = pd.read_csv("lux/data/car.csv") + # vlst = VisList([lux.Clause(attribute = ["Origin","Cylinders"], filter_op="=",value="?"),lux.Clause(attribute = ["Horsepower"]),lux.Clause(attribute = "Weight")],df) + # assert len(vlst) == 8 - #does not work - # df = pd.read_csv("lux/data/car.csv") - # vlst = VisList([lux.Clause(attribute = ["Origin","Cylinders"], filter_op="=",value="?"),lux.Clause(attribute = ["Horsepower"]),lux.Clause(attribute = "Weight")],df) - # assert len(vlst) == 8 def test_sort_bar(): - from lux.processor.Compiler import Compiler - from lux.vis.Vis import Vis - df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause(attribute="Acceleration",data_model="measure",data_type="quantitative"), - lux.Clause(attribute="Origin",data_model="dimension",data_type="nominal")],df) - assert vis.mark == "bar" - assert vis._inferred_intent[1].sort == '' - - df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause(attribute="Acceleration",data_model="measure",data_type="quantitative"), - lux.Clause(attribute="Name",data_model="dimension",data_type="nominal")],df) - assert vis.mark == "bar" - assert vis._inferred_intent[1].sort == 'ascending' + from lux.processor.Compiler import Compiler + from lux.vis.Vis import Vis + + df = pd.read_csv("lux/data/car.csv") + vis = Vis( + [ + lux.Clause( + attribute="Acceleration", data_model="measure", data_type="quantitative" + ), + lux.Clause(attribute="Origin", data_model="dimension", data_type="nominal"), + ], + df, + ) + assert vis.mark == "bar" + assert vis._inferred_intent[1].sort == "" + + df = pd.read_csv("lux/data/car.csv") + vis = Vis( + [ + lux.Clause( + attribute="Acceleration", data_model="measure", data_type="quantitative" + ), + lux.Clause(attribute="Name", data_model="dimension", data_type="nominal"), + ], + df, + ) + assert vis.mark == "bar" + assert vis._inferred_intent[1].sort == "ascending" + def test_specified_vis_collection(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true' - df = pd.read_csv(url) - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype + url = "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" + df = pd.read_csv(url) + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + + vlst = VisList( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Brand"), + lux.Clause(attribute="Origin", value=["Japan", "USA"]), + ], + df, + ) + assert len(vlst) == 2 + + vlst = VisList( + [ + lux.Clause(attribute=["Horsepower", "Weight"]), + lux.Clause(attribute="Brand"), + lux.Clause(attribute="Origin", value=["Japan", "USA"]), + ], + df, + ) + assert len(vlst) == 4 + + # test if z axis has been filtered correctly + chart_titles = [vis.title for vis in vlst] + assert "Origin = USA" and "Origin = Japan" in chart_titles + assert "Origin = Europe" not in chart_titles - vlst = VisList([lux.Clause(attribute="Horsepower"),lux.Clause(attribute="Brand"), lux.Clause(attribute = "Origin",value=["Japan","USA"])],df) - assert len(vlst) == 2 - vlst = VisList([lux.Clause(attribute=["Horsepower","Weight"]),lux.Clause(attribute="Brand"), lux.Clause(attribute = "Origin",value=["Japan","USA"])],df) - assert len(vlst) == 4 +def test_specified_channel_enforced_vis_collection(): + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + visList = VisList( + [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], + df, + ) + for vis in visList: + check_attribute_on_channel(vis, "MilesPerGal", "x") - # test if z axis has been filtered correctly - chart_titles = [vis.title for vis in vlst] - assert "Origin = USA" and "Origin = Japan" in chart_titles - assert "Origin = Europe" not in chart_titles +def test_autoencoding_scatter(): + # No channel specified + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + vis = Vis([lux.Clause(attribute="MilesPerGal"), lux.Clause(attribute="Weight")], df) + check_attribute_on_channel(vis, "MilesPerGal", "x") + check_attribute_on_channel(vis, "Weight", "y") + + # Partial channel specified + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) + check_attribute_on_channel(vis, "MilesPerGal", "y") + check_attribute_on_channel(vis, "Weight", "x") + + # Full channel specified + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], + df, + ) + check_attribute_on_channel(vis, "MilesPerGal", "y") + check_attribute_on_channel(vis, "Weight", "x") + # Duplicate channel specified + with pytest.raises(ValueError): + # Should throw error because there should not be columns with the same channel specified + df.set_intent( + [ + lux.Clause(attribute="MilesPerGal", channel="x"), + lux.Clause(attribute="Weight", channel="x"), + ] + ) -def test_specified_channel_enforced_vis_collection(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - visList = VisList([lux.Clause(attribute="?"),lux.Clause(attribute="MilesPerGal",channel="x")],df) - for vis in visList: - check_attribute_on_channel(vis, "MilesPerGal", "x") -def test_autoencoding_scatter(): - # No channel specified - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - vis = Vis([lux.Clause(attribute="MilesPerGal"), lux.Clause(attribute="Weight")],df) - check_attribute_on_channel(vis, "MilesPerGal", "x") - check_attribute_on_channel(vis, "Weight", "y") - - # Partial channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight")],df) - check_attribute_on_channel(vis, "MilesPerGal", "y") - check_attribute_on_channel(vis, "Weight", "x") - - # Full channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x")],df) - check_attribute_on_channel(vis, "MilesPerGal", "y") - check_attribute_on_channel(vis, "Weight", "x") - # Duplicate channel specified - with pytest.raises(ValueError): - # Should throw error because there should not be columns with the same channel specified - df.set_intent([lux.Clause(attribute="MilesPerGal", channel="x"), lux.Clause(attribute="Weight", channel="x")]) - def test_autoencoding_histogram(): - # No channel specified - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y")],df) - check_attribute_on_channel(vis, "MilesPerGal", "y") + # No channel specified + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y")], df) + check_attribute_on_channel(vis, "MilesPerGal", "y") + + vis = Vis([lux.Clause(attribute="MilesPerGal", channel="x")], df) + assert vis.get_attr_by_channel("x")[0].attribute == "MilesPerGal" + assert vis.get_attr_by_channel("y")[0].attribute == "Record" - vis = Vis([lux.Clause(attribute="MilesPerGal",channel="x")],df) - assert vis.get_attr_by_channel("x")[0].attribute == "MilesPerGal" - assert vis.get_attr_by_channel("y")[0].attribute == "Record" def test_autoencoding_line_chart(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - vis = Vis([lux.Clause(attribute="Year"), lux.Clause(attribute="Acceleration")],df) - check_attribute_on_channel(vis, "Year", "x") - check_attribute_on_channel(vis, "Acceleration", "y") - - # Partial channel specified - vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration")],df) - check_attribute_on_channel(vis, "Year", "y") - check_attribute_on_channel(vis, "Acceleration", "x") - - # Full channel specified - vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x")],df) - check_attribute_on_channel(vis, "Year", "y") - check_attribute_on_channel(vis, "Acceleration", "x") - - with pytest.raises(ValueError): - # Should throw error because there should not be columns with the same channel specified - df.set_intent([lux.Clause(attribute="Year", channel="x"), lux.Clause(attribute="Acceleration", channel="x")]) + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + vis = Vis([lux.Clause(attribute="Year"), lux.Clause(attribute="Acceleration")], df) + check_attribute_on_channel(vis, "Year", "x") + check_attribute_on_channel(vis, "Acceleration", "y") + + # Partial channel specified + vis = Vis( + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration"), + ], + df, + ) + check_attribute_on_channel(vis, "Year", "y") + check_attribute_on_channel(vis, "Acceleration", "x") + + # Full channel specified + vis = Vis( + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration", channel="x"), + ], + df, + ) + check_attribute_on_channel(vis, "Year", "y") + check_attribute_on_channel(vis, "Acceleration", "x") + + with pytest.raises(ValueError): + # Should throw error because there should not be columns with the same channel specified + df.set_intent( + [ + lux.Clause(attribute="Year", channel="x"), + lux.Clause(attribute="Acceleration", channel="x"), + ] + ) + def test_autoencoding_color_line_chart(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - intent = [lux.Clause(attribute="Year"), lux.Clause(attribute="Acceleration"), lux.Clause(attribute="Origin")] - vis = Vis(intent,df) - check_attribute_on_channel(vis, "Year", "x") - check_attribute_on_channel(vis, "Acceleration", "y") - check_attribute_on_channel(vis, "Origin", "color") + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + intent = [ + lux.Clause(attribute="Year"), + lux.Clause(attribute="Acceleration"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, df) + check_attribute_on_channel(vis, "Year", "x") + check_attribute_on_channel(vis, "Acceleration", "y") + check_attribute_on_channel(vis, "Origin", "color") + def test_autoencoding_color_scatter_chart(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - vis = Vis([lux.Clause(attribute="Horsepower"), lux.Clause(attribute="Acceleration"), lux.Clause(attribute="Origin")],df) - check_attribute_on_channel(vis, "Origin", "color") + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + vis = Vis( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Acceleration"), + lux.Clause(attribute="Origin"), + ], + df, + ) + check_attribute_on_channel(vis, "Origin", "color") + + vis = Vis( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Acceleration", channel="color"), + lux.Clause(attribute="Origin"), + ], + df, + ) + check_attribute_on_channel(vis, "Acceleration", "color") - vis = Vis([lux.Clause(attribute="Horsepower"), lux.Clause(attribute="Acceleration", channel="color"), lux.Clause(attribute="Origin")],df) - check_attribute_on_channel(vis, "Acceleration", "color") def test_populate_options(): - from lux.processor.Compiler import Compiler - df = pd.read_csv("lux/data/car.csv") - df.set_intent([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal")]) - col_set = set() - for specOptions in Compiler.populate_wildcard_options(df._intent, df)["attributes"]: - for clause in specOptions: - col_set.add(clause.attribute) - assert list_equal(list(col_set), list(df.columns)) - - df.set_intent([lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal")]) - df._repr_html_() - col_set = set() - for specOptions in Compiler.populate_wildcard_options(df._intent, df)["attributes"]: - for clause in specOptions: - col_set.add(clause.attribute) - assert list_equal(list(col_set), ['Acceleration', 'Weight', 'Horsepower', 'MilesPerGal', 'Displacement']) + from lux.processor.Compiler import Compiler + + df = pd.read_csv("lux/data/car.csv") + df.set_intent([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal")]) + col_set = set() + for specOptions in Compiler.populate_wildcard_options(df._intent, df)["attributes"]: + for clause in specOptions: + col_set.add(clause.attribute) + assert list_equal(list(col_set), list(df.columns)) + + df.set_intent( + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="MilesPerGal"), + ] + ) + df._repr_html_() + col_set = set() + for specOptions in Compiler.populate_wildcard_options(df._intent, df)["attributes"]: + for clause in specOptions: + col_set.add(clause.attribute) + assert list_equal( + list(col_set), + ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], + ) + def test_remove_all_invalid(): - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): - df.set_intent([lux.Clause(attribute = "Origin", filter_op="=",value="USA"),lux.Clause(attribute = "Origin")]) - df._repr_html_() - assert len(df.current_vis)==0 + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): + df.set_intent( + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Origin"), + ] + ) + df._repr_html_() + assert len(df.current_vis) == 0 + def list_equal(l1, l2): l1.sort() l2.sort() - return l1==l2 + return l1 == l2 + def check_attribute_on_channel(vis, attr_name, channelName): - assert vis.get_attr_by_channel(channelName)[0].attribute == attr_name + assert vis.get_attr_by_channel(channelName)[0].attribute == attr_name diff --git a/tests/test_config.py b/tests/test_config.py index 3b9be963..adfd2655 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,130 +19,168 @@ from lux.vis.VisList import VisList import lux -def register_new_action(validator: bool=True): - df = pd.read_csv("lux/data/car.csv") - def random_categorical(ldf): - intent = [lux.Clause("?",data_type="nominal")] - vlist = VisList(intent,ldf) - for vis in vlist: - vis.score = 10 - vlist = vlist.topK(15) - return {"action":"bars", "description": "Random list of Bar charts", "collection": vlist} - def contain_horsepower(df): - for clause in df.intent: - if clause.get_attr() == "Horsepower": - return True - return False - if validator: - lux.register_action("bars", random_categorical, contain_horsepower) - else: - lux.register_action("bars", random_categorical) - return df + +def register_new_action(validator: bool = True): + df = pd.read_csv("lux/data/car.csv") + + def random_categorical(ldf): + intent = [lux.Clause("?", data_type="nominal")] + vlist = VisList(intent, ldf) + for vis in vlist: + vis.score = 10 + vlist = vlist.topK(15) + return { + "action": "bars", + "description": "Random list of Bar charts", + "collection": vlist, + } + + def contain_horsepower(df): + for clause in df.intent: + if clause.get_attr() == "Horsepower": + return True + return False + + if validator: + lux.register_action("bars", random_categorical, contain_horsepower) + else: + lux.register_action("bars", random_categorical) + return df + def test_default_actions_registered(): - df = pd.read_csv("lux/data/car.csv") - df._repr_html_() - assert("Distribution" in df.recommendation) - assert(len(df.recommendation["Distribution"]) > 0) + df = pd.read_csv("lux/data/car.csv") + df._repr_html_() + assert "Distribution" in df.recommendation + assert len(df.recommendation["Distribution"]) > 0 + + assert "Occurrence" in df.recommendation + assert len(df.recommendation["Occurrence"]) > 0 - assert("Occurrence" in df.recommendation) - assert(len(df.recommendation["Occurrence"]) > 0) + assert "Temporal" in df.recommendation + assert len(df.recommendation["Temporal"]) > 0 - assert("Temporal" in df.recommendation) - assert(len(df.recommendation["Temporal"]) > 0) + assert "Correlation" in df.recommendation + assert len(df.recommendation["Correlation"]) > 0 - assert("Correlation" in df.recommendation) - assert(len(df.recommendation["Correlation"]) > 0) def test_fail_validator(): - df = register_new_action() - df._repr_html_() - assert("bars" not in df.recommendation, - "Bars should not be rendered when there is no intent 'horsepower' specified.") + df = register_new_action() + df._repr_html_() + assert ( + "bars" not in df.recommendation, + "Bars should not be rendered when there is no intent 'horsepower' specified.", + ) + def test_pass_validator(): - df = register_new_action() - df.set_intent(["Acceleration", "Horsepower"]) - df._repr_html_() - assert(len(df.recommendation["bars"]) > 0) - assert("bars" in df.recommendation, - "Bars should be rendered when intent 'horsepower' is specified.") + df = register_new_action() + df.set_intent(["Acceleration", "Horsepower"]) + df._repr_html_() + assert len(df.recommendation["bars"]) > 0 + assert ( + "bars" in df.recommendation, + "Bars should be rendered when intent 'horsepower' is specified.", + ) + def test_no_validator(): - df = register_new_action(False) - df._repr_html_() - assert(len(df.recommendation["bars"]) > 0) - assert("bars" in df.recommendation) + df = register_new_action(False) + df._repr_html_() + assert len(df.recommendation["bars"]) > 0 + assert "bars" in df.recommendation + def test_invalid_function(): - df = pd.read_csv("lux/data/car.csv") - with pytest.raises(ValueError,match="Value must be a callable"): - lux.register_action("bars", "not a Callable") + df = pd.read_csv("lux/data/car.csv") + with pytest.raises(ValueError, match="Value must be a callable"): + lux.register_action("bars", "not a Callable") + def test_invalid_validator(): - df = pd.read_csv("lux/data/car.csv") - def random_categorical(ldf): - intent = [lux.Clause("?",data_type="nominal")] - vlist = VisList(intent,ldf) - for vis in vlist: - vis.score = 10 - vlist = vlist.topK(15) - return {"action":"bars", "description": "Random list of Bar charts", "collection": vlist} - with pytest.raises(ValueError,match="Value must be a callable"): - lux.register_action("bars", random_categorical, "not a Callable") + df = pd.read_csv("lux/data/car.csv") + + def random_categorical(ldf): + intent = [lux.Clause("?", data_type="nominal")] + vlist = VisList(intent, ldf) + for vis in vlist: + vis.score = 10 + vlist = vlist.topK(15) + return { + "action": "bars", + "description": "Random list of Bar charts", + "collection": vlist, + } + + with pytest.raises(ValueError, match="Value must be a callable"): + lux.register_action("bars", random_categorical, "not a Callable") + def test_remove_action(): - df = register_new_action() - df.set_intent(["Acceleration", "Horsepower"]) - df._repr_html_() - assert("bars" in df.recommendation, - "Bars should be rendered after it has been registered with correct intent.") - assert(len(df.recommendation["bars"]) > 0, - "Bars should be rendered after it has been registered with correct intent.") - lux.remove_action("bars") - df._repr_html_() - assert("bars" not in df.recommendation, - "Bars should not be rendered after it has been removed.") + df = register_new_action() + df.set_intent(["Acceleration", "Horsepower"]) + df._repr_html_() + assert ( + "bars" in df.recommendation, + "Bars should be rendered after it has been registered with correct intent.", + ) + assert ( + len(df.recommendation["bars"]) > 0, + "Bars should be rendered after it has been registered with correct intent.", + ) + lux.remove_action("bars") + df._repr_html_() + assert ( + "bars" not in df.recommendation, + "Bars should not be rendered after it has been removed.", + ) + def test_remove_invalid_action(): - df = pd.read_csv("lux/data/car.csv") - with pytest.raises(ValueError,match="Option 'bars' has not been registered"): - lux.remove_action("bars") + df = pd.read_csv("lux/data/car.csv") + with pytest.raises(ValueError, match="Option 'bars' has not been registered"): + lux.remove_action("bars") + def test_remove_default_actions(): - df = pd.read_csv("lux/data/car.csv") - df._repr_html_() + df = pd.read_csv("lux/data/car.csv") + df._repr_html_() + + lux.remove_action("Distribution") + df._repr_html_() + assert "Distribution" not in df.recommendation - lux.remove_action("Distribution") - df._repr_html_() - assert("Distribution" not in df.recommendation) + lux.remove_action("Occurrence") + df._repr_html_() + assert "Occurrence" not in df.recommendation - lux.remove_action("Occurrence") - df._repr_html_() - assert("Occurrence" not in df.recommendation) + lux.remove_action("Temporal") + df._repr_html_() + assert "Temporal" not in df.recommendation - lux.remove_action("Temporal") - df._repr_html_() - assert("Temporal" not in df.recommendation) + lux.remove_action("Correlation") + df._repr_html_() + assert "Correlation" not in df.recommendation - lux.remove_action("Correlation") - df._repr_html_() - assert("Correlation" not in df.recommendation) + assert ( + len(df.recommendation) == 0, + "Default actions should not be rendered after it has been removed.", + ) - assert(len(df.recommendation) == 0, - "Default actions should not be rendered after it has been removed.") + df = register_new_action() + df.set_intent(["Acceleration", "Horsepower"]) + df._repr_html_() + assert ( + "bars" in df.recommendation, + "Bars should be rendered after it has been registered with correct intent.", + ) + assert len(df.recommendation["bars"]) > 0 - df = register_new_action() - df.set_intent(["Acceleration", "Horsepower"]) - df._repr_html_() - assert("bars" in df.recommendation, - "Bars should be rendered after it has been registered with correct intent.") - assert(len(df.recommendation["bars"]) > 0) -# TODO: This test does not pass in pytest but is working in Jupyter notebook. +# TODO: This test does not pass in pytest but is working in Jupyter notebook. # def test_plot_setting(): # df = pd.read_csv("lux/data/car.csv") -# df["Year"] = pd.to_datetime(df["Year"], format='%Y') +# df["Year"] = pd.to_datetime(df["Year"], format='%Y') # def change_color_add_title(chart): # chart = chart.configure_mark(color="green") # change mark color to green # chart.title = "Custom Title" # add title to chart @@ -154,4 +192,4 @@ def test_remove_default_actions(): # vis_code = df.recommendation["Correlation"][0].to_Altair() # print (vis_code) -# assert 'chart = chart.configure_mark(color="green")' in vis_code, "Exported chart does not have additional plot style setting." \ No newline at end of file +# assert 'chart = chart.configure_mark(color="green")' in vis_code, "Exported chart does not have additional plot style setting." diff --git a/tests/test_dates.py b/tests/test_dates.py index 2ffef558..140a2450 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,78 +19,105 @@ from lux.utils import date_utils from lux.executor.PandasExecutor import PandasExecutor + def test_dateformatter(): - ldf = pd.read_csv("lux/data/car.csv") - ldf["Year"] = pd.to_datetime(ldf["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - timestamp = np.datetime64('2019-08-26') - ldf.maintain_metadata() - assert(date_utils.date_formatter(timestamp,ldf) == '2019') + ldf = pd.read_csv("lux/data/car.csv") + ldf["Year"] = pd.to_datetime( + ldf["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + timestamp = np.datetime64("2019-08-26") + ldf.maintain_metadata() + assert date_utils.date_formatter(timestamp, ldf) == "2019" - ldf["Year"][0] = np.datetime64('1970-03-01') # make month non unique + ldf["Year"][0] = np.datetime64("1970-03-01") # make month non unique - assert (date_utils.date_formatter(timestamp, ldf) == '2019-8') + assert date_utils.date_formatter(timestamp, ldf) == "2019-8" + ldf["Year"][0] = np.datetime64("1970-03-03") # make day non unique - ldf["Year"][0] = np.datetime64('1970-03-03') # make day non unique + assert date_utils.date_formatter(timestamp, ldf) == "2019-8-26" - assert (date_utils.date_formatter(timestamp, ldf) == '2019-8-26') def test_period_selection(): - ldf = pd.read_csv("lux/data/car.csv") - ldf["Year"] = pd.to_datetime(ldf["Year"], format='%Y') + ldf = pd.read_csv("lux/data/car.csv") + ldf["Year"] = pd.to_datetime(ldf["Year"], format="%Y") + + ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A") - ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq='A') + ldf.set_intent( + [ + lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), + lux.Clause(attribute="Year"), + ] + ) - ldf.set_intent([lux.Clause(attribute = ["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute ="Year")]) + PandasExecutor.execute(ldf.current_vis, ldf) - PandasExecutor.execute(ldf.current_vis, ldf) + assert all( + [type(vlist.data) == lux.core.frame.LuxDataFrame for vlist in ldf.current_vis] + ) + assert all(ldf.current_vis[2].data.columns == ["Year", "Acceleration"]) - assert all([type(vlist.data) == lux.core.frame.LuxDataFrame for vlist in ldf.current_vis]) - assert all(ldf.current_vis[2].data.columns == ["Year", 'Acceleration']) def test_period_filter(): - ldf = pd.read_csv("lux/data/car.csv") - ldf["Year"] = pd.to_datetime(ldf["Year"], format='%Y') + ldf = pd.read_csv("lux/data/car.csv") + ldf["Year"] = pd.to_datetime(ldf["Year"], format="%Y") - ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq='A') + ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A") - ldf.set_intent([lux.Clause(attribute ="Acceleration"), lux.Clause(attribute ="Horsepower")]) + ldf.set_intent( + [lux.Clause(attribute="Acceleration"), lux.Clause(attribute="Horsepower")] + ) - PandasExecutor.execute(ldf.current_vis, ldf) - ldf._repr_html_() + PandasExecutor.execute(ldf.current_vis, ldf) + ldf._repr_html_() + + assert isinstance( + ldf.recommendation["Filter"][2]._inferred_intent[2].value, pd.Period + ) - assert isinstance(ldf.recommendation['Filter'][2]._inferred_intent[2].value, pd.Period) def test_period_to_altair(): - chart = None - df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + chart = None + df = pd.read_csv("lux/data/car.csv") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + + df["Year"] = pd.DatetimeIndex(df["Year"]).to_period(freq="A") + + df.set_intent( + [lux.Clause(attribute="Acceleration"), lux.Clause(attribute="Horsepower")] + ) - df["Year"] = pd.DatetimeIndex(df["Year"]).to_period(freq='A') + PandasExecutor.execute(df.current_vis, df) + df._repr_html_() - df.set_intent([lux.Clause(attribute ="Acceleration"), lux.Clause(attribute ="Horsepower")]) + exported_code = df.recommendation["Filter"][2].to_Altair() - PandasExecutor.execute(df.current_vis, df) - df._repr_html_() + assert "Year = 1971" in exported_code - exported_code = df.recommendation['Filter'][2].to_Altair() - - assert 'Year = 1971' in exported_code def test_refresh_inplace(): - df = pd.DataFrame({'date': ['2020-01-01', '2020-02-01', '2020-03-01', '2020-04-01'], 'value': [10.5,15.2,20.3,25.2]}) - with pytest.warns(UserWarning,match="Lux detects that the attribute 'date' may be temporal."): - df._repr_html_() - assert df.data_type_lookup["date"]=="temporal" - - from lux.vis.Vis import Vis - vis = Vis(["date","value"],df) - - df['date'] = pd.to_datetime(df['date'],format="%Y-%m-%d") - df.maintain_metadata() - assert df.data_type['temporal'][0] == 'date' - - vis.refresh_source(df) - assert vis.mark == "line" - assert vis.get_attr_by_channel("x")[0].attribute == "date" - assert vis.get_attr_by_channel("y")[0].attribute == "value" \ No newline at end of file + df = pd.DataFrame( + { + "date": ["2020-01-01", "2020-02-01", "2020-03-01", "2020-04-01"], + "value": [10.5, 15.2, 20.3, 25.2], + } + ) + with pytest.warns( + UserWarning, match="Lux detects that the attribute 'date' may be temporal." + ): + df._repr_html_() + assert df.data_type_lookup["date"] == "temporal" + + from lux.vis.Vis import Vis + + vis = Vis(["date", "value"], df) + + df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d") + df.maintain_metadata() + assert df.data_type["temporal"][0] == "date" + + vis.refresh_source(df) + assert vis.mark == "line" + assert vis.get_attr_by_channel("x")[0].attribute == "date" + assert vis.get_attr_by_channel("y")[0].attribute == "value" diff --git a/tests/test_display.py b/tests/test_display.py index ef570d8d..54da6ca4 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,20 +17,25 @@ import pandas as pd from lux.vis.Vis import Vis from lux.vis.VisList import VisList + + def test_to_pandas(): df = pd.read_csv("lux/data/car.csv") df.to_pandas() + def test_display_LuxDataframe(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - + + def test_display_Vis(): df = pd.read_csv("lux/data/car.csv") - vis = Vis(["Horsepower","Acceleration"],df) + vis = Vis(["Horsepower", "Acceleration"], df) vis._repr_html_() - + + def test_display_VisList(): df = pd.read_csv("lux/data/car.csv") - vislist = VisList(["?","Acceleration"],df) - vislist._repr_html_() \ No newline at end of file + vislist = VisList(["?", "Acceleration"], df) + vislist._repr_html_() diff --git a/tests/test_error_warning.py b/tests/test_error_warning.py index b745a751..e4e23c11 100644 --- a/tests/test_error_warning.py +++ b/tests/test_error_warning.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -19,35 +19,41 @@ # Test suite for checking if the expected errors and warnings are showing up correctly def test_context_str_error(): df = pd.read_csv("lux/data/college.csv") - with pytest.raises(TypeError,match="Input intent must be a list"): + with pytest.raises(TypeError, match="Input intent must be a list"): df.set_intent("bad string input") + + def test_export_b4_widget_created(): df = pd.read_csv("lux/data/college.csv") - with pytest.warns(UserWarning,match="No widget attached to the dataframe"): + with pytest.warns(UserWarning, match="No widget attached to the dataframe"): df.exported + + def test_bad_filter(): df = pd.read_csv("lux/data/college.csv") - with pytest.warns(UserWarning,match="Lux can not operate on an empty dataframe"): - df[df["Region"]=="asdfgh"]._repr_html_() + with pytest.warns(UserWarning, match="Lux can not operate on an empty dataframe"): + df[df["Region"] == "asdfgh"]._repr_html_() + + # Test Properties with Private Variables Readable but not Writable def test_vis_private_properties(): from lux.vis.Vis import Vis + df = pd.read_csv("lux/data/car.csv") - vis = Vis(["Horsepower","Weight"],df) + vis = Vis(["Horsepower", "Weight"], df) vis._repr_html_() - assert isinstance(vis.data,lux.core.frame.LuxDataFrame) - with pytest.raises(AttributeError,match="can't set attribute"): + assert isinstance(vis.data, lux.core.frame.LuxDataFrame) + with pytest.raises(AttributeError, match="can't set attribute"): vis.data = "some val" - assert isinstance(vis.code,dict) - with pytest.raises(AttributeError,match="can't set attribute"): + assert isinstance(vis.code, dict) + with pytest.raises(AttributeError, match="can't set attribute"): vis.code = "some val" - - assert isinstance(vis.min_max,dict) - with pytest.raises(AttributeError,match="can't set attribute"): + + assert isinstance(vis.min_max, dict) + with pytest.raises(AttributeError, match="can't set attribute"): vis.min_max = "some val" - assert vis.mark =="scatter" - with pytest.raises(AttributeError,match="can't set attribute"): + assert vis.mark == "scatter" + with pytest.raises(AttributeError, match="can't set attribute"): vis.mark = "some val" - diff --git a/tests/test_executor.py b/tests/test_executor.py index f8ee2613..2dababb0 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,133 +18,203 @@ from lux.executor.PandasExecutor import PandasExecutor from lux.vis.Vis import Vis from lux.vis.VisList import VisList + + def test_lazy_execution(): df = pd.read_csv("lux/data/car.csv") - intent = [lux.Clause(attribute ="Horsepower", aggregation="mean"), lux.Clause(attribute ="Origin")] + intent = [ + lux.Clause(attribute="Horsepower", aggregation="mean"), + lux.Clause(attribute="Origin"), + ] vis = Vis(intent) # Check data field in vis is empty before calling executor assert vis.data is None PandasExecutor.execute([vis], df) assert type(vis.data) == lux.core.frame.LuxDataFrame - + + def test_selection(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - intent = [lux.Clause(attribute = ["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute ="Year")] - vislist = VisList(intent,df) + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + intent = [ + lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), + lux.Clause(attribute="Year"), + ] + vislist = VisList(intent, df) assert all([type(vis.data) == lux.core.frame.LuxDataFrame for vis in vislist]) - assert all(vislist[2].data.columns == ["Year", 'Acceleration']) + assert all(vislist[2].data.columns == ["Year", "Acceleration"]) + def test_aggregation(): df = pd.read_csv("lux/data/car.csv") - intent = [lux.Clause(attribute ="Horsepower", aggregation="mean"), lux.Clause(attribute ="Origin")] - vis = Vis(intent,df) + intent = [ + lux.Clause(attribute="Horsepower", aggregation="mean"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, df) result_df = vis.data - assert int(result_df[result_df["Origin"]=="USA"]["Horsepower"])==119 + assert int(result_df[result_df["Origin"] == "USA"]["Horsepower"]) == 119 - intent = [lux.Clause(attribute ="Horsepower", aggregation="sum"), lux.Clause(attribute ="Origin")] - vis = Vis(intent,df) + intent = [ + lux.Clause(attribute="Horsepower", aggregation="sum"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, df) result_df = vis.data - assert int(result_df[result_df["Origin"]=="Japan"]["Horsepower"])==6307 + assert int(result_df[result_df["Origin"] == "Japan"]["Horsepower"]) == 6307 - intent = [lux.Clause(attribute ="Horsepower", aggregation="max"), lux.Clause(attribute ="Origin")] - vis = Vis(intent,df) + intent = [ + lux.Clause(attribute="Horsepower", aggregation="max"), + lux.Clause(attribute="Origin"), + ] + vis = Vis(intent, df) result_df = vis.data - assert int(result_df[result_df["Origin"]=="Europe"]["Horsepower"])==133 + assert int(result_df[result_df["Origin"] == "Europe"]["Horsepower"]) == 133 + def test_colored_bar_chart(): from lux.vis.Vis import Vis from lux.vis.Vis import Clause + df = pd.read_csv("lux/data/car.csv") - x_clause = Clause(attribute = "MilesPerGal", channel = "x") - y_clause = Clause(attribute = "Origin", channel = "y") - color_clause = Clause(attribute = 'Cylinders', channel = "color") + x_clause = Clause(attribute="MilesPerGal", channel="x") + y_clause = Clause(attribute="Origin", channel="y") + color_clause = Clause(attribute="Cylinders", channel="color") - new_vis = Vis([x_clause, y_clause, color_clause],df) - #make sure dimention of the data is correct - color_cardinality = len(df.unique_values['Cylinders']) - group_by_cardinality = len(df.unique_values['Origin']) - assert (len(new_vis.data.columns)==3) - assert(len(new_vis.data)==15 > group_by_cardinality < color_cardinality*group_by_cardinality) # Not color_cardinality*group_by_cardinality since some combinations have 0 values + new_vis = Vis([x_clause, y_clause, color_clause], df) + # make sure dimention of the data is correct + color_cardinality = len(df.unique_values["Cylinders"]) + group_by_cardinality = len(df.unique_values["Origin"]) + assert len(new_vis.data.columns) == 3 + assert ( + len(new_vis.data) + == 15 + > group_by_cardinality + < color_cardinality * group_by_cardinality + ) # Not color_cardinality*group_by_cardinality since some combinations have 0 values - def test_colored_line_chart(): from lux.vis.Vis import Vis from lux.vis.Vis import Clause + df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + + x_clause = Clause(attribute="Year", channel="x") + y_clause = Clause(attribute="MilesPerGal", channel="y") + color_clause = Clause(attribute="Cylinders", channel="color") - x_clause = Clause(attribute = "Year", channel = "x") - y_clause = Clause(attribute = "MilesPerGal", channel = "y") - color_clause = Clause(attribute = 'Cylinders', channel = "color") + new_vis = Vis([x_clause, y_clause, color_clause], df) - new_vis = Vis([x_clause, y_clause, color_clause],df) + # make sure dimention of the data is correct + color_cardinality = len(df.unique_values["Cylinders"]) + group_by_cardinality = len(df.unique_values["Year"]) + assert len(new_vis.data.columns) == 3 + assert ( + len(new_vis.data) + == 60 + > group_by_cardinality + < color_cardinality * group_by_cardinality + ) # Not color_cardinality*group_by_cardinality since some combinations have 0 values - #make sure dimention of the data is correct - color_cardinality = len(df.unique_values['Cylinders']) - group_by_cardinality = len(df.unique_values['Year']) - assert (len(new_vis.data.columns)==3) - assert(len(new_vis.data)==60 > group_by_cardinality < color_cardinality*group_by_cardinality) # Not color_cardinality*group_by_cardinality since some combinations have 0 values - def test_filter(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype - intent = [lux.Clause(attribute ="Horsepower"), lux.Clause(attribute ="Year"), lux.Clause(attribute ="Origin", filter_op="=", value ="USA")] - vis = Vis(intent,df) + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype + intent = [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Year"), + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + ] + vis = Vis(intent, df) vis._vis_data = df PandasExecutor.execute_filter(vis) - assert len(vis.data) == len(df[df["Origin"]=="USA"]) + assert len(vis.data) == len(df[df["Origin"] == "USA"]) + + def test_inequalityfilter(): df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause(attribute ="Horsepower", filter_op=">", value=50), lux.Clause(attribute ="MilesPerGal")]) + vis = Vis( + [ + lux.Clause(attribute="Horsepower", filter_op=">", value=50), + lux.Clause(attribute="MilesPerGal"), + ] + ) vis._vis_data = df PandasExecutor.execute_filter(vis) assert len(df) > len(vis.data) - assert len(vis.data) == 386 - - intent = [lux.Clause(attribute ="Horsepower", filter_op="<=", value=100), lux.Clause(attribute ="MilesPerGal")] - vis = Vis(intent,df) + assert len(vis.data) == 386 + + intent = [ + lux.Clause(attribute="Horsepower", filter_op="<=", value=100), + lux.Clause(attribute="MilesPerGal"), + ] + vis = Vis(intent, df) vis._vis_data = df PandasExecutor.execute_filter(vis) - assert len(vis.data) == len(df[df["Horsepower"]<=100]) == 242 + assert len(vis.data) == len(df[df["Horsepower"] <= 100]) == 242 # Test end-to-end PandasExecutor.execute([vis], df) - Nbins =list(filter(lambda x: x.bin_size!=0, vis._inferred_intent))[0].bin_size + Nbins = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0].bin_size assert len(vis.data) == Nbins - + + def test_binning(): df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause(attribute ="Horsepower")],df) - nbins =list(filter(lambda x: x.bin_size!=0, vis._inferred_intent))[0].bin_size + vis = Vis([lux.Clause(attribute="Horsepower")], df) + nbins = list(filter(lambda x: x.bin_size != 0, vis._inferred_intent))[0].bin_size assert len(vis.data) == nbins + def test_record(): df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause(attribute ="Cylinders")],df) + vis = Vis([lux.Clause(attribute="Cylinders")], df) assert len(vis.data) == len(df["Cylinders"].unique()) - + + def test_filter_aggregation_fillzero_aligned(): df = pd.read_csv("lux/data/car.csv") - intent = [lux.Clause(attribute="Cylinders"), lux.Clause(attribute="MilesPerGal"), lux.Clause("Origin=Japan")] - vis = Vis(intent,df) + intent = [ + lux.Clause(attribute="Cylinders"), + lux.Clause(attribute="MilesPerGal"), + lux.Clause("Origin=Japan"), + ] + vis = Vis(intent, df) result = vis.data - externalValidation = df[df["Origin"]=="Japan"].groupby("Cylinders").mean()["MilesPerGal"] - assert result[result["Cylinders"]==5]["MilesPerGal"].values[0]==0 - assert result[result["Cylinders"]==8]["MilesPerGal"].values[0]==0 - assert result[result["Cylinders"]==3]["MilesPerGal"].values[0]==externalValidation[3] - assert result[result["Cylinders"]==4]["MilesPerGal"].values[0]==externalValidation[4] - assert result[result["Cylinders"]==6]["MilesPerGal"].values[0]==externalValidation[6] + externalValidation = ( + df[df["Origin"] == "Japan"].groupby("Cylinders").mean()["MilesPerGal"] + ) + assert result[result["Cylinders"] == 5]["MilesPerGal"].values[0] == 0 + assert result[result["Cylinders"] == 8]["MilesPerGal"].values[0] == 0 + assert ( + result[result["Cylinders"] == 3]["MilesPerGal"].values[0] + == externalValidation[3] + ) + assert ( + result[result["Cylinders"] == 4]["MilesPerGal"].values[0] + == externalValidation[4] + ) + assert ( + result[result["Cylinders"] == 6]["MilesPerGal"].values[0] + == externalValidation[6] + ) + def test_exclude_attribute(): df = pd.read_csv("lux/data/car.csv") intent = [lux.Clause("?", exclude=["Name", "Year"]), lux.Clause("Horsepower")] - vislist = VisList(intent,df) + vislist = VisList(intent, df) for vis in vislist: - assert (vis.get_attr_by_channel("x")[0].attribute != "Year") - assert (vis.get_attr_by_channel("x")[0].attribute != "Name") - assert (vis.get_attr_by_channel("y")[0].attribute != "Year") - assert (vis.get_attr_by_channel("y")[0].attribute != "Year") + assert vis.get_attr_by_channel("x")[0].attribute != "Year" + assert vis.get_attr_by_channel("x")[0].attribute != "Name" + assert vis.get_attr_by_channel("y")[0].attribute != "Year" + assert vis.get_attr_by_channel("y")[0].attribute != "Year" diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 7943e95f..7fa7fcd0 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,143 +21,187 @@ # The following test cases are labelled for vis with def test_interestingness_1_0_0(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - - df.set_intent([lux.Clause(attribute = "Origin")]) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + + df.set_intent([lux.Clause(attribute="Origin")]) df._repr_html_() - #check that top recommended enhance graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Enhance'][0],df) != None + # check that top recommended enhance graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Enhance"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Enhance'])): - vis = df.recommendation['Enhance'][f] - if vis.get_attr_by_channel("x")[0].attribute == 'Displacement': + for f in range(0, len(df.recommendation["Enhance"])): + vis = df.recommendation["Enhance"][f] + if vis.get_attr_by_channel("x")[0].attribute == "Displacement": rank1 = f - if vis.get_attr_by_channel("x")[0].attribute == 'Weight': + if vis.get_attr_by_channel("x")[0].attribute == "Weight": rank2 = f - if vis.get_attr_by_channel("x")[0].attribute == 'Acceleration': + if vis.get_attr_by_channel("x")[0].attribute == "Acceleration": rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check that top recommended filter graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Filter'][0],df) != None + # check that top recommended filter graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Filter"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Filter'])): - vis = df.recommendation['Filter'][f] - if len(vis.get_attr_by_attr_name("Cylinders"))>0: + for f in range(0, len(df.recommendation["Filter"])): + vis = df.recommendation["Filter"][f] + if len(vis.get_attr_by_attr_name("Cylinders")) > 0: if int(vis._inferred_intent[2].value) == 8: rank1 = f if int(vis._inferred_intent[2].value) == 6: rank2 = f - if '1972' in str(df.recommendation['Filter'][f]._inferred_intent[2].value): + if "1972" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 + + def test_interestingness_1_0_1(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Origin", filter_op="=",value="USA"),lux.Clause(attribute = "Cylinders")]) + df.set_intent( + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Cylinders"), + ] + ) df._repr_html_() assert df.current_vis[0].score == 0 + def test_interestingness_0_1_0(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Horsepower")]) + df.set_intent([lux.Clause(attribute="Horsepower")]) df._repr_html_() - #check that top recommended enhance graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Enhance'][0],df) != None + # check that top recommended enhance graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Enhance"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Enhance'])): - if df.recommendation['Enhance'][f].mark == 'scatter' and df.recommendation['Enhance'][f]._inferred_intent[1].attribute == 'Weight': + for f in range(0, len(df.recommendation["Enhance"])): + if ( + df.recommendation["Enhance"][f].mark == "scatter" + and df.recommendation["Enhance"][f]._inferred_intent[1].attribute + == "Weight" + ): rank1 = f - if df.recommendation['Enhance'][f].mark == 'scatter' and df.recommendation['Enhance'][f]._inferred_intent[1].attribute == 'Acceleration': + if ( + df.recommendation["Enhance"][f].mark == "scatter" + and df.recommendation["Enhance"][f]._inferred_intent[1].attribute + == "Acceleration" + ): rank2 = f - if df.recommendation['Enhance'][f].mark == 'line' and df.recommendation['Enhance'][f]._inferred_intent[0].attribute == 'Year': + if ( + df.recommendation["Enhance"][f].mark == "line" + and df.recommendation["Enhance"][f]._inferred_intent[0].attribute == "Year" + ): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check that top recommended filter graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Filter'][0],df) != None + # check that top recommended filter graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Filter"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Filter'])): - if df.recommendation['Filter'][f]._inferred_intent[2].value == 4: + for f in range(0, len(df.recommendation["Filter"])): + if df.recommendation["Filter"][f]._inferred_intent[2].value == 4: rank1 = f - if str(df.recommendation['Filter'][f]._inferred_intent[2].value) == "Europe": + if str(df.recommendation["Filter"][f]._inferred_intent[2].value) == "Europe": rank2 = f - if '1971' in str(df.recommendation['Filter'][f]._inferred_intent[2].value): + if "1971" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 def test_interestingness_0_1_1(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - - df.set_intent([lux.Clause(attribute = "Origin", filter_op="=",value="?"),lux.Clause(attribute = "MilesPerGal")]) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + + df.set_intent( + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ] + ) df._repr_html_() - assert interestingness(df.recommendation['Current Vis'][0],df) != None - assert str(df.recommendation['Current Vis'][0]._inferred_intent[2].value) == 'USA' + assert interestingness(df.recommendation["Current Vis"][0], df) != None + assert str(df.recommendation["Current Vis"][0]._inferred_intent[2].value) == "USA" + def test_interestingness_1_1_0(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Horsepower"),lux.Clause(attribute = "Year")]) + df.set_intent([lux.Clause(attribute="Horsepower"), lux.Clause(attribute="Year")]) df._repr_html_() - #check that top recommended Enhance graph score is not none (all graphs here have same score) - assert interestingness(df.recommendation['Enhance'][0],df) != None + # check that top recommended Enhance graph score is not none (all graphs here have same score) + assert interestingness(df.recommendation["Enhance"][0], df) != None - #check that top recommended filter graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Filter'][0],df) != None + # check that top recommended filter graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Filter"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Filter'])): - vis = df.recommendation['Filter'][f] - if len(vis.get_attr_by_attr_name("Cylinders"))>0: + for f in range(0, len(df.recommendation["Filter"])): + vis = df.recommendation["Filter"][f] + if len(vis.get_attr_by_attr_name("Cylinders")) > 0: if int(vis._inferred_intent[2].value) == 6: rank1 = f if int(vis._inferred_intent[2].value) == 5: rank3 = f - if len(vis.get_attr_by_attr_name("Origin"))>0: + if len(vis.get_attr_by_attr_name("Origin")) > 0: if str(vis._inferred_intent[2].value) == "Europe": rank2 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check that top recommended generalize graph score is not none - assert interestingness(df.recommendation['Filter'][0],df) != None + # check that top recommended generalize graph score is not none + assert interestingness(df.recommendation["Filter"][0], df) != None + def test_interestingness_1_1_1(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Horsepower"), lux.Clause(attribute = "Origin", filter_op="=",value = "USA", bin_size=20)]) + df.set_intent( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="Origin", filter_op="=", value="USA", bin_size=20), + ] + ) df._repr_html_() - #check that top recommended Enhance graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Enhance'][0],df) != None + # check that top recommended Enhance graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Enhance"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Enhance'])): - if str(df.recommendation['Enhance'][f]._inferred_intent[2].value) == "USA" and str(df.recommendation['Enhance'][f]._inferred_intent[1].attribute) == 'Cylinders': + for f in range(0, len(df.recommendation["Enhance"])): + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].value) == "USA" + and str(df.recommendation["Enhance"][f]._inferred_intent[1].attribute) + == "Cylinders" + ): rank1 = f - if str(df.recommendation['Enhance'][f]._inferred_intent[2].value) == "USA" and str(df.recommendation['Enhance'][f]._inferred_intent[1].attribute) == 'Weight': + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].value) == "USA" + and str(df.recommendation["Enhance"][f]._inferred_intent[1].attribute) + == "Weight" + ): rank2 = f - if str(df.recommendation['Enhance'][f]._inferred_intent[2].value) == "USA" and str(df.recommendation['Enhance'][f]._inferred_intent[1].attribute) == 'Horsepower': + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].value) == "USA" + and str(df.recommendation["Enhance"][f]._inferred_intent[1].attribute) + == "Horsepower" + ): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check for top recommended Filter graph score is not none - assert interestingness(df.recommendation['Filter'][0],df) != None + # check for top recommended Filter graph score is not none + assert interestingness(df.recommendation["Filter"][0], df) != None + def test_interestingness_1_2_0(): from lux.vis.Vis import Vis @@ -165,59 +209,79 @@ def test_interestingness_1_2_0(): from lux.interestingness.interestingness import interestingness df = pd.read_csv("lux/data/car.csv") - y_clause = Clause(attribute = "Name", channel = "y") - color_clause = Clause(attribute = 'Cylinders', channel = "color") + y_clause = Clause(attribute="Name", channel="y") + color_clause = Clause(attribute="Cylinders", channel="color") new_vis = Vis([y_clause, color_clause]) new_vis.refresh_source(df) new_vis - #assert(len(new_vis.data)==color_cardinality*group_by_cardinality) + # assert(len(new_vis.data)==color_cardinality*group_by_cardinality) + + assert interestingness(new_vis, df) < 0.01 - assert(interestingness(new_vis, df)<0.01) def test_interestingness_0_2_0(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Horsepower"),lux.Clause(attribute = "Acceleration")]) + df.set_intent( + [lux.Clause(attribute="Horsepower"), lux.Clause(attribute="Acceleration")] + ) df._repr_html_() - #check that top recommended enhance graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Enhance'][0],df) != None + # check that top recommended enhance graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Enhance"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Enhance'])): - if str(df.recommendation['Enhance'][f]._inferred_intent[2].attribute) == "Origin" and str(df.recommendation['Enhance'][f].mark) == 'scatter': + for f in range(0, len(df.recommendation["Enhance"])): + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].attribute) + == "Origin" + and str(df.recommendation["Enhance"][f].mark) == "scatter" + ): rank1 = f - if str(df.recommendation['Enhance'][f]._inferred_intent[2].attribute) == "Displacement" and str(df.recommendation['Enhance'][f].mark) == 'scatter': + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].attribute) + == "Displacement" + and str(df.recommendation["Enhance"][f].mark) == "scatter" + ): rank2 = f - if str(df.recommendation['Enhance'][f]._inferred_intent[2].attribute) == "Year" and str(df.recommendation['Enhance'][f].mark) == 'scatter': + if ( + str(df.recommendation["Enhance"][f]._inferred_intent[2].attribute) == "Year" + and str(df.recommendation["Enhance"][f].mark) == "scatter" + ): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check that top recommended filter graph score is not none and that ordering makes intuitive sense - assert interestingness(df.recommendation['Filter'][0],df) != None + # check that top recommended filter graph score is not none and that ordering makes intuitive sense + assert interestingness(df.recommendation["Filter"][0], df) != None rank1 = -1 rank2 = -1 rank3 = -1 - for f in range(0, len(df.recommendation['Filter'])): - if '1973' in str(df.recommendation['Filter'][f]._inferred_intent[2].value): + for f in range(0, len(df.recommendation["Filter"])): + if "1973" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank1 = f - if '1976' in str(df.recommendation['Filter'][f]._inferred_intent[2].value): + if "1976" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank2 = f - if str(df.recommendation['Filter'][f]._inferred_intent[2].value) == "Europe": + if str(df.recommendation["Filter"][f]._inferred_intent[2].value) == "Europe": rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 - #check that top recommended Generalize graph score is not none - assert interestingness(df.recommendation['Generalize'][0],df) != None + # check that top recommended Generalize graph score is not none + assert interestingness(df.recommendation["Generalize"][0], df) != None def test_interestingness_0_2_1(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") - df.set_intent([lux.Clause(attribute = "Horsepower"),lux.Clause(attribute = "MilesPerGal"),lux.Clause(attribute = "Acceleration", filter_op=">",value = 10)]) + df.set_intent( + [ + lux.Clause(attribute="Horsepower"), + lux.Clause(attribute="MilesPerGal"), + lux.Clause(attribute="Acceleration", filter_op=">", value=10), + ] + ) df._repr_html_() - #check that top recommended Generalize graph score is not none - assert interestingness(df.recommendation['Generalize'][0],df) != None \ No newline at end of file + # check that top recommended Generalize graph score is not none + assert interestingness(df.recommendation["Generalize"][0], df) != None diff --git a/tests/test_maintainence.py b/tests/test_maintainence.py index 797d3462..20ee227e 100644 --- a/tests/test_maintainence.py +++ b/tests/test_maintainence.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,58 +18,71 @@ import pandas as pd from lux.vis.Vis import Vis + def test_metadata_subsequent_display(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" df._repr_html_() - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" + def test_metadata_subsequent_vis(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" - vis = Vis(["Acceleration","Horsepower"],df) - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" + vis = Vis(["Acceleration", "Horsepower"], df) + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" + def test_metadata_inplace_operation(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" df.dropna(inplace=True) - assert df._metadata_fresh==False, "Failed to expire metadata after in-place Pandas operation" + assert ( + df._metadata_fresh == False + ), "Failed to expire metadata after in-place Pandas operation" + def test_metadata_new_df_operation(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" - df[["MilesPerGal","Acceleration"]] - assert df._metadata_fresh==True, "Failed to maintain metadata after display df" - df2 = df[["MilesPerGal","Acceleration"]] - assert not hasattr(df2,"_metadata_fresh") + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" + df[["MilesPerGal", "Acceleration"]] + assert df._metadata_fresh == True, "Failed to maintain metadata after display df" + df2 = df[["MilesPerGal", "Acceleration"]] + assert not hasattr(df2, "_metadata_fresh") + def test_metadata_column_group_reset_df(): df = pd.read_csv("lux/data/car.csv") - assert not hasattr(df,"_metadata_fresh") - df['Year'] = pd.to_datetime(df['Year'], format='%Y') - assert hasattr(df,"_metadata_fresh") + assert not hasattr(df, "_metadata_fresh") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + assert hasattr(df, "_metadata_fresh") result = df.groupby("Cylinders").mean() - assert not hasattr(result,"_metadata_fresh") - result._repr_html_() # Note that this should trigger two compute metadata (one for df, and one for an intermediate df.reset_index used to feed inside created Vis) - assert result._metadata_fresh==True, "Failed to maintain metadata after display df" + assert not hasattr(result, "_metadata_fresh") + result._repr_html_() # Note that this should trigger two compute metadata (one for df, and one for an intermediate df.reset_index used to feed inside created Vis) + assert ( + result._metadata_fresh == True + ), "Failed to maintain metadata after display df" colgroup_recs = result.recommendation["Column Groups"] - assert len(colgroup_recs) == 5 - for rec in colgroup_recs: assert rec.mark=="bar", "Column Group not displaying bar charts" - + assert len(colgroup_recs) == 5 + for rec in colgroup_recs: + assert rec.mark == "bar", "Column Group not displaying bar charts" + + def test_recs_inplace_operation(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._recs_fresh==True, "Failed to maintain recommendation after display df" - assert len(df.recommendation["Occurrence"])==3 - df.drop(columns=["Name"],inplace=True) - assert 'Name' not in df.columns, "Failed to perform `drop` operation in-place" - assert df._recs_fresh==False, "Failed to maintain recommendation after in-place Pandas operation" + assert df._recs_fresh == True, "Failed to maintain recommendation after display df" + assert len(df.recommendation["Occurrence"]) == 3 + df.drop(columns=["Name"], inplace=True) + assert "Name" not in df.columns, "Failed to perform `drop` operation in-place" + assert ( + df._recs_fresh == False + ), "Failed to maintain recommendation after in-place Pandas operation" df._repr_html_() - assert len(df.recommendation["Occurrence"])==2 - assert df._recs_fresh==True, "Failed to maintain recommendation after display df" \ No newline at end of file + assert len(df.recommendation["Occurrence"]) == 2 + assert df._recs_fresh == True, "Failed to maintain recommendation after display df" diff --git a/tests/test_nan.py b/tests/test_nan.py index c92a9097..df2b8e9f 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,10 +18,12 @@ import numpy as np from lux.vis.Vis import Vis + + def test_nan_column(): - df = pd.read_csv("lux/data/college.csv") - df["Geography"] = np.nan - df._repr_html_() - for visList in df.recommendation.keys(): - for vis in df.recommendation[visList]: - assert vis.get_attr_by_attr_name("Geography")==[] \ No newline at end of file + df = pd.read_csv("lux/data/college.csv") + df["Geography"] = np.nan + df._repr_html_() + for visList in df.recommendation.keys(): + for vis in df.recommendation[visList]: + assert vis.get_attr_by_attr_name("Geography") == [] diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 114206e5..60f6a91c 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,13 +27,20 @@ # assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." # assert series.name == "Weight", "Pandas Series original `name` property not retained." + def test_head_tail(): df = pd.read_csv("lux/data/car.csv") df._repr_html_() - assert df._message.to_html()=="" + assert df._message.to_html() == "" df.head()._repr_html_() - assert "Lux is visualizing the previous version of the dataframe before you applied head."in df._message.to_html() + assert ( + "Lux is visualizing the previous version of the dataframe before you applied head." + in df._message.to_html() + ) df._repr_html_() - assert df._message.to_html()=="" + assert df._message.to_html() == "" df.tail()._repr_html_() - assert "Lux is visualizing the previous version of the dataframe before you applied tail." in df._message.to_html() \ No newline at end of file + assert ( + "Lux is visualizing the previous version of the dataframe before you applied tail." + in df._message.to_html() + ) diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index c4d47616..d88badf5 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -20,24 +20,29 @@ # DataFrame Tests # ################### + def test_deepcopy(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df._repr_html_(); + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df._repr_html_() saved_df = df.copy(deep=True) - saved_df._repr_html_(); + saved_df._repr_html_() check_metadata_equal(df, saved_df) + def test_rename_inplace(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df._repr_html_(); + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df._repr_html_() new_df = df.copy(deep=True) - df.rename(columns={"Name": "Car Name"}, inplace = True) - df._repr_html_(); - new_df._repr_html_(); - new_df, df = df, new_df # new_df is the old dataframe (df) with the new column name changed inplace - + df.rename(columns={"Name": "Car Name"}, inplace=True) + df._repr_html_() + new_df._repr_html_() + new_df, df = ( + df, + new_df, + ) # new_df is the old dataframe (df) with the new column name changed inplace + assert df.data_type_lookup != new_df.data_type_lookup assert df.data_type_lookup["Name"] == new_df.data_type_lookup["Car Name"] @@ -60,12 +65,14 @@ def test_rename_inplace(): assert list(df.cardinality.values()) == list(new_df.cardinality.values()) assert df._min_max == new_df._min_max assert df.pre_aggregated == new_df.pre_aggregated + + def test_rename(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df._repr_html_(); - new_df = df.rename(columns={"Name": "Car Name"}, inplace = False) - new_df._repr_html_(); + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df._repr_html_() + new_df = df.rename(columns={"Name": "Car Name"}, inplace=False) + new_df._repr_html_() assert df.data_type_lookup != new_df.data_type_lookup assert df.data_type_lookup["Name"] == new_df.data_type_lookup["Car Name"] @@ -88,52 +95,84 @@ def test_rename(): assert list(df.cardinality.values()) == list(new_df.cardinality.values()) assert df._min_max == new_df._min_max assert df.pre_aggregated == new_df.pre_aggregated + + def test_rename3(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df.columns = ["col1", "col2", "col3", "col4","col5", "col6", "col7", "col8", "col9"] + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df.columns = [ + "col1", + "col2", + "col3", + "col4", + "col5", + "col6", + "col7", + "col8", + "col9", + ] df._repr_html_() - assert list(df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(df.cardinality) == 9 assert "col2" in list(df.cardinality.keys()) + def test_concat(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = pd.concat([df.loc[:, "Name":"Cylinders"], df.loc[:, "Year":"Origin"]], axis = "columns") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = pd.concat( + [df.loc[:, "Name":"Cylinders"], df.loc[:, "Year":"Origin"]], axis="columns" + ) new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 5 + def test_groupby_agg(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") new_df = df.groupby("Year").agg(sum) new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Column Groups'] + assert list(new_df.recommendation.keys()) == ["Column Groups"] assert len(new_df.cardinality) == 7 + def test_qcut(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df["Weight"] = pd.qcut(df["Weight"], q = 3) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df["Weight"] = pd.qcut(df["Weight"], q=3) df._repr_html_() + def test_cut(): df = pd.read_csv("lux/data/car.csv") - df["Weight"] = pd.cut(df["Weight"], bins = [0, 2500, 7500, 10000], labels = ["small", "medium", "large"]) + df["Weight"] = pd.cut( + df["Weight"], bins=[0, 2500, 7500, 10000], labels=["small", "medium", "large"] + ) df._repr_html_() + + def test_groupby_agg_very_small(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") new_df = df.groupby("Origin").agg(sum).reset_index() new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Column Groups'] + assert list(new_df.recommendation.keys()) == ["Column Groups"] assert len(new_df.cardinality) == 7 + # def test_groupby_multi_index(): # url = 'https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true' # df = pd.read_csv(url) @@ -143,151 +182,244 @@ def test_groupby_agg_very_small(): # assert list(new_df.recommendation.keys() ) == ['Column Groups'] # TODO # assert len(new_df.cardinality) == 7 # TODO + def test_query(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") new_df = df.query("Weight > 3000") new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 9 + def test_pop(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.pop("Weight") df._repr_html_() - assert list(df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(df.cardinality) == 8 + def test_transform(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.iloc[:,1:].groupby("Origin").transform(sum) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.iloc[:, 1:].groupby("Origin").transform(sum) new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Occurrence'] + assert list(new_df.recommendation.keys()) == ["Correlation", "Occurrence"] assert len(new_df.cardinality) == 6 + def test_get_group(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") gbobj = df.groupby("Origin") new_df = gbobj.get_group("Japan") new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 9 + def test_applymap(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - mapping = {"USA": 0, "Europe": 1, "Japan":2} + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + mapping = {"USA": 0, "Europe": 1, "Japan": 2} df["Origin"] = df[["Origin"]].applymap(mapping.get) df._repr_html_() - assert list(df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(df.cardinality) == 9 + def test_strcat(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true') - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df["combined"] = df["Origin"].str.cat(df["Brand"], sep = ", ") + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" + ) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df["combined"] = df["Origin"].str.cat(df["Brand"], sep=", ") df._repr_html_() - assert list(df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(df.cardinality) == 11 + def test_named_agg(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true') - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.groupby("Brand").agg(avg_weight = ("Weight", "mean"), max_weight = ("Weight", "max"), mean_displacement = ("Displacement", "mean")) + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" + ) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.groupby("Brand").agg( + avg_weight=("Weight", "mean"), + max_weight=("Weight", "max"), + mean_displacement=("Displacement", "mean"), + ) new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Column Groups'] + assert list(new_df.recommendation.keys()) == ["Column Groups"] assert len(new_df.cardinality) == 4 + def test_change_dtype(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df["Cylinders"] = pd.Series(df["Cylinders"], dtype = "Int64") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df["Cylinders"] = pd.Series(df["Cylinders"], dtype="Int64") df._repr_html_() - assert list(df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(df.data_type_lookup) == 9 + def test_get_dummies(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") new_df = pd.get_dummies(df) new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.data_type_lookup) == 310 + def test_drop(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.drop([0, 1, 2], axis = "rows") - new_df2 = new_df.drop(["Name", "MilesPerGal", "Cylinders"], axis = "columns") + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.drop([0, 1, 2], axis="rows") + new_df2 = new_df.drop(["Name", "MilesPerGal", "Cylinders"], axis="columns") new_df2._repr_html_() - assert list(new_df2.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df2.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df2.cardinality) == 6 + def test_merge(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.drop([0, 1, 2], axis = "rows") - new_df2 = pd.merge(df, new_df, how = "left", indicator = True) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.drop([0, 1, 2], axis="rows") + new_df2 = pd.merge(df, new_df, how="left", indicator=True) new_df2._repr_html_() - assert list(new_df2.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] # TODO once bug is fixed + assert list(new_df2.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] # TODO once bug is fixed assert len(new_df2.cardinality) == 10 + def test_prefix(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') + df["Year"] = pd.to_datetime(df["Year"], format="%Y") new_df = df.add_prefix("1_") new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 9 assert new_df.cardinality["1_Name"] == 300 + def test_loc(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true') - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.loc[:,"Displacement":"Origin"] + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" + ) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.loc[:, "Displacement":"Origin"] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 6 - new_df = df.loc[0:10,"Displacement":"Origin"] + new_df = df.loc[0:10, "Displacement":"Origin"] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution'] + assert list(new_df.recommendation.keys()) == ["Correlation", "Distribution"] assert len(new_df.cardinality) == 6 - new_df = df.loc[0:10,"Displacement":"Horsepower"] + new_df = df.loc[0:10, "Displacement":"Horsepower"] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution'] + assert list(new_df.recommendation.keys()) == ["Correlation", "Distribution"] assert len(new_df.cardinality) == 2 import numpy as np - inter_df = df.groupby("Brand")[["Acceleration", "Weight", "Horsepower"]].agg(np.mean) + + inter_df = df.groupby("Brand")[["Acceleration", "Weight", "Horsepower"]].agg( + np.mean + ) new_df = inter_df.loc["chevrolet":"fiat", "Acceleration":"Weight"] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Column Groups'] + assert list(new_df.recommendation.keys()) == ["Column Groups"] assert len(new_df.cardinality) == 3 + def test_iloc(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true') - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - new_df = df.iloc[:,3:9] + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" + ) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + new_df = df.iloc[:, 3:9] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution', 'Occurrence', 'Temporal'] + assert list(new_df.recommendation.keys()) == [ + "Correlation", + "Distribution", + "Occurrence", + "Temporal", + ] assert len(new_df.cardinality) == 6 - new_df = df.iloc[0:11,3:9] + new_df = df.iloc[0:11, 3:9] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution'] + assert list(new_df.recommendation.keys()) == ["Correlation", "Distribution"] assert len(new_df.cardinality) == 6 - new_df = df.iloc[0:11,3:5] + new_df = df.iloc[0:11, 3:5] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Correlation', 'Distribution'] + assert list(new_df.recommendation.keys()) == ["Correlation", "Distribution"] assert len(new_df.cardinality) == 2 import numpy as np - inter_df = df.groupby("Brand")[["Acceleration", "Weight", "Horsepower"]].agg(np.mean) + + inter_df = df.groupby("Brand")[["Acceleration", "Weight", "Horsepower"]].agg( + np.mean + ) new_df = inter_df.iloc[5:10, 0:2] new_df._repr_html_() - assert list(new_df.recommendation.keys() ) == ['Column Groups'] + assert list(new_df.recommendation.keys()) == ["Column Groups"] assert len(new_df.cardinality) == 3 + def check_metadata_equal(df1, df2): # Checks to make sure metadata for df1 and df2 are equal. for attr in df1._metadata: @@ -335,6 +467,7 @@ def compare_clauses(clause1, clause2): assert clause1.sort == clause2.sort assert clause1.exclude == clause2.exclude + def compare_vis(vis1, vis2): assert len(vis1._intent) == len(vis2._intent) for j in range(len(vis1._intent)): @@ -342,7 +475,7 @@ def compare_vis(vis1, vis2): assert len(vis1._inferred_intent) == len(vis2._inferred_intent) for j in range(len(vis1._inferred_intent)): compare_clauses(vis1._inferred_intent[j], vis2._inferred_intent[j]) - assert vis1._source == vis2._source + assert vis1._source == vis2._source assert vis1._code == vis2._code assert vis1._mark == vis2._mark assert vis1._min_max == vis2._min_max @@ -350,41 +483,116 @@ def compare_vis(vis1, vis2): assert vis1.title == vis2.title assert vis1.score == vis2.score + ################ # Series Tests # ################ + def test_df_to_series(): # Ensure metadata is kept when going from df to series df = pd.read_csv("lux/data/car.csv") - df._repr_html_() # compute metadata + df._repr_html_() # compute metadata assert df.cardinality is not None series = df["Weight"] - assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." + assert isinstance( + series, lux.core.series.LuxSeries + ), "Derived series is type LuxSeries." df["Weight"]._metadata - assert df["Weight"]._metadata == ['_intent','data_type_lookup','data_type','data_model_lookup','data_model','unique_values','cardinality','_rec_info','_pandas_only','_min_max','plot_config','_current_vis','_widget','_recommendation','_prev','_history','_saved_export'], "Metadata is lost when going from Dataframe to Series." - assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." - assert series.name == "Weight", "Pandas Series original `name` property not retained." + assert df["Weight"]._metadata == [ + "_intent", + "data_type_lookup", + "data_type", + "data_model_lookup", + "data_model", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + ], "Metadata is lost when going from Dataframe to Series." + assert ( + df.cardinality is not None + ), "Metadata is lost when going from Dataframe to Series." + assert ( + series.name == "Weight" + ), "Pandas Series original `name` property not retained." + def test_value_counts(): df = pd.read_csv("lux/data/car.csv") - df._repr_html_() # compute metadata + df._repr_html_() # compute metadata assert df.cardinality is not None series = df["Weight"] series.value_counts() - assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." - assert df["Weight"]._metadata == ['_intent','data_type_lookup','data_type','data_model_lookup','data_model','unique_values','cardinality','_rec_info','_pandas_only','_min_max','plot_config','_current_vis','_widget','_recommendation','_prev','_history','_saved_export'], "Metadata is lost when going from Dataframe to Series." - assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." - assert series.name == "Weight", "Pandas Series original `name` property not retained." + assert isinstance( + series, lux.core.series.LuxSeries + ), "Derived series is type LuxSeries." + assert df["Weight"]._metadata == [ + "_intent", + "data_type_lookup", + "data_type", + "data_model_lookup", + "data_model", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + ], "Metadata is lost when going from Dataframe to Series." + assert ( + df.cardinality is not None + ), "Metadata is lost when going from Dataframe to Series." + assert ( + series.name == "Weight" + ), "Pandas Series original `name` property not retained." + def test_str_replace(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true' + url = "https://github.com/lux-org/lux-datasets/blob/master/data/cars.csv?raw=true" df = pd.read_csv(url) - df._repr_html_() # compute metadata + df._repr_html_() # compute metadata assert df.cardinality is not None series = df["Brand"].str.replace("chevrolet", "chevy") - assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." - assert df["Brand"]._metadata == ['_intent','data_type_lookup','data_type','data_model_lookup','data_model','unique_values','cardinality','_rec_info','_pandas_only','_min_max','plot_config','_current_vis','_widget','_recommendation','_prev','_history','_saved_export'], "Metadata is lost when going from Dataframe to Series." - assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." - assert series.name == "Brand", "Pandas Series original `name` property not retained." - + assert isinstance( + series, lux.core.series.LuxSeries + ), "Derived series is type LuxSeries." + assert df["Brand"]._metadata == [ + "_intent", + "data_type_lookup", + "data_type", + "data_model_lookup", + "data_model", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + ], "Metadata is lost when going from Dataframe to Series." + assert ( + df.cardinality is not None + ), "Metadata is lost when going from Dataframe to Series." + assert ( + series.name == "Brand" + ), "Pandas Series original `name` property not retained." diff --git a/tests/test_parser.py b/tests/test_parser.py index 9b400554..67021583 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,67 +15,74 @@ import pandas as pd import lux + def test_case1(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent(["Horsepower"]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == "Horsepower") + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent(["Horsepower"]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == "Horsepower" + def test_case2(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent(["Horsepower", lux.Clause("MilesPerGal", channel="x")]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == "Horsepower") - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "MilesPerGal") + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent(["Horsepower", lux.Clause("MilesPerGal", channel="x")]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == "Horsepower" + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "MilesPerGal" + def test_case3(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent(["Horsepower", "Origin=USA"]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == "Horsepower") - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "Origin") - assert(ldf._intent[1].value == "USA") + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent(["Horsepower", "Origin=USA"]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == "Horsepower" + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "Origin" + assert ldf._intent[1].value == "USA" + def test_case4(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent(["Horsepower", "Origin=USA|Japan"]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == "Horsepower") - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "Origin") - assert(ldf._intent[1].value == ["USA","Japan"]) + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent(["Horsepower", "Origin=USA|Japan"]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == "Horsepower" + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "Origin" + assert ldf._intent[1].value == ["USA", "Japan"] + def test_case5(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent([["Horsepower", "MilesPerGal", "Weight"], "Origin=USA"]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == ["Horsepower", "MilesPerGal", "Weight"]) - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "Origin") - assert(ldf._intent[1].value == "USA") - - ldf.set_intent(["Horsepower|MilesPerGal|Weight", "Origin=USA"]) - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == ["Horsepower", "MilesPerGal", "Weight"]) - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "Origin") - assert(ldf._intent[1].value == "USA") + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent([["Horsepower", "MilesPerGal", "Weight"], "Origin=USA"]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == ["Horsepower", "MilesPerGal", "Weight"] + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "Origin" + assert ldf._intent[1].value == "USA" + + ldf.set_intent(["Horsepower|MilesPerGal|Weight", "Origin=USA"]) + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == ["Horsepower", "MilesPerGal", "Weight"] + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "Origin" + assert ldf._intent[1].value == "USA" + def test_case6(): - ldf = pd.read_csv("lux/data/car.csv") - ldf.set_intent(["Horsepower", "Origin=?"]) - ldf._repr_html_() - assert(type(ldf._intent[0]) is lux.Clause) - assert(ldf._intent[0].attribute == "Horsepower") - assert(type(ldf._intent[1]) is lux.Clause) - assert(ldf._intent[1].attribute == "Origin") - assert(ldf._intent[1].value == ["USA","Japan","Europe"]) + ldf = pd.read_csv("lux/data/car.csv") + ldf.set_intent(["Horsepower", "Origin=?"]) + ldf._repr_html_() + assert type(ldf._intent[0]) is lux.Clause + assert ldf._intent[0].attribute == "Horsepower" + assert type(ldf._intent[1]) is lux.Clause + assert ldf._intent[1].attribute == "Origin" + assert ldf._intent[1].value == ["USA", "Japan", "Europe"] + # TODO: Need to support this case -''' +""" lux.set_intent(["Horsepower","MPG","Acceleration"],"Origin") lux.set_intent("Horsepower/MPG/Acceleration", "Origin") --> [Clause(attr= ["Horsepower","MPG","Acceleration"], type= "attributeGroup")] -''' \ No newline at end of file +""" diff --git a/tests/test_performance.py b/tests/test_performance.py index 588aea55..a30b4cd2 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,35 +16,42 @@ import pytest import pandas as pd import time + # To run the script and see the printed result, run: # python -m pytest -s tests/test_performance.py def test_q1_performance_census(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/census.csv?raw=true' - df = pd.read_csv(url) - tic = time.perf_counter() - df._repr_html_() - toc = time.perf_counter() - delta = toc - tic - df._repr_html_() - toc2 = time.perf_counter() - delta2 = toc2 - toc - print(f"1st display Performance: {delta:0.4f} seconds") - print(f"2nd display Performance: {delta2:0.4f} seconds") - assert delta < 4.5, "The recommendations on Census dataset took a total of {delta:0.4f} seconds, longer than expected." - assert delta2 < 0.15 order_id, product_id, user_id is not visualized since it resembles an ID field." in df._message.to_html() + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/instacart_sample.csv?raw=true" + ) + df._repr_html_() + assert len(df.data_type["id"]) == 3 + assert ( + "order_id, product_id, user_id is not visualized since it resembles an ID field." + in df._message.to_html() + ) + def test_check_str_id(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/churn.csv?raw=true') - df._repr_html_() - assert "customerID is not visualized since it resembles an ID field." in df._message.to_html() + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/churn.csv?raw=true" + ) + df._repr_html_() + assert ( + "customerID is not visualized since it resembles an ID field." + in df._message.to_html() + ) + def test_check_hpi(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/hpi.csv?raw=true').head(10) + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/hpi.csv?raw=true" + ).head(10) + + df.maintain_metadata() - df.maintain_metadata() + assert df.data_type_lookup == { + "HPIRank": "quantitative", + "Country": "nominal", + "SubRegion": "nominal", + "AverageLifeExpectancy": "quantitative", + "AverageWellBeing": "quantitative", + "HappyLifeYears": "quantitative", + "Footprint": "quantitative", + "InequalityOfOutcomes": "quantitative", + "InequalityAdjustedLifeExpectancy": "quantitative", + "InequalityAdjustedWellbeing": "quantitative", + "HappyPlanetIndex": "quantitative", + "GDPPerCapita": "quantitative", + "Population": "quantitative", + } - assert df.data_type_lookup == {'HPIRank': 'quantitative', - 'Country': 'nominal', - 'SubRegion': 'nominal', - 'AverageLifeExpectancy': 'quantitative', - 'AverageWellBeing': 'quantitative', - 'HappyLifeYears': 'quantitative', - 'Footprint': 'quantitative', - 'InequalityOfOutcomes': 'quantitative', - 'InequalityAdjustedLifeExpectancy': 'quantitative', - 'InequalityAdjustedWellbeing': 'quantitative', - 'HappyPlanetIndex': 'quantitative', - 'GDPPerCapita': 'quantitative', - 'Population': 'quantitative'} def test_check_airbnb(): - df = pd.read_csv('https://github.com/lux-org/lux-datasets/blob/master/data/airbnb_nyc.csv?raw=true') - df.maintain_metadata() - assert df.data_type_lookup == {'id': 'id', - 'name': 'nominal', - 'host_id': 'id', - 'host_name': 'nominal', - 'neighbourhood_group': 'nominal', - 'neighbourhood': 'nominal', - 'latitude': 'quantitative', - 'longitude': 'quantitative', - 'room_type': 'nominal', - 'price': 'quantitative', - 'minimum_nights': 'quantitative', - 'number_of_reviews': 'quantitative', - 'last_review': 'nominal', - 'reviews_per_month': 'quantitative', - 'calculated_host_listings_count': 'quantitative', - 'availability_365': 'quantitative'} + df = pd.read_csv( + "https://github.com/lux-org/lux-datasets/blob/master/data/airbnb_nyc.csv?raw=true" + ) + df.maintain_metadata() + assert df.data_type_lookup == { + "id": "id", + "name": "nominal", + "host_id": "id", + "host_name": "nominal", + "neighbourhood_group": "nominal", + "neighbourhood": "nominal", + "latitude": "quantitative", + "longitude": "quantitative", + "room_type": "nominal", + "price": "quantitative", + "minimum_nights": "quantitative", + "number_of_reviews": "quantitative", + "last_review": "nominal", + "reviews_per_month": "quantitative", + "calculated_host_listings_count": "quantitative", + "availability_365": "quantitative", + } + def test_check_college(): - df = pd.read_csv('lux/data/college.csv') - df.maintain_metadata() - assert df.data_type_lookup == {'Name': 'nominal', - 'PredominantDegree': 'nominal', - 'HighestDegree': 'nominal', - 'FundingModel': 'nominal', - 'Region': 'nominal', - 'Geography': 'nominal', - 'AdmissionRate': 'quantitative', - 'ACTMedian': 'quantitative', - 'SATAverage': 'quantitative', - 'AverageCost': 'quantitative', - 'Expenditure': 'quantitative', - 'AverageFacultySalary': 'quantitative', - 'MedianDebt': 'quantitative', - 'AverageAgeofEntry': 'quantitative', - 'MedianFamilyIncome': 'quantitative', - 'MedianEarnings': 'quantitative'} \ No newline at end of file + df = pd.read_csv("lux/data/college.csv") + df.maintain_metadata() + assert df.data_type_lookup == { + "Name": "nominal", + "PredominantDegree": "nominal", + "HighestDegree": "nominal", + "FundingModel": "nominal", + "Region": "nominal", + "Geography": "nominal", + "AdmissionRate": "quantitative", + "ACTMedian": "quantitative", + "SATAverage": "quantitative", + "AverageCost": "quantitative", + "Expenditure": "quantitative", + "AverageFacultySalary": "quantitative", + "MedianDebt": "quantitative", + "AverageAgeofEntry": "quantitative", + "MedianFamilyIncome": "quantitative", + "MedianEarnings": "quantitative", + } diff --git a/tests/test_vis.py b/tests/test_vis.py index e9ffa33c..0f6f9eec 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -1,5 +1,5 @@ # Copyright 2019-2020 The Lux Authors. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,152 +17,223 @@ import pandas as pd from lux.vis.VisList import VisList from lux.vis.Vis import Vis + + def test_vis(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - vis = Vis(["Height","SportType=Ball"],df) - assert vis.get_attr_by_attr_name("Height")[0].bin_size!=0 - assert vis.get_attr_by_attr_name("Record")[0].aggregation == 'count' - + vis = Vis(["Height", "SportType=Ball"], df) + assert vis.get_attr_by_attr_name("Height")[0].bin_size != 0 + assert vis.get_attr_by_attr_name("Record")[0].aggregation == "count" + + def test_vis_set_specs(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - vis = Vis(["Height","SportType=Ball"],df) - vis.set_intent(["Height","SportType=Ice"]) - assert vis.get_attr_by_attr_name("SportType")[0].value =="Ice" + vis = Vis(["Height", "SportType=Ball"], df) + vis.set_intent(["Height", "SportType=Ice"]) + assert vis.get_attr_by_attr_name("SportType")[0].value == "Ice" + def test_vis_collection(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - vlist = VisList(["Height","SportType=Ball","?"],df) - vis_with_year = list(filter(lambda x: x.get_attr_by_attr_name("Year")!=[],vlist))[0] - assert vis_with_year.get_attr_by_channel("x")[0].attribute=="Year" - assert len(vlist) == len(df.columns) -1 -1 #remove 1 for vis with same filter attribute and remove 1 vis with for same attribute - vlist = VisList(["Height","?"],df) - assert len(vlist) == len(df.columns) -1 #remove 1 for vis with for same attribute + vlist = VisList(["Height", "SportType=Ball", "?"], df) + vis_with_year = list( + filter(lambda x: x.get_attr_by_attr_name("Year") != [], vlist) + )[0] + assert vis_with_year.get_attr_by_channel("x")[0].attribute == "Year" + assert ( + len(vlist) == len(df.columns) - 1 - 1 + ) # remove 1 for vis with same filter attribute and remove 1 vis with for same attribute + vlist = VisList(["Height", "?"], df) + assert len(vlist) == len(df.columns) - 1 # remove 1 for vis with for same attribute + def test_vis_collection_set_intent(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - vlist = VisList(["Height","SportType=Ice","?"],df) - vlist.set_intent(["Height","SportType=Boat","?"]) - for v in vlist._collection: - filter_vspec = list(filter(lambda x: x.channel=="",v._inferred_intent))[0] - assert filter_vspec.value =="Boat" + vlist = VisList(["Height", "SportType=Ice", "?"], df) + vlist.set_intent(["Height", "SportType=Boat", "?"]) + for v in vlist._collection: + filter_vspec = list(filter(lambda x: x.channel == "", v._inferred_intent))[0] + assert filter_vspec.value == "Boat" + + def test_custom_plot_setting(): def change_color_make_transparent_add_title(chart): - chart = chart.configure_mark(color="green",opacity=0.2) + chart = chart.configure_mark(color="green", opacity=0.2) chart.title = "Test Title" return chart + df = pd.read_csv("lux/data/car.csv") df.plot_config = change_color_make_transparent_add_title df._repr_html_() - config_mark_addition = 'chart = chart.configure_mark(color="green",opacity=0.2)' - title_addition ='chart.title = "Test Title"' + config_mark_addition = 'chart = chart.configure_mark(color="green", opacity=0.2)' + title_addition = 'chart.title = "Test Title"' exported_code_str = df.recommendation["Correlation"][0].to_Altair() assert config_mark_addition in exported_code_str assert title_addition in exported_code_str + def test_remove(): df = pd.read_csv("lux/data/car.csv") - vis = Vis([lux.Clause("Horsepower"),lux.Clause("Acceleration")],df) - vis.remove_column_from_spec("Horsepower",remove_first=False) + vis = Vis([lux.Clause("Horsepower"), lux.Clause("Acceleration")], df) + vis.remove_column_from_spec("Horsepower", remove_first=False) assert vis._inferred_intent[0].attribute == "Acceleration" + + def test_remove_identity(): df = pd.read_csv("lux/data/car.csv") - vis = Vis(["Horsepower","Horsepower"],df) + vis = Vis(["Horsepower", "Horsepower"], df) vis.remove_column_from_spec("Horsepower") - assert (vis._inferred_intent == []),"Remove all instances of Horsepower" + assert vis._inferred_intent == [], "Remove all instances of Horsepower" df = pd.read_csv("lux/data/car.csv") - vis = Vis(["Horsepower","Horsepower"],df) - vis.remove_column_from_spec("Horsepower",remove_first=True) - assert (len(vis._inferred_intent)==1),"Remove only 1 instances of Horsepower" - assert (vis._inferred_intent[0].attribute=="Horsepower"),"Remove only 1 instances of Horsepower" + vis = Vis(["Horsepower", "Horsepower"], df) + vis.remove_column_from_spec("Horsepower", remove_first=True) + assert len(vis._inferred_intent) == 1, "Remove only 1 instances of Horsepower" + assert ( + vis._inferred_intent[0].attribute == "Horsepower" + ), "Remove only 1 instances of Horsepower" + + def test_refresh_collection(): df = pd.read_csv("lux/data/car.csv") - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - df.set_intent([lux.Clause(attribute = "Acceleration"),lux.Clause(attribute = "Horsepower")]) + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + df.set_intent( + [lux.Clause(attribute="Acceleration"), lux.Clause(attribute="Horsepower")] + ) df._repr_html_() enhanceCollection = df.recommendation["Enhance"] - enhanceCollection.refresh_source(df[df["Origin"]=="USA"]) + enhanceCollection.refresh_source(df[df["Origin"] == "USA"]) + def test_vis_custom_aggregation_as_str(): df = pd.read_csv("lux/data/college.csv") import numpy as np - vis = Vis(["HighestDegree",lux.Clause("AverageCost",aggregation="max")],df) + + vis = Vis(["HighestDegree", lux.Clause("AverageCost", aggregation="max")], df) assert vis.get_attr_by_data_model("measure")[0].aggregation == "max" - assert vis.get_attr_by_data_model("measure")[0]._aggregation_name =='max' - + assert vis.get_attr_by_data_model("measure")[0]._aggregation_name == "max" + + def test_vis_custom_aggregation_as_numpy_func(): df = pd.read_csv("lux/data/college.csv") from lux.vis.Vis import Vis import numpy as np - vis = Vis(["HighestDegree",lux.Clause("AverageCost",aggregation=np.ptp)],df) + + vis = Vis(["HighestDegree", lux.Clause("AverageCost", aggregation=np.ptp)], df) assert vis.get_attr_by_data_model("measure")[0].aggregation == np.ptp - assert vis.get_attr_by_data_model("measure")[0]._aggregation_name =='ptp' + assert vis.get_attr_by_data_model("measure")[0]._aggregation_name == "ptp" + + def test_vis_collection_via_list_of_vis(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - df["Year"] = pd.to_datetime(df["Year"], format='%Y') # change pandas dtype for the column "Year" to datetype + df["Year"] = pd.to_datetime( + df["Year"], format="%Y" + ) # change pandas dtype for the column "Year" to datetype from lux.vis.VisList import VisList from lux.vis.Vis import Vis + vcLst = [] - for attribute in ['Sport','Year','Height','HostRegion','SportType']: + for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]: vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)]) vcLst.append(vis) - vlist = VisList(vcLst,df) + vlist = VisList(vcLst, df) assert len(vlist) == 5 + + def test_vis_to_Altair_basic_df(): df = pd.read_csv("lux/data/car.csv") - vis = Vis(['Weight','Horsepower'],df) + vis = Vis(["Weight", "Horsepower"], df) code = vis.to_Altair() - assert "alt.Chart(df)" in code , "Unable to export to Altair" + assert "alt.Chart(df)" in code, "Unable to export to Altair" + + def test_vis_to_Altair_custom_named_df(): df = pd.read_csv("lux/data/car.csv") some_weirdly_named_df = df.dropna() - vis = Vis(['Weight','Horsepower'],some_weirdly_named_df) + vis = Vis(["Weight", "Horsepower"], some_weirdly_named_df) code = vis.to_Altair() - assert "alt.Chart(some_weirdly_named_df)" in code , "Unable to export to Altair and detect custom df name" + assert ( + "alt.Chart(some_weirdly_named_df)" in code + ), "Unable to export to Altair and detect custom df name" + + def test_vis_to_Altair_standalone(): df = pd.read_csv("lux/data/car.csv") - vis = Vis(['Weight','Horsepower'],df) + vis = Vis(["Weight", "Horsepower"], df) code = vis.to_Altair(standalone=True) - assert "chart = alt.Chart(pd.DataFrame({'Weight': {0: 3504, 1: 3693, 2: 3436, 3: 3433, 4: 3449, 5: 43" in code or "alt.Chart(pd.DataFrame({'Horsepower': {0: 130, 1: 165, 2: 150, 3: 150, 4: 140," in code + assert ( + "chart = alt.Chart(pd.DataFrame({'Weight': {0: 3504, 1: 3693, 2: 3436, 3: 3433, 4: 3449, 5: 43" + in code + or "alt.Chart(pd.DataFrame({'Horsepower': {0: 130, 1: 165, 2: 150, 3: 150, 4: 140," + in code + ) + + def test_vis_list_custom_title_override(): - url = 'https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true' + url = ( + "https://github.com/lux-org/lux-datasets/blob/master/data/olympic.csv?raw=true" + ) df = pd.read_csv(url) - df["Year"] = pd.to_datetime(df["Year"], format='%Y') - + df["Year"] = pd.to_datetime(df["Year"], format="%Y") + vcLst = [] - for attribute in ['Sport','Year','Height','HostRegion','SportType']: - vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)],title="overriding dummy title") + for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]: + vis = Vis( + [lux.Clause("Weight"), lux.Clause(attribute)], + title="overriding dummy title", + ) vcLst.append(vis) - vlist = VisList(vcLst,df) - for v in vlist: - assert v.title=="overriding dummy title" + vlist = VisList(vcLst, df) + for v in vlist: + assert v.title == "overriding dummy title" + + def test_vis_set_intent(): from lux.vis.Vis import Vis + df = pd.read_csv("lux/data/car.csv") - vis = Vis(["Weight","Horsepower"],df) + vis = Vis(["Weight", "Horsepower"], df) vis._repr_html_() assert "Horsepower" in str(vis._code) - vis.intent = ["Weight","MilesPerGal"] + vis.intent = ["Weight", "MilesPerGal"] vis._repr_html_() assert "MilesPerGal" in str(vis._code) + + def test_vis_list_set_intent(): from lux.vis.VisList import VisList + df = pd.read_csv("lux/data/car.csv") - vislist = VisList(["Horsepower","?"],df) + vislist = VisList(["Horsepower", "?"], df) vislist._repr_html_() - for vis in vislist: assert vis.get_attr_by_attr_name("Horsepower")!=[] - vislist.intent = ["Weight","?"] + for vis in vislist: + assert vis.get_attr_by_attr_name("Horsepower") != [] + vislist.intent = ["Weight", "?"] vislist._repr_html_() - for vis in vislist: assert vis.get_attr_by_attr_name("Weight")!=[] + for vis in vislist: + assert vis.get_attr_by_attr_name("Weight") != [] + + def test_text_not_overridden(): from lux.vis.Vis import Vis + df = pd.read_csv("lux/data/college.csv") vis = Vis(["Region", "Geography"], df) vis._repr_html_() code = vis.to_Altair() - assert "color = \"#ff8e04\"" in code + assert 'color = "#ff8e04"' in code