diff --git a/.travis.yml b/.travis.yml index 5197d6ea..538316b6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,14 @@ install: - pip install jupyter-client==6.1.6 - pip install -r requirements.txt - pip install -r requirements-dev.txt +#command to generate postgres database +before_script: + - psql -c "ALTER USER postgres WITH PASSWORD 'lux';" -U postgres + - psql -c "ALTER USER postgres WITH SUPERUSER;" -U postgres + - psql -c "ALTER DATABASE postgres OWNER TO travis;" + - psql -c "DROP schema public cascade;" -U postgres + - psql -c "CREATE schema public;" -U postgres + - psql -c "CREATE DATABASE postgres;" -U postgres # command to run tests script: - python lux/data/upload_car_data.py diff --git a/doc/source/guide/FAQ.rst b/doc/source/guide/FAQ.rst index c99d29ed..5da2e0c3 100644 --- a/doc/source/guide/FAQ.rst +++ b/doc/source/guide/FAQ.rst @@ -35,6 +35,20 @@ How do I set the Lux widgets to show up on default? .. code-block:: python lux.config.default_display = "pandas" + +How do I change the plotting library used for visualization? +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + By default, we make use of `Altair `__ to generate `Vega-Lite `__ visualizations. We can modify the :code:`plotting_backend` config property to use `Matplotlib `__ as the plotting library instead: + + .. code-block:: python + + lux.config.plotting_backend = "matplotlib" + + To switch back to Vega-Lite: + + .. code-block:: python + + lux.config.plotting_backend = "vegalite" I want to change the opacity of my chart, add title, change chart font size, etc. How do I modify chart settings? """"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" diff --git a/doc/source/guide/style.rst b/doc/source/guide/style.rst index 58a0fa71..4f4712b2 100644 --- a/doc/source/guide/style.rst +++ b/doc/source/guide/style.rst @@ -25,7 +25,8 @@ To change the plot configuration in Altair, we need to specify a function that t Let's say that we want to change all the graphical marks of the charts to green and add a custom title. We can define this `change_color_add_title` function, which configures the chart's mark as green and adds a custom title to the chart. .. code-block:: python - + lux.config.plotting_backend = "altair" # or 'vegalite' + def change_color_add_title(chart): chart = chart.configure_mark(color="green") # change mark color to green chart.title = "Custom Title" # add title to chart @@ -43,6 +44,26 @@ We now see that the displayed visualizations adopt these new imported settings. :width: 700 :align: center +Similarly, we can change the plot configurations for Matplotlib charts as well. +The plot_config attribute for Matplotlib charts takes in both the figure and axis as parameters. +.. code-block:: python + + lux.config.plotting_backend = "matplotlib" # or 'matplotlib_code' + + def add_title(fig, ax): + ax.set_title("Test Title") + return fig, ax + +.. code-block:: python + + lux.config.plot_config = add_title + +We now see that the displayed visualizations adopt these new imported settings. + +.. image:: ../img/style-7.png + :width: 700 + :align: center + If we click on the visualization for `Displacement` v.s. `Weight` and export it. We see that the exported chart now contains code with these additional plot settings at the every end. .. code-block:: python diff --git a/doc/source/img/style-7.png b/doc/source/img/style-7.png new file mode 100644 index 00000000..da1cc902 Binary files /dev/null and b/doc/source/img/style-7.png differ diff --git a/doc/source/reference/config.rst b/doc/source/reference/config.rst index 14ac5e48..522f9f53 100644 --- a/doc/source/reference/config.rst +++ b/doc/source/reference/config.rst @@ -62,6 +62,45 @@ If you try to set the default_display to anything other than 'lux' or 'pandas,' :width: 700 :align: center +Change plotting backend for rendering visualizations in Lux +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We can set the :code:`plotting_backend` config to change the plotting library used for rendering the visualizations in Lux. +This is often useful not just for stylizing plot aesthetics, but also to change the code generated when `exporting a visualization `__. +For example, if you are more familiar with `matplotlib `__ , you may want to use a matplotlib plotting backend so that you can make use of the exported visualization code. In the following code, we set the plotting backend to 'matplotlib', and Lux will display the Matplotlib rendered charts. + +.. code-block:: python + + lux.config.plotting_backend = "matplotlib" + df + +.. image:: https://github.com/lux-org/lux-resources/blob/master/doc_img/vislib-1.png?raw=true + :width: 700 + :align: center + +We can set the vislib back to the default 'vegalite,' which uses Vega-Lite to render the displayed chart. + +.. code-block:: python + + lux.config.plotting_backend = "vegalite" + df + +.. image:: https://github.com/lux-org/lux-resources/blob/master/doc_img/display-1.png?raw=true + :width: 700 + :align: center + +Lux currently only support Vega-Lite and matplotlib, and we plan to add support for other plotting libraries in the future. If you try to set the :code:`plotting_backend` to anything other than 'matplotlib' or 'vegalite', a warning will be shown, and the display will default to the previous setting. + +.. code-block:: python + + lux.config.plotting_backend = "notvegalite" # Throw an warning + df + +.. image:: https://github.com/lux-org/lux-resources/blob/master/doc_img/vislib-2.png?raw=true + + :width: 700 + :align: center + Change the sampling parameters of Lux ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -91,20 +130,8 @@ We can disable this feature and revert back to using a scatter plot by running t lux.config.heatmap = False - -Default Renderer -~~~~~~~~~~~~~~~~~ - -Charts in Lux are rendered using `Altair `__. We are working on supporting plotting via `matplotlib `__ and other plotting libraries. - -To change the default renderer, run the following code block: - -.. code-block:: python - - lux.config.renderer = "matplotlib" - -Plot Configurations -~~~~~~~~~~~~~~~~~~~ +Changing the plot styling +~~~~~~~~~~~~~~~~~~~~~~~~~~ Altair supports plot configurations to be applied on top of the generated graphs. To set a default plot configuration, first write a function that can take in a `chart` and returns a `chart`. For example: @@ -129,6 +156,27 @@ The above results in the following changes: See `this page `__ for more details. +Matplotlib also supports plot configurations to be applied on top of the generated graphs. To set a default plot configuration, first write a function that can take in a `fig` and 'ax' and returns a `fig` and 'ax. For example: + +.. code-block:: python + + def add_title(fig, ax): + ax.set_title("Test Title") + return fig, ax + +.. code-block:: python + + lux.config.plot_config = add_title + +The above results in the following changes: + +.. image:: https://github.com/lux-org/lux-resources/blob/master/doc_img/style-7.png?raw=true + :width: 600 + :align: center + +See `this page `__ for more details. + + Modify Sorting and Ranking in Recommendations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/reference/gen/lux._config.config.Config.rst b/doc/source/reference/gen/lux._config.config.Config.rst index 48db70fb..4f0afd64 100644 --- a/doc/source/reference/gen/lux._config.config.Config.rst +++ b/doc/source/reference/gen/lux._config.config.Config.rst @@ -29,6 +29,7 @@ lux.\_config.config.Config ~Config.default_display ~Config.heatmap + ~Config.plotting_backend ~Config.sampling ~Config.sampling_cap ~Config.sampling_start diff --git a/doc/source/reference/gen/lux.vis.Vis.Vis.rst b/doc/source/reference/gen/lux.vis.Vis.Vis.rst index dc79967d..216bc658 100644 --- a/doc/source/reference/gen/lux.vis.Vis.Vis.rst +++ b/doc/source/reference/gen/lux.vis.Vis.Vis.rst @@ -26,6 +26,8 @@ lux.vis.Vis.Vis ~Vis.to_Altair ~Vis.to_VegaLite ~Vis.to_code + ~Vis.to_matplotlib + ~Vis.to_matplotlib_code diff --git a/doc/source/reference/lux.vislib.matplotlib.rst b/doc/source/reference/lux.vislib.matplotlib.rst new file mode 100644 index 00000000..797f054c --- /dev/null +++ b/doc/source/reference/lux.vislib.matplotlib.rst @@ -0,0 +1,70 @@ +lux.vislib.matplotlib package +============================= + +Submodules +---------- + +lux.vislib.matplotlib.BarChart module +------------------------------------- + +.. automodule:: lux.vislib.matplotlib.BarChart + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.Heatmap module +------------------------------------ + +.. automodule:: lux.vislib.matplotlib.Heatmap + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.Histogram module +-------------------------------------- + +.. automodule:: lux.vislib.matplotlib.Histogram + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.LineChart module +-------------------------------------- + +.. automodule:: lux.vislib.matplotlib.LineChart + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.MatplotlibChart module +-------------------------------------------- + +.. automodule:: lux.vislib.matplotlib.MatplotlibChart + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.MatplotlibRenderer module +----------------------------------------------- + +.. automodule:: lux.vislib.matplotlib.MatplotlibRenderer + :members: + :undoc-members: + :show-inheritance: + +lux.vislib.matplotlib.ScatterChart module +----------------------------------------- + +.. automodule:: lux.vislib.matplotlib.ScatterChart + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: lux.vislib.matplotlib + :members: + :undoc-members: + :show-inheritance: diff --git a/lux/_config/config.py b/lux/_config/config.py index e7809edf..d16678ef 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -17,7 +17,6 @@ class Config: def __init__(self): self._default_display = "pandas" - self.renderer = "altair" self.plot_config = None self.SQLconnection = "" self.executor = None @@ -30,6 +29,7 @@ def __init__(self): self._sampling_cap = 30000 self._sampling_flag = True self._heatmap_flag = True + self._plotting_backend = "vegalite" self._topk = 15 self._sort = "descending" self._pandas_fallback = True @@ -261,6 +261,29 @@ def default_display(self, type: str) -> None: stacklevel=2, ) + @property + def plotting_backend(self): + return self._plotting_backend + + @plotting_backend.setter + def plotting_backend(self, type: str) -> None: + """ + Set the widget display to show Vegalite by default or Matplotlib by default + Parameters + ---------- + type : str + Default display type, can take either the string `vegalite` or `matplotlib` (regardless of capitalization) + """ + if type.lower() == "vegalite" or type.lower() == "altair": + self._plotting_backend = "vegalite" + elif type.lower() == "matplotlib": + self._plotting_backend = "matplotlib" + else: + warnings.warn( + "Unsupported plotting backend. Lux currently only support 'altair', 'vegalite', or 'matplotlib'", + stacklevel=2, + ) + def _get_action(self, pat: str, silent: bool = False): return lux.actions[pat] diff --git a/lux/core/frame.py b/lux/core/frame.py index d3a68454..d7ea28ec 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -710,7 +710,7 @@ def current_vis_to_JSON(vlist, input_current_vis=""): current_vis_spec = {} numVC = len(vlist) # number of visualizations in the vis list if numVC == 1: - current_vis_spec = vlist[0].to_code(prettyOutput=False) + current_vis_spec = vlist[0].to_code(language=lux.config.plotting_backend, prettyOutput=False) elif numVC > 1: pass return current_vis_spec @@ -725,7 +725,7 @@ def rec_to_JSON(recs): if len(rec["collection"]) > 0: rec["vspec"] = [] for vis in rec["collection"]: - chart = vis.to_code(prettyOutput=False) + chart = vis.to_code(language=lux.config.plotting_backend, prettyOutput=False) rec["vspec"].append(chart) rec_lst.append(rec) # delete since not JSON serializable diff --git a/lux/utils/utils.py b/lux/utils/utils.py index 3ae4503d..ac909f79 100644 --- a/lux/utils/utils.py +++ b/lux/utils/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pandas as pd +import matplotlib.pyplot as plt def convert_to_list(x): @@ -103,3 +104,13 @@ def like_nan(val): import math return math.isnan(val) + + +def matplotlib_setup(w, h): + plt.ioff() + fig, ax = plt.subplots(figsize=(w, h)) + ax.set_axisbelow(True) + ax.grid(color="#dddddd") + ax.spines["right"].set_color("#dddddd") + ax.spines["top"].set_color("#dddddd") + return fig, ax diff --git a/lux/vis/Vis.py b/lux/vis/Vis.py index ccc3527f..f70abd71 100644 --- a/lux/vis/Vis.py +++ b/lux/vis/Vis.py @@ -16,6 +16,7 @@ from lux.vis.Clause import Clause from lux.utils.utils import check_import_lux_widget import lux +import warnings class Vis: @@ -236,6 +237,36 @@ def to_Altair(self, standalone=False) -> str: self._code = renderer.create_vis(self, standalone) return self._code + def to_matplotlib(self) -> str: + """ + Generate minimal Matplotlib code to visualize the Vis + + Returns + ------- + str + String version of the Matplotlib code. Need to print out the string to apply formatting. + """ + from lux.vislib.matplotlib.MatplotlibRenderer import MatplotlibRenderer + + renderer = MatplotlibRenderer(output_type="matplotlib") + self._code = renderer.create_vis(self) + return self._code + + def to_matplotlib_code(self) -> str: + """ + Generate minimal Matplotlib code to visualize the Vis + + Returns + ------- + str + String version of the Matplotlib code. Need to print out the string to apply formatting. + """ + from lux.vislib.matplotlib.MatplotlibRenderer import MatplotlibRenderer + + renderer = MatplotlibRenderer(output_type="matplotlib_code") + self._code = renderer.create_vis(self) + return self._code + def to_VegaLite(self, prettyOutput=True) -> Union[dict, str]: """ Generate minimal Vega-Lite code to visualize the Vis @@ -276,6 +307,15 @@ def to_code(self, language="vegalite", **kwargs): return self.to_VegaLite(**kwargs) elif language == "altair": return self.to_Altair(**kwargs) + elif language == "matplotlib": + return self.to_matplotlib() + elif language == "matplotlib_code": + return self.to_matplotlib_code() + else: + warnings.warn( + "Unsupported plotting backend. Lux currently only support 'altair', 'vegalite', or 'matplotlib'", + stacklevel=2, + ) def refresh_source(self, ldf): # -> Vis: """ diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index 63d33c5b..35617d75 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -85,7 +85,9 @@ def create_vis(self, vis, standalone=True): chart = None if chart: - if lux.config.plot_config: + if lux.config.plot_config and ( + lux.config.plotting_backend == "vegalite" or lux.config.plotting_backend == "altair" + ): chart.chart = lux.config.plot_config(chart.chart) if self.output_type == "VegaLite": chart_dict = chart.chart.to_dict() @@ -93,6 +95,7 @@ def create_vis(self, vis, standalone=True): # chart["data"] = { "values": vis.data.to_dict(orient='records') } # chart_dict["width"] = 160 # chart_dict["height"] = 150 + chart_dict["vislib"] = "vegalite" return chart_dict elif self.output_type == "Altair": import inspect diff --git a/lux/vislib/matplotlib/BarChart.py b/lux/vislib/matplotlib/BarChart.py new file mode 100644 index 00000000..90a6a93b --- /dev/null +++ b/lux/vislib/matplotlib/BarChart.py @@ -0,0 +1,151 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.vislib.matplotlib.MatplotlibChart import MatplotlibChart +from lux.utils.utils import get_agg_title +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from lux.utils.utils import matplotlib_setup +from matplotlib.cm import ScalarMappable +from lux.utils.date_utils import compute_date_granularity +from matplotlib.ticker import MaxNLocator + + +class BarChart(MatplotlibChart): + """ + BarChart is a subclass of MatplotlibChart that render as a bar charts. + All rendering properties for bar charts are set here. + + See Also + -------- + matplotlib.org + """ + + def __init__(self, dobj, fig, ax): + super().__init__(dobj, fig, ax) + + def __repr__(self): + return f"Bar Chart <{str(self.vis)}>" + + def initialize_chart(self): + self.tooltip = False + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + x_attr_abv = x_attr.attribute + y_attr_abv = y_attr.attribute + + if len(x_attr.attribute) > 25: + x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] + if len(y_attr.attribute) > 25: + y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] + + if x_attr.data_model == "measure": + agg_title = get_agg_title(x_attr) + measure_attr = x_attr.attribute + bar_attr = y_attr.attribute + else: + agg_title = get_agg_title(y_attr) + measure_attr = y_attr.attribute + bar_attr = x_attr.attribute + + k = 10 + self._topkcode = "" + n_bars = len(self.data.iloc[:, 0].unique()) + if n_bars > k: # Truncating to only top k + remaining_bars = n_bars - k + self.data = self.data.nlargest(k, measure_attr) + self.ax.text( + 0.95, + 0.01, + f"+ {remaining_bars} more ...", + verticalalignment="bottom", + horizontalalignment="right", + transform=self.ax.transAxes, + fontsize=11, + fontweight="bold", + color="#ff8e04", + ) + + self._topkcode = f"""text = alt.Chart(visData).mark_text( + x=155, + y=142, + align="right", + color = "#ff8e04", + fontSize = 11, + text=f"+ {remaining_bars} more ..." + ) + chart = chart + text\n""" + + df = pd.DataFrame(self.data) + + bars = df[bar_attr].apply(lambda x: str(x)) + measurements = df[measure_attr] + + plot_code = "" + + color_attr = self.vis.get_attr_by_channel("color") + if len(color_attr) == 1: + self.fig, self.ax = matplotlib_setup(6, 4) + color_attr_name = color_attr[0].attribute + color_attr_type = color_attr[0].data_type + colors = df[color_attr_name].values + unique = list(set(colors)) + d_x = {} + d_y = {} + for i in unique: + d_x[i] = [] + d_y[i] = [] + for i in range(len(colors)): + d_x[colors[i]].append(bars[i]) + d_y[colors[i]].append(measurements[i]) + for i in range(len(unique)): + self.ax.barh(d_x[unique[i]], d_y[unique[i]], label=unique[i]) + plot_code += ( + f"ax.barh({d_x}[{unique}[{i}]], {d_y}[{unique}[{i}]], label={unique}[{i}])\n" + ) + self.ax.legend( + title=color_attr_name, bbox_to_anchor=(1.05, 1), loc="upper left", ncol=1, frameon=False + ) + plot_code += f"""ax.legend( + title='{color_attr_name}', + bbox_to_anchor=(1.05, 1), + loc='upper left', + ncol=1, + frameon=False,)\n""" + else: + self.ax.barh(bars, measurements, align="center") + plot_code += f"ax.barh(bars, measurements, align='center')\n" + + y_ticks_abbev = df[bar_attr].apply(lambda x: str(x)[:10] + "..." if len(str(x)) > 10 else str(x)) + self.ax.set_yticks(bars) + self.ax.set_yticklabels(y_ticks_abbev) + + self.ax.set_xlabel(x_attr_abv) + self.ax.set_ylabel(y_attr_abv) + plt.gca().invert_yaxis() + + self.code += "import matplotlib.pyplot as plt\n" + self.code += "import numpy as np\n" + self.code += "from math import nan\n" + + self.code += f"fig, ax = plt.subplots()\n" + self.code += f"bars = df['{bar_attr}']\n" + self.code += f"measurements = df['{measure_attr}']\n" + + self.code += plot_code + + self.code += f"ax.set_xlabel('{x_attr_abv}')\n" + self.code += f"ax.set_ylabel('{y_attr_abv}')\n" diff --git a/lux/vislib/matplotlib/Heatmap.py b/lux/vislib/matplotlib/Heatmap.py new file mode 100644 index 00000000..ad42633f --- /dev/null +++ b/lux/vislib/matplotlib/Heatmap.py @@ -0,0 +1,110 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.vislib.matplotlib.MatplotlibChart import MatplotlibChart +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from lux.utils.utils import matplotlib_setup + + +class Heatmap(MatplotlibChart): + """ + Heatmap is a subclass of MatplotlibChart that render as a heatmap. + All rendering properties for heatmap are set here. + + See Also + -------- + matplotlib.org + """ + + def __init__(self, vis, fig, ax): + super().__init__(vis, fig, ax) + + def __repr__(self): + return f"Heatmap <{str(self.vis)}>" + + def initialize_chart(self): + # return NotImplemented + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + x_attr_abv = x_attr.attribute + y_attr_abv = y_attr.attribute + + if len(x_attr.attribute) > 25: + x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] + if len(y_attr.attribute) > 25: + y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] + + df = pd.DataFrame(self.data) + + plot_code = "" + color_attr = self.vis.get_attr_by_channel("color") + color_attr_name = "" + color_map = "Blues" + if len(color_attr) == 1: + self.fig, self.ax = matplotlib_setup(6, 4) + color_attr_name = color_attr[0].attribute + df = pd.pivot_table(data=df, index="xBinStart", values=color_attr_name, columns="yBinStart") + color_map = "viridis" + plot_code += f"""df = pd.pivot_table( + data=df, + index='xBinStart', + values='{color_attr_name}', + columns='yBinStart')\n""" + else: + df = pd.pivot_table(data=df, index="xBinStart", values="count", columns="yBinStart") + df = df.apply(lambda x: np.log(x), axis=1) + plot_code += f"""df = pd.pivot_table( + df, + index='xBinStart', + values='count', + columns='yBinStart')\n""" + plot_code += f"df = df.apply(lambda x: np.log(x), axis=1)\n" + df = df.values + + plt.imshow(df, cmap=color_map) + self.ax.set_aspect("auto") + plt.gca().invert_yaxis() + + colorbar_code = "" + if len(color_attr) == 1: + cbar = plt.colorbar(label=color_attr_name) + cbar.outline.set_linewidth(0) + colorbar_code += f"cbar = plt.colorbar(label='{color_attr_name}')\n" + colorbar_code += f"cbar.outline.set_linewidth(0)\n" + + self.ax.set_xlabel(x_attr_abv) + self.ax.set_ylabel(y_attr_abv) + self.ax.grid(False) + + self.code += "import matplotlib.pyplot as plt\n" + self.code += "import numpy as np\n" + self.code += "from math import nan\n" + self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n" + + self.code += plot_code + self.code += f"df = df.values\n" + + self.code += f"fig, ax = plt.subplots()\n" + self.code += f"plt.imshow(df, cmap='{color_map}')\n" + self.code += f"ax.set_aspect('auto')\n" + self.code += f"plt.gca().invert_yaxis()\n" + + self.code += colorbar_code + + self.code += f"ax.set_xlabel('{x_attr_abv}')\n" + self.code += f"ax.set_ylabel('{y_attr_abv}')\n" + self.code += f"ax.grid(False)\n" diff --git a/lux/vislib/matplotlib/Histogram.py b/lux/vislib/matplotlib/Histogram.py new file mode 100644 index 00000000..84a7c2ff --- /dev/null +++ b/lux/vislib/matplotlib/Histogram.py @@ -0,0 +1,83 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.vislib.matplotlib.MatplotlibChart import MatplotlibChart +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +class Histogram(MatplotlibChart): + """ + Histogram is a subclass of AltairChart that render as a histograms. + All rendering properties for histograms are set here. + + See Also + -------- + matplotlib.org + """ + + def __init__(self, vis, fig, ax): + super().__init__(vis, fig, ax) + + def __repr__(self): + return f"Histogram <{str(self.vis)}>" + + def initialize_chart(self): + self.tooltip = False + measure = self.vis.get_attr_by_data_model("measure", exclude_record=True)[0] + msr_attr = self.vis.get_attr_by_channel(measure.channel)[0] + + msr_attr_abv = msr_attr.attribute + + if len(msr_attr.attribute) > 17: + msr_attr_abv = msr_attr.attribute[:10] + "..." + msr_attr.attribute[-7:] + + x_min = self.vis.min_max[msr_attr.attribute][0] + x_max = self.vis.min_max[msr_attr.attribute][1] + + markbar = abs(x_max - x_min) / 12 + + df = pd.DataFrame(self.data) + + bars = df[msr_attr.attribute] + measurements = df["Number of Records"] + + self.ax.bar(bars, measurements, width=markbar) + self.ax.set_xlim(x_min, x_max) + + x_label = "" + y_label = "" + if measure.channel == "x": + x_label = f"{msr_attr.attribute} (binned)" + y_label = "Number of Records" + elif measure.channel == "y": + x_label = "Number of Records" + y_label = f"{msr_attr.attribute} (binned)" + + self.ax.set_xlabel(x_label) + self.ax.set_ylabel(y_label) + + self.code += "import matplotlib.pyplot as plt\n" + self.code += "import numpy as np\n" + self.code += "from math import nan\n" + self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n" + + self.code += f"fig, ax = plt.subplots()\n" + self.code += f"bars = df['{msr_attr.attribute}']\n" + self.code += f"measurements = df['Number of Records']\n" + + self.code += f"ax.bar(bars, measurements, width={markbar})\n" + self.code += f"ax.set_xlabel('{x_label}')\n" + self.code += f"ax.set_ylabel('{y_label}')\n" diff --git a/lux/vislib/matplotlib/LineChart.py b/lux/vislib/matplotlib/LineChart.py new file mode 100644 index 00000000..449d4f56 --- /dev/null +++ b/lux/vislib/matplotlib/LineChart.py @@ -0,0 +1,122 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.vislib.matplotlib.MatplotlibChart import MatplotlibChart +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from lux.utils.utils import get_agg_title +import altair as alt +from lux.utils.utils import matplotlib_setup + + +class LineChart(MatplotlibChart): + """ + LineChart is a subclass of MatplotlibChart that render as a line charts. + All rendering properties for line charts are set here. + + See Also + -------- + matplotlib.org + """ + + def __init__(self, dobj, fig, ax): + super().__init__(dobj, fig, ax) + + def __repr__(self): + return f"Line Chart <{str(self.vis)}>" + + def initialize_chart(self): + self.tooltip = False # tooltip looks weird for line chart + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + x_attr_abv = x_attr.attribute + y_attr_abv = y_attr.attribute + + if len(x_attr.attribute) > 25: + x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] + if len(y_attr.attribute) > 25: + y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] + + self.data = self.data.dropna(subset=[x_attr.attribute, y_attr.attribute]) + + df = pd.DataFrame(self.data) + + x_pts = df[x_attr.attribute] + y_pts = df[y_attr.attribute] + + plot_code = "" + + color_attr = self.vis.get_attr_by_channel("color") + if len(color_attr) == 1: + self.fig, self.ax = matplotlib_setup(6, 4) + color_attr_name = color_attr[0].attribute + color_attr_type = color_attr[0].data_type + colors = df[color_attr_name].values + unique = list(set(colors)) + d_x = {} + d_y = {} + for i in unique: + d_x[i] = [] + d_y[i] = [] + for i in range(len(colors)): + d_x[colors[i]].append(x_pts[i]) + d_y[colors[i]].append(y_pts[i]) + for i in range(len(unique)): + self.ax.plot(d_x[unique[i]], d_y[unique[i]], label=unique[i]) + plot_code += f"""ax.plot( + {d_x}[{unique}[{i}]], + {d_y}[{unique}[{i}]], + label={unique}[{i}])\n""" + self.ax.legend( + title=color_attr_name, bbox_to_anchor=(1.05, 1), loc="upper left", ncol=1, frameon=False + ) + plot_code += f"""ax.legend( + title='{color_attr_name}', + bbox_to_anchor=(1.05, 1), + loc='upper left', + ncol=1, + frameon=False,)\n""" + else: + self.ax.plot(x_pts, y_pts) + plot_code += f"ax.plot(x_pts, y_pts)\n" + + x_label = "" + y_label = "" + if y_attr.data_model == "measure": + agg_title = get_agg_title(y_attr) + self.ax.set_xlabel(x_attr_abv) + self.ax.set_ylabel(agg_title) + x_label = x_attr_abv + y_label = agg_title + else: + agg_title = get_agg_title(x_attr) + self.ax.set_xlabel(agg_title) + self.ax.set_ylabel(y_attr_abv) + x_label = agg_title + y_label = y_attr_abv + + self.code += "import matplotlib.pyplot as plt\n" + self.code += "import numpy as np\n" + self.code += "from math import nan\n" + + self.code += f"fig, ax = plt.subplots()\n" + self.code += f"x_pts = df['{x_attr.attribute}']\n" + self.code += f"y_pts = df['{y_attr.attribute}']\n" + + self.code += plot_code + + self.code += f"ax.set_xlabel('{x_label}')\n" + self.code += f"ax.set_ylabel('{y_label}')\n" diff --git a/lux/vislib/matplotlib/MatplotlibChart.py b/lux/vislib/matplotlib/MatplotlibChart.py new file mode 100644 index 00000000..30bb8a81 --- /dev/null +++ b/lux/vislib/matplotlib/MatplotlibChart.py @@ -0,0 +1,75 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import matplotlib.pyplot as plt + + +class MatplotlibChart: + """ + MatplotlibChart is a representation of a chart. + Common utilities for charts that is independent of chart types should go here. + + See Also + -------- + https://matplotlib.org/ + + """ + + def __init__(self, vis, fig, ax): + self.vis = vis + self.data = vis.data + self.tooltip = True + self.fig = fig + self.ax = ax + # ----- START self.code modification ----- + self.code = "" + self.apply_default_config() + self.chart = self.initialize_chart() + self.add_title() + + # ----- END self.code modification ----- + + def __repr__(self): + return f"MatplotlibChart <{str(self.vis)}>" + + def add_tooltip(self): + return NotImplemented + + def apply_default_config(self): + self.code += """plt.rcParams.update( + { + "axes.titlesize": 20, + "axes.titleweight": "bold", + "axes.labelweight": "bold", + "axes.labelsize": 16, + "legend.fontsize": 14, + "legend.title_fontsize": 15, + # "font.family": "DejaVu Sans", + "xtick.labelsize": 13, + "ytick.labelsize": 13, + } + )\n""" + + def encode_color(self): + return NotImplemented + + def add_title(self): + chart_title = self.vis.title + if chart_title: + self.ax.set_title(chart_title) + self.code += f"ax.set_title('{chart_title}')\n" + + def initialize_chart(self): + return NotImplemented diff --git a/lux/vislib/matplotlib/MatplotlibRenderer.py b/lux/vislib/matplotlib/MatplotlibRenderer.py new file mode 100644 index 00000000..0e19a85b --- /dev/null +++ b/lux/vislib/matplotlib/MatplotlibRenderer.py @@ -0,0 +1,111 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lux +import pandas as pd +from lux.executor.PandasExecutor import PandasExecutor +from lux.vislib.matplotlib.BarChart import BarChart +from lux.vislib.matplotlib.ScatterChart import ScatterChart +from lux.vislib.matplotlib.LineChart import LineChart +from lux.vislib.matplotlib.Histogram import Histogram +from lux.vislib.matplotlib.Heatmap import Heatmap +import matplotlib.pyplot as plt +from lux.utils.utils import matplotlib_setup + +import base64 +from io import BytesIO + + +class MatplotlibRenderer: + """ + Renderer for Charts based on Matplotlib (https://matplotlib.org/) + """ + + def __init__(self, output_type="matplotlib"): + self.output_type = output_type + + def __repr__(self): + return f"MatplotlibRenderer" + + def create_vis(self, vis, standalone=True): + """ + Input Vis object and return a visualization specification + + Parameters + ---------- + vis: lux.vis.Vis + Input Vis (with data) + standalone: bool + Flag to determine if outputted code uses user-defined variable names or can be run independently + Returns + ------- + chart : altair.Chart + Output Altair Chart Object + """ + # Lazy Evaluation for 2D Binning + if vis.mark == "scatter" and vis._postbin: + vis._mark = "heatmap" + + PandasExecutor.execute_2D_binning(vis) + # If a column has a Period dtype, or contains Period objects, convert it back to Datetime + if vis.data is not None: + for attr in list(vis.data.columns): + if pd.api.types.is_period_dtype(vis.data.dtypes[attr]) or isinstance( + vis.data[attr].iloc[0], pd.Period + ): + dateColumn = vis.data[attr] + vis.data[attr] = pd.PeriodIndex(dateColumn.values).to_timestamp() + if pd.api.types.is_interval_dtype(vis.data.dtypes[attr]) or isinstance( + vis.data[attr].iloc[0], pd.Interval + ): + vis.data[attr] = vis.data[attr].astype(str) + fig, ax = matplotlib_setup(4.5, 4) + if vis.mark == "histogram": + chart = Histogram(vis, fig, ax) + elif vis.mark == "bar": + chart = BarChart(vis, fig, ax) + elif vis.mark == "scatter": + chart = ScatterChart(vis, fig, ax) + elif vis.mark == "line": + chart = LineChart(vis, fig, ax) + elif vis.mark == "heatmap": + chart = Heatmap(vis, fig, ax) + else: + chart = None + return chart + if chart: + plt.tight_layout() + if lux.config.plot_config and ( + lux.config.plotting_backend == "matplotlib" + or lux.config.plotting_backend == "matplotlib_code" + ): + chart.fig, chart.ax = lux.config.plot_config(chart.fig, chart.ax) + plt.tight_layout() + tmpfile = BytesIO() + chart.fig.savefig(tmpfile, format="png") + chart.chart = base64.b64encode(tmpfile.getvalue()).decode("utf-8") + plt.clf() + plt.close("all") + if self.output_type == "matplotlib": + return {"config": chart.chart, "vislib": "matplotlib"} + if self.output_type == "matplotlib_code": + if lux.config.plot_config: + import inspect + + chart.code += "\n".join( + inspect.getsource(lux.config.plot_config).split("\n ")[1:-1] + ) + chart.code += "\nfig" + chart.code = chart.code.replace("\n\t\t", "\n") + return chart.code diff --git a/lux/vislib/matplotlib/ScatterChart.py b/lux/vislib/matplotlib/ScatterChart.py new file mode 100644 index 00000000..ae19f09a --- /dev/null +++ b/lux/vislib/matplotlib/ScatterChart.py @@ -0,0 +1,128 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.vislib.matplotlib.MatplotlibChart import MatplotlibChart +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from lux.utils.utils import matplotlib_setup +from matplotlib.cm import ScalarMappable + + +class ScatterChart(MatplotlibChart): + """ + ScatterChart is a subclass of MatplotlibChart that render as a scatter charts. + All rendering properties for scatter charts are set here. + + See Also + -------- + matplotlib.org + """ + + def __init__(self, vis, fig, ax): + super().__init__(vis, fig, ax) + + def __repr__(self): + return f"ScatterChart <{str(self.vis)}>" + + def initialize_chart(self): + x_attr = self.vis.get_attr_by_channel("x")[0] + y_attr = self.vis.get_attr_by_channel("y")[0] + + x_attr_abv = x_attr.attribute + y_attr_abv = y_attr.attribute + + if len(x_attr.attribute) > 25: + x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] + if len(y_attr.attribute) > 25: + y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] + + df = pd.DataFrame(self.data) + + x_pts = df[x_attr.attribute] + y_pts = df[y_attr.attribute] + + plot_code = "" + + color_attr = self.vis.get_attr_by_channel("color") + if len(color_attr) == 1: + self.fig, self.ax = matplotlib_setup(6, 5) + color_attr_name = color_attr[0].attribute + color_attr_type = color_attr[0].data_type + colors = df[color_attr_name].values + plot_code += f"colors = df['{color_attr_name}'].values\n" + unique = list(set(colors)) + vals = [unique.index(i) for i in colors] + if color_attr_type == "quantitative": + self.ax.scatter(x_pts, y_pts, c=vals, cmap="Blues", alpha=0.5) + plot_code += f"ax.scatter(x_pts, y_pts, c={vals}, cmap='Blues', alpha=0.5)\n" + my_cmap = plt.cm.get_cmap("Blues") + max_color = max(colors) + sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0, max_color)) + sm.set_array([]) + + cbar = plt.colorbar(sm, label=color_attr_name) + cbar.outline.set_linewidth(0) + plot_code += f"my_cmap = plt.cm.get_cmap('Blues')\n" + plot_code += f"""sm = ScalarMappable( + cmap=my_cmap, + norm=plt.Normalize(0, {max_color}))\n""" + + plot_code += f"cbar = plt.colorbar(sm, label='{color_attr_name}')\n" + plot_code += f"cbar.outline.set_linewidth(0)\n" + else: + scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap="Set1") + plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap='Set1')\n" + + unique = [str(i) for i in unique] + leg = self.ax.legend( + handles=scatter.legend_elements(num=range(0, len(unique)))[0], + labels=unique, + title=color_attr_name, + markerscale=2.0, + bbox_to_anchor=(1.05, 1), + loc="upper left", + ncol=1, + frameon=False, + ) + scatter.set_alpha(0.5) + plot_code += f"""ax.legend( + handles=scatter.legend_elements(num=range(0, len({unique})))[0], + labels={unique}, + title='{color_attr_name}', + markerscale=2., + bbox_to_anchor=(1.05, 1), + loc='upper left', + ncol=1, + frameon=False,)\n""" + plot_code += "scatter.set_alpha(0.5)\n" + else: + self.ax.scatter(x_pts, y_pts, alpha=0.5) + plot_code += f"ax.scatter(x_pts, y_pts, alpha=0.5)\n" + + self.ax.set_xlabel(x_attr_abv) + self.ax.set_ylabel(y_attr_abv) + + self.code += "import matplotlib.pyplot as plt\n" + self.code += "import numpy as np\n" + self.code += "from math import nan\n" + self.code += "from matplotlib.cm import ScalarMappable\n" + + self.code += f"fig, ax = plt.subplots()\n" + self.code += f"x_pts = df['{x_attr.attribute}']\n" + self.code += f"y_pts = df['{y_attr.attribute}']\n" + + self.code += plot_code + self.code += f"ax.set_xlabel('{x_attr_abv}')\n" + self.code += f"ax.set_ylabel('{y_attr_abv}')\n" diff --git a/lux/vislib/matplotlib/__init__.py b/lux/vislib/matplotlib/__init__.py new file mode 100644 index 00000000..ebce4f75 --- /dev/null +++ b/lux/vislib/matplotlib/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lux.utils.utils import matplotlib_setup +import matplotlib.pyplot as plt + +plt.rcParams.update({"figure.max_open_warning": 0}) +plt.rcParams.update( + { + "axes.titlesize": 15, + "axes.titleweight": "bold", + "axes.labelweight": "bold", + "axes.labelsize": 13, + "legend.fontsize": 13, + "legend.title_fontsize": 13, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + } +) diff --git a/requirements.txt b/requirements.txt index 74e8349f..563cee5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ altair>=4.0.0 numpy>=1.16.5 pandas>=1.1.0 scikit-learn>=0.22 +matplotlib>=3.0.0 # Install only to use SQLExecutor psycopg2>=2.8.5 psycopg2-binary>=2.8.5 diff --git a/tests/test_config.py b/tests/test_config.py index 1085fe9a..70f558e6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -187,7 +187,24 @@ def test_remove_default_actions(global_var): register_default_actions() +def test_matplotlib_set_default_plot_config(): + lux.config.plotting_backend = "matplotlib" + + def add_title(fig, ax): + ax.set_title("Test Title") + return fig, ax + + df = pd.read_csv("lux/data/car.csv") + lux.config.plot_config = add_title + df._repr_html_() + title_addition = 'ax.set_title("Test Title")' + exported_code_str = df.recommendation["Correlation"][0].to_Altair() + assert title_addition in exported_code_str + + def test_set_default_plot_config(): + lux.config.plotting_backend = "vegalite" + def change_color_make_transparent_add_title(chart): chart = chart.configure_mark(color="green", opacity=0.2) chart.title = "Test Title" diff --git a/tests/test_vis.py b/tests/test_vis.py index b79d1b45..f4395800 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -197,3 +197,284 @@ def test_text_not_overridden(): vis._repr_html_() code = vis.to_Altair() assert 'color = "#ff8e04"' in code + + +def test_bar_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Origin", "Acceleration"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_bar()" in vis_code + assert ( + "y = alt.Y('Origin', type= 'nominal', axis=alt.Axis(labelOverlap=True, title='Origin'))" + in vis_code + ) + assert ( + "x = alt.X('Acceleration', type= 'quantitative', title='Mean of Acceleration', axis=alt.Axis(title='Mean of Acceleration'))" + in vis_code + ) + + lux.config.plot_config = None + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Origin", "Acceleration"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.barh(bars, measurements, align='center')" in vis_code + assert "ax.set_xlabel('Acceleration')" in vis_code + assert "ax.set_ylabel('Origin')" in vis_code + + +def test_colored_bar_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Cylinders", "Acceleration", "Origin"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_bar()" in vis_code + assert ( + "y = alt.Y('Cylinders', type= 'nominal', axis=alt.Axis(labelOverlap=True, title='Cylinders'))" + in vis_code + ) + assert ( + "x = alt.X('Acceleration', type= 'quantitative', title='Mean of Acceleration', axis=alt.Axis(title='Mean of Acceleration')" + in vis_code + ) + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Cylinders", "Acceleration", "Origin"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.barh" in vis_code + assert "title='Origin'" in vis_code + assert "ax.set_xlabel('Acceleration')" in vis_code + assert "ax.set_ylabel('Cylinders')" in vis_code + + +def test_scatter_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Acceleration", "Weight"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(df).mark_circle()" in vis_code + assert ( + "x=alt.X('Acceleration',scale=alt.Scale(domain=(8.0, 24.8)),type='quantitative', axis=alt.Axis(title='Acceleration'))" + in vis_code + ) + assert ( + " y=alt.Y('Weight',scale=alt.Scale(domain=(1613, 5140)),type='quantitative', axis=alt.Axis(title='Weight'))" + in vis_code + ) + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Acceleration", "Weight"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.scatter(x_pts, y_pts, alpha=0.5)" in vis_code + assert "ax.set_xlabel('Acceleration')" in vis_code + assert "ax.set_ylabel('Weight')" in vis_code + + +def test_colored_scatter_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Origin", "Acceleration", "Weight"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(df).mark_circle()" in vis_code + assert ( + "x=alt.X('Acceleration',scale=alt.Scale(domain=(8.0, 24.8)),type='quantitative', axis=alt.Axis(title='Acceleration'))" + in vis_code + ) + assert ( + " y=alt.Y('Weight',scale=alt.Scale(domain=(1613, 5140)),type='quantitative', axis=alt.Axis(title='Weight'))" + in vis_code + ) + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Origin", "Acceleration", "Weight"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.scatter" in vis_code + assert "title='Origin'" in vis_code + assert "ax.set_xlabel('Acceleration')" in vis_code + assert "ax.set_ylabel('Weight')" in vis_code + + +def test_line_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Year", "Acceleration"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_line()" in vis_code + assert ( + "y = alt.Y('Acceleration', type= 'quantitative', title='Mean of Acceleration', axis=alt.Axis(title='Acceleration')" + in vis_code + ) + assert "x = alt.X('Year', type = 'temporal', axis=alt.Axis(title='Year'))" in vis_code + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Year", "Acceleration"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.plot(x_pts, y_pts)" in vis_code + assert "ax.set_xlabel('Year')" in vis_code + assert "ax.set_ylabel('Mean of Acceleration')" in vis_code + + +def test_colored_line_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Year", "Acceleration", "Origin"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_line()" in vis_code + assert ( + "y = alt.Y('Acceleration', type= 'quantitative', title='Mean of Acceleration', axis=alt.Axis(title='Acceleration')" + in vis_code + ) + assert "x = alt.X('Year', type = 'temporal', axis=alt.Axis(title='Year'))" in vis_code + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Year", "Acceleration", "Origin"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.plot" in vis_code + assert "title='Origin'" in vis_code + assert "ax.set_xlabel('Year')" in vis_code + assert "ax.set_ylabel('Mean of Acceleration')" in vis_code + + +def test_histogram_chart(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + vis = Vis(["Displacement"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_bar" in vis_code + assert ( + "alt.X('Displacement', title='Displacement (binned)',bin=alt.Bin(binned=True), type='quantitative', axis=alt.Axis(labelOverlap=True, title='Displacement (binned)'), scale=alt.Scale(domain=(68.0, 455.0)))" + in vis_code + ) + assert 'alt.Y("Number of Records", type="quantitative")' in vis_code + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["Displacement"], df) + vis_code = vis.to_matplotlib_code() + assert "ax.bar(bars, measurements, width=32.25)" in vis_code + assert "ax.set_xlabel('Displacement (binned)')" in vis_code + assert "ax.set_ylabel('Number of Records')" in vis_code + + +def test_heatmap_chart(global_var): + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + lux.config.plotting_backend = "vegalite" + vis = Vis(["price", "longitude"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_rect()" in vis_code + assert ( + "x=alt.X('xBinStart', type='quantitative', axis=alt.Axis(title='price'), bin = alt.BinParams(binned=True))" + in vis_code + ) + assert "x2=alt.X2('xBinEnd')" in vis_code + assert ( + "y=alt.Y('yBinStart', type='quantitative', axis=alt.Axis(title='longitude'), bin = alt.BinParams(binned=True))" + in vis_code + ) + assert "y2=alt.Y2('yBinEnd')" in vis_code + assert 'scale=alt.Scale(type="log")' in vis_code + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["price", "longitude"], df) + vis_code = vis.to_matplotlib_code() + assert "plt.imshow(df, cmap='Blues')" in vis_code + assert "index='xBinStart'" in vis_code + assert "values='count'" in vis_code + assert "columns='yBinStart'" in vis_code + + +def test_colored_heatmap_chart(global_var): + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + lux.config.plotting_backend = "vegalite" + vis = Vis(["price", "longitude", "availability_365"], df) + vis_code = vis.to_Altair() + assert "alt.Chart(visData).mark_rect()" in vis_code + assert ( + "x=alt.X('xBinStart', type='quantitative', axis=alt.Axis(title='price'), bin = alt.BinParams(binned=True))" + in vis_code + ) + assert "x2=alt.X2('xBinEnd')" in vis_code + assert ( + "y=alt.Y('yBinStart', type='quantitative', axis=alt.Axis(title='longitude'), bin = alt.BinParams(binned=True))" + in vis_code + ) + assert "y2=alt.Y2('yBinEnd')" in vis_code + assert 'scale=alt.Scale(type="log")' in vis_code + assert "chart.encode(color=alt.Color('availability_365',type='quantitative'))" in vis_code + + lux.config.plotting_backend = "matplotlib" + vis = Vis(["price", "longitude", "availability_365"], df) + vis_code = vis.to_matplotlib_code() + assert "plt.imshow(df, cmap='viridis')" in vis_code + assert "index='xBinStart'" in vis_code + assert "values='availability_365'" in vis_code + assert "columns='yBinStart'" in vis_code + assert "plt.colorbar(label='availability_365')" in vis_code + + +def test_vegalite_default_actions_registered(global_var): + df = pytest.car_df + lux.config.plotting_backend = "vegalite" + df._repr_html_() + # Histogram Chart + assert "Distribution" in df.recommendation + assert len(df.recommendation["Distribution"]) > 0 + + # Occurrence Chart + assert "Occurrence" in df.recommendation + assert len(df.recommendation["Occurrence"]) > 0 + + # Line Chart + assert "Temporal" in df.recommendation + assert len(df.recommendation["Temporal"]) > 0 + + # Scatter Chart + assert "Correlation" in df.recommendation + assert len(df.recommendation["Correlation"]) > 0 + + +def test_matplotlib_default_actions_registered(global_var): + df = pytest.car_df + lux.config.plotting_backend = "matplotlib" + df._repr_html_() + # Histogram Chart + assert "Distribution" in df.recommendation + assert len(df.recommendation["Distribution"]) > 0 + + # Occurrence Chart + assert "Occurrence" in df.recommendation + assert len(df.recommendation["Occurrence"]) > 0 + + # Line Chart + assert "Temporal" in df.recommendation + assert len(df.recommendation["Temporal"]) > 0 + + # Scatter Chart + assert "Correlation" in df.recommendation + assert len(df.recommendation["Correlation"]) > 0 + + +def test_vegalite_heatmap_flag_config(): + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + lux.config.plotting_backend = "vegalite" + df._repr_html_() + # Heatmap Chart + assert df.recommendation["Correlation"][0]._postbin + lux.config.heatmap = False + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + df = df.copy() + assert not df.recommendation["Correlation"][0]._postbin + lux.config.heatmap = True + + +def test_matplotlib_heatmap_flag_config(): + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + lux.config.plotting_backend = "matplotlib" + df._repr_html_() + # Heatmap Chart + assert df.recommendation["Correlation"][0]._postbin + lux.config.heatmap = False + df = pd.read_csv("https://raw.githubusercontent.com/lux-org/lux-datasets/master/data/airbnb_nyc.csv") + df = df.copy() + assert not df.recommendation["Correlation"][0]._postbin + lux.config.heatmap = True + lux.config.plotting_backend = "vegalite"