From 5fbb5c46d17d50c5f4d1d9ff56ebe8af95ff4423 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 15 May 2020 18:45:57 -0400 Subject: [PATCH] Refactor variable processing (#2071) * Add initial common function for processing long-form inputs * Attempt to use new core variable processing for longform relational plots * Create and ignore a directory for notes * Move relational plots to use common variable processing * Refactor establish_variables method into core * Allow relational plots to use wide lists of lists * Change base class to _VectorPlotter * Add initial attempt at generalized wide data processing * Modify tests for new intermeidate wide-form data representation * Remove relational-specific wide data processing * Fold relplot tests under TestRelationalPlotter * Move tests for core variable processing * Revert test reorganization; PEP8 and clean up names * Add tests for wide dict inputs * Pandas compat * Modernize numpy random usage in test fixtures * Fix docstring and comments * Use containment checks rather than KeyError handling * Improve test coverage for long data and messy wide data * Test variables from dataframe index * Flesh out wide data docstring * Return variables dict along with plot_data df * First attempt at generalizing relplot inputs * Refactor and parametrize flat variables tests * Test at base class level * Test relplot from wide data and long vectors * Fix test --- .gitignore | 1 + seaborn/conftest.py | 168 +++- seaborn/core.py | 230 +++++ seaborn/relational.py | 283 ++---- seaborn/tests/test_core.py | 35 + seaborn/tests/test_relational.py | 1610 ++++++++++++++++++------------ 6 files changed, 1497 insertions(+), 830 deletions(-) create mode 100644 seaborn/core.py create mode 100644 seaborn/tests/test_core.py diff --git a/.gitignore b/.gitignore index 24d06cba5d..41fd2281c8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ cover/ .idea/ .vscode/ .pytest_cache/ +notes/ diff --git a/seaborn/conftest.py b/seaborn/conftest.py index 8a0425d678..764f530841 100644 --- a/seaborn/conftest.py +++ b/seaborn/conftest.py @@ -1,5 +1,7 @@ import numpy as np +import pandas as pd import matplotlib.pyplot as plt + import pytest @@ -11,4 +13,168 @@ def close_figs(): @pytest.fixture(autouse=True) def random_seed(): - np.random.seed(47) + seed = sum(map(ord, "seaborn random global")) + np.random.seed(seed) + + +@pytest.fixture() +def rng(): + seed = sum(map(ord, "seaborn random object")) + return np.random.RandomState(seed) + + +@pytest.fixture +def wide_df(rng): + + columns = list("abc") + index = pd.Int64Index(np.arange(10, 50, 2), name="wide_index") + values = rng.normal(size=(len(index), len(columns))) + return pd.DataFrame(values, index=index, columns=columns) + + +@pytest.fixture +def wide_array(wide_df): + + # Requires panads >= 0.24 + # return wide_df.to_numpy() + return np.asarray(wide_df) + + +@pytest.fixture +def flat_series(rng): + + index = pd.Int64Index(np.arange(10, 30), name="t") + return pd.Series(rng.normal(size=20), index, name="s") + + +@pytest.fixture +def flat_array(flat_series): + + # Requires panads >= 0.24 + # return flat_series.to_numpy() + return np.asarray(flat_series) + + +@pytest.fixture +def flat_list(flat_series): + + # Requires panads >= 0.24 + # return flat_series.to_list() + return flat_series.tolist() + + +@pytest.fixture(params=["series", "array", "list"]) +def flat_data(rng, request): + + index = pd.Int64Index(np.arange(10, 30), name="t") + series = pd.Series(rng.normal(size=20), index, name="s") + if request.param == "series": + data = series + elif request.param == "array": + try: + data = series.to_numpy() # Requires pandas >= 0.24 + except AttributeError: + data = np.asarray(series) + elif request.param == "list": + try: + data = series.to_list() # Requires pandas >= 0.24 + except AttributeError: + data = series.tolist() + return data + + +@pytest.fixture +def wide_list_of_series(rng): + + return [pd.Series(rng.normal(size=20), np.arange(20), name="a"), + pd.Series(rng.normal(size=10), np.arange(5, 15), name="b")] + + +@pytest.fixture +def wide_list_of_arrays(wide_list_of_series): + + # Requires pandas >= 0.24 + # return [s.to_numpy() for s in wide_list_of_series] + return [np.asarray(s) for s in wide_list_of_series] + + +@pytest.fixture +def wide_list_of_lists(wide_list_of_series): + + # Requires pandas >= 0.24 + # return [s.to_list() for s in wide_list_of_series] + return [s.tolist() for s in wide_list_of_series] + + +@pytest.fixture +def wide_dict_of_series(wide_list_of_series): + + return {s.name: s for s in wide_list_of_series} + + +@pytest.fixture +def wide_dict_of_arrays(wide_list_of_series): + + # Requires pandas >= 0.24 + # return {s.name: s.to_numpy() for s in wide_list_of_series} + return {s.name: np.asarray(s) for s in wide_list_of_series} + + +@pytest.fixture +def wide_dict_of_lists(wide_list_of_series): + + # Requires pandas >= 0.24 + # return {s.name: s.to_list() for s in wide_list_of_series} + return {s.name: s.tolist() for s in wide_list_of_series} + + +@pytest.fixture +def long_df(rng): + + n = 100 + df = pd.DataFrame(dict( + x=rng.uniform(0, 20, n).round().astype("int"), + y=rng.normal(size=n), + a=rng.choice(list("abc"), n), + b=rng.choice(list("mnop"), n), + c=rng.choice([0, 1], n), + t=np.repeat(np.datetime64('2005-02-25'), n), + s=rng.choice([2, 4, 8], n), + f=rng.choice([0.2, 0.3], n), + )) + df["s_cat"] = df["s"].astype("category") + return df + + +@pytest.fixture +def long_dict(long_df): + + return long_df.to_dict() + + +@pytest.fixture +def repeated_df(rng): + + n = 100 + return pd.DataFrame(dict( + x=np.tile(np.arange(n // 2), 2), + y=rng.normal(size=n), + a=rng.choice(list("abc"), n), + u=np.repeat(np.arange(2), n // 2), + )) + + +@pytest.fixture +def missing_df(rng, long_df): + + df = long_df.copy() + for col in df: + idx = rng.permutation(df.index)[:10] + df.loc[idx, col] = np.nan + return df + + +@pytest.fixture +def null_series(): + + return pd.Series(index=np.arange(20), dtype='float64') diff --git a/seaborn/core.py b/seaborn/core.py new file mode 100644 index 0000000000..fb432f7507 --- /dev/null +++ b/seaborn/core.py @@ -0,0 +1,230 @@ +from collections.abc import Iterable, Sequence, Mapping +import numpy as np +import pandas as pd + + +class _VectorPlotter: + """Base class for objects underlying *plot functions.""" + + semantics = ["x", "y"] + + def establish_variables(self, data=None, **kwargs): + """Define plot variables.""" + x = kwargs.get("x", None) + y = kwargs.get("y", None) + + if x is None and y is None: + self.input_format = "wide" + plot_data, variables = self.establish_variables_wideform( + data, **kwargs + ) + else: + self.input_format = "long" + plot_data, variables = self.establish_variables_longform( + data, **kwargs + ) + + self.plot_data = plot_data + self.variables = variables + + return plot_data, variables + + def establish_variables_wideform(self, data=None, **kwargs): + """Define plot variables given wide-form data. + + Parameters + ---------- + data : flat vector or collection of vectors + Data can be a vector or mapping that is coerceable to a Series + or a sequence- or mapping-based collection of such vectors, or a + rectangular numpy array, or a Pandas DataFrame. + kwargs : variable -> data mappings + Behavior with keyword arguments is currently undefined. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + """ + # TODO raise here if any kwarg values are not None, + # # if we decide for "structure-only" wide API + + # First, determine if the data object actually has any data in it + empty = not len(data) + + # Then, determine if we have "flat" data (a single vector) + # TODO extract this into a separate function? + if isinstance(data, dict): + values = data.values() + else: + values = np.atleast_1d(data) + flat = not any( + isinstance(v, Iterable) and not isinstance(v, (str, bytes)) + for v in values + ) + + if empty: + + # Make an object with the structure of plot_data, but empty + plot_data = pd.DataFrame(columns=self.semantics) + variables = {} + + elif flat: + + # Coerce the data into a pandas Series such that the values + # become the y variable and the index becomes the x variable + # No other semantics are defined. + # (Could be accomplished with a more general to_series() interface) + flat_data = pd.Series(data, name="y").copy() + flat_data.index.name = "x" + plot_data = flat_data.reset_index().reindex(columns=self.semantics) + + orig_index = getattr(data, "index", None) + variables = { + "x": getattr(orig_index, "name", None), + "y": getattr(data, "name", None) + } + + else: + + # Otherwise assume we have some collection of vectors. + + # Handle Python sequences such that entries end up in the columns, + # not in the rows, of the intermediate wide DataFrame. + # One way to accomplish this is to convert to a dict of Series. + if isinstance(data, Sequence): + data_dict = {} + for i, var in enumerate(data): + key = getattr(var, "name", i) + # TODO is there a safer/more generic way to ensure Series? + # sort of like np.asarray, but for pandas? + data_dict[key] = pd.Series(var) + + data = data_dict + + # Pandas requires that dict values either be Series objects + # or all have the same length, but we want to allow "ragged" inputs + if isinstance(data, Mapping): + data = {key: pd.Series(val) for key, val in data.items()} + + # Otherwise, delegate to the pandas DataFrame constructor + # This is where we'd prefer to use a general interface that says + # "give me this data as a pandas DataFrame", so we can accept + # DataFrame objects from other libraries + wide_data = pd.DataFrame(data, copy=True) + + # At this point we should reduce the dataframe to numeric cols + # TODO do we want any control over this? + wide_data = wide_data.select_dtypes("number") + + # Now melt the data to long form + melt_kws = {"var_name": "columns", "value_name": "values"} + if "index" in self.wide_structure.values(): + melt_kws["id_vars"] = "index" + wide_data["index"] = wide_data.index.to_series() + plot_data = wide_data.melt(**melt_kws) + + # Assign names corresponding to plot semantics + for var, attr in self.wide_structure.items(): + plot_data[var] = plot_data[attr] + plot_data = plot_data.reindex(columns=self.semantics) + + # Define the variable names + variables = {} + for var, attr in self.wide_structure.items(): + obj = getattr(wide_data, attr) + variables[var] = getattr(obj, "name", None) + + return plot_data, variables + + def establish_variables_longform(self, data=None, **kwargs): + """Define plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data : dict-like collection of vectors + Input data where variable names map to vector values. + kwargs : variable -> data mappings + Keys are seaborn variables (x, y, hue, ...) and values are vectors + in any format that can construct a :class:`pandas.DataFrame` or + names of columns or index levels in ``data``. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + Raises + ------ + ValueError + When variables are strings that don't appear in ``data``. + + """ + plot_data = {} + variables = {} + + # Data is optional; all variables can be defined as vectors + if data is None: + data = {} + + # TODO should we try a data.to_dict() or similar here to more + # generally accept objects with that interface? + # Note that dict(df) also works for pandas, and gives us what we + # want, whereas DataFrame.to_dict() gives a nested dict instead of + # a dict of series. + + # Variables can also be extraced from the index attribute + # TODO is this the most general way to enable it? + # There is no index.to_dict on multiindex, unfortunately + try: + index = data.index.to_frame() + except AttributeError: + index = {} + + # The caller will determine the order of variables in plot_data + for key, val in kwargs.items(): + + if isinstance(val, (str, bytes)): + # String inputs trigger __getitem__ + if val in data: + # First try to get an entry in the data object + plot_data[key] = data[val] + variables[key] = val + elif val in index: + # Failing that, try to get an entry in the index object + plot_data[key] = index[val] + variables[key] = val + else: + # We don't know what this name means + err = f"Could not interpret input '{val}'" + raise ValueError(err) + + else: + + # Otherwise, assume the value is itself a vector of data + # TODO check for 1D here or let pd.DataFrame raise? + plot_data[key] = val + # Try to infer the name of the variable + variables[key] = getattr(val, "name", None) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + plot_data = pd.DataFrame(plot_data, columns=self.semantics) + + # Reduce the variables dictionary to fields with valid data + variables = { + var: name + for var, name in variables.items() + if plot_data[var].notnull().any() + } + + return plot_data, variables diff --git a/seaborn/relational.py b/seaborn/relational.py index 7530785c7c..09f68137ba 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -7,6 +7,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt +from .core import _VectorPlotter from . import utils from .utils import (categorical_order, get_color_cycle, ci_to_errsize, remove_na, locator_to_legend_entries) @@ -20,170 +21,26 @@ __all__ = ["relplot", "scatterplot", "lineplot"] -class _RelationalPlotter(object): +class _RelationalPlotter(_VectorPlotter): - default_markers = ["o", "X", "s", "P", "D", "^", "v", "p"] - default_dashes = ["", (4, 1.5), (1, 1), - (3, 1, 1.5, 1), (5, 1, 1, 1), - (5, 1, 2, 1, 2, 1)] - - def establish_variables(self, x=None, y=None, - hue=None, size=None, style=None, - units=None, data=None): - """Parse the inputs to define data for plotting.""" - # Initialize label variables - x_label = y_label = hue_label = size_label = style_label = None - - # Option 1: - # We have a wide-form datast - # -------------------------- - - if x is None and y is None: - - self.input_format = "wide" - - # Option 1a: - # The input data is a Pandas DataFrame - # ------------------------------------ - # We will assign the index to x, the values to y, - # and the columns names to both hue and style - - # TODO accept a dict and try to coerce to a dataframe? - - if isinstance(data, pd.DataFrame): - - # Enforce numeric values - try: - data.astype(np.float) - except ValueError: - err = "A wide-form input must have only numeric values." - raise ValueError(err) - - plot_data = data.copy() - plot_data.loc[:, "x"] = data.index - plot_data = pd.melt(plot_data, "x", - var_name="hue", value_name="y") - plot_data["style"] = plot_data["hue"] - - x_label = getattr(data.index, "name", None) - hue_label = style_label = getattr(plot_data.columns, - "name", None) - - # Option 1b: - # The input data is an array or list - # ---------------------------------- - - else: - - if not len(data): - - plot_data = pd.DataFrame(columns=["x", "y"]) - - elif np.isscalar(np.asarray(data)[0]): - - # The input data is a flat list(like): - # We assign a numeric index for x and use the values for y - - x = getattr(data, "index", np.arange(len(data))) - plot_data = pd.DataFrame(dict(x=x, y=data)) - - elif hasattr(data, "shape"): + semantics = _VectorPlotter.semantics + ["hue", "size", "style", "units"] - # The input data is an array(like): - # We either use the index or assign a numeric index to x, - # the values to y, and id keys to both hue and style + wide_structure = { + "x": "index", "y": "values", "hue": "columns", "style": "columns", + } - plot_data = pd.DataFrame(data) - plot_data.loc[:, "x"] = plot_data.index - plot_data = pd.melt(plot_data, "x", - var_name="hue", - value_name="y") - plot_data["style"] = plot_data["hue"] + # TODO where best to define default parameters? + sort = True - else: - - # The input data is a nested list: We will either use the - # index or assign a numeric index for x, use the values - # for y, and use numeric hue/style identifiers. - - plot_data = [] - for i, data_i in enumerate(data): - x = getattr(data_i, "index", np.arange(len(data_i))) - n = getattr(data_i, "name", i) - data_i = dict(x=x, y=data_i, hue=n, style=n, size=None) - plot_data.append(pd.DataFrame(data_i)) - plot_data = pd.concat(plot_data) - - # Option 2: - # We have long-form data - # ---------------------- - - elif x is not None and y is not None: - - self.input_format = "long" - - # Use variables as from the dataframe if specified - if data is not None: - x = data.get(x, x) - y = data.get(y, y) - hue = data.get(hue, hue) - size = data.get(size, size) - style = data.get(style, style) - units = data.get(units, units) - - # Validate the inputs - for var in [x, y, hue, size, style, units]: - if isinstance(var, str): - err = "Could not interpret input '{}'".format(var) - raise ValueError(err) - - # Extract variable names - x_label = getattr(x, "name", None) - y_label = getattr(y, "name", None) - hue_label = getattr(hue, "name", None) - size_label = getattr(size, "name", None) - style_label = getattr(style, "name", None) - - # Reassemble into a DataFrame - plot_data = dict( - x=x, y=y, - hue=hue, style=style, size=size, - units=units - ) - plot_data = pd.DataFrame(plot_data) - - # Option 3: - # Only one variable argument - # -------------------------- - - else: - err = ("Either both or neither of `x` and `y` must be specified " - "(but try passing to `data`, which is more flexible).") - raise ValueError(err) - - # ---- Post-processing - - # Assign default values for missing attribute variables - for attr in ["hue", "style", "size", "units"]: - if attr not in plot_data: - plot_data[attr] = None - - # Determine which semantics have (some) data - plot_valid = plot_data.notnull().any() - semantics = ["x", "y"] + [ - name for name in ["hue", "size", "style"] - if plot_valid[name] - ] - - self.x_label = x_label - self.y_label = y_label - self.hue_label = hue_label - self.size_label = size_label - self.style_label = style_label - self.plot_data = plot_data - self.semantics = semantics + # Defaults for size semantic + # TODO this should match style of other defaults + _default_size_range = 0, 1 - return plot_data + # Defaults for style semantic + default_markers = ["o", "X", "s", "P", "D", "^", "v", "p"] + default_dashes = ["", (4, 1.5), (1, 1), + (3, 1, 1.5, 1), (5, 1, 1, 1), + (5, 1, 2, 1, 2, 1)] def categorical_to_palette(self, data, order, palette): """Determine colors when the hue variable is qualitative.""" @@ -331,12 +188,12 @@ def subset_data(self): if self.sort: subset_data = subset_data.sort_values(["units", "x", "y"]) - if self.units is None: + if "units" not in self.variables: subset_data = subset_data.drop("units", axis=1) yield (hue, size, style), subset_data - def parse_hue(self, data, palette, order, norm): + def parse_hue(self, data, palette=None, order=None, norm=None): """Determine what colors to use given data characteristics.""" if self._empty_data(data): @@ -395,7 +252,7 @@ def parse_hue(self, data, palette, order, norm): # Update data as it may have changed dtype self.plot_data["hue"] = data - def parse_size(self, data, sizes, order, norm): + def parse_size(self, data, sizes=None, order=None, norm=None): """Determine the linewidths given data characteristics.""" # TODO could break out two options like parse_hue does for clarity @@ -496,7 +353,7 @@ def parse_size(self, data, sizes, order, norm): # Update data as it may have changed dtype self.plot_data["size"] = data - def parse_style(self, data, markers, dashes, order): + def parse_style(self, data, markers=None, dashes=None, order=None): """Determine the markers and line dashes.""" if self._empty_data(data): @@ -568,12 +425,12 @@ def _semantic_type(self, data): def label_axes(self, ax): """Set x and y labels with visibility that matches the ticklabels.""" - if self.x_label is not None: + if "x" in self.variables and self.variables["x"] is not None: x_visible = any(t.get_visible() for t in ax.get_xticklabels()) - ax.set_xlabel(self.x_label, visible=x_visible) - if self.y_label is not None: + ax.set_xlabel(self.variables["x"], visible=x_visible) + if "y" in self.variables and self.variables["y"] is not None: y_visible = any(t.get_visible() for t in ax.get_yticklabels()) - ax.set_ylabel(self.y_label, visible=y_visible) + ax.set_ylabel(self.variables["y"], visible=y_visible) def add_legend_data(self, ax): """Add labeled artists to represent the different plot semantics.""" @@ -610,14 +467,15 @@ def update(var_name, val_name, **kws): hue_levels = hue_formatted_levels = self.hue_levels # Add the hue semantic subtitle - if self.hue_label is not None: - update((self.hue_label, "title"), self.hue_label, **title_kws) + if "hue" in self.variables and self.variables["hue"] is not None: + update((self.variables["hue"], "title"), + self.variables["hue"], **title_kws) # Add the hue semantic labels for level, formatted_level in zip(hue_levels, hue_formatted_levels): if level is not None: color = self.color_lookup(level) - update(self.hue_label, formatted_level, color=color) + update(self.variables["hue"], formatted_level, color=color) # -- Add a legend for size semantics @@ -632,26 +490,28 @@ def update(var_name, val_name, **kws): size_levels = size_formatted_levels = self.size_levels # Add the size semantic subtitle - if self.size_label is not None: - update((self.size_label, "title"), self.size_label, **title_kws) + if "size" in self.variables and self.variables["size"] is not None: + update((self.variables["size"], "title"), + self.variables["size"], **title_kws) # Add the size semantic labels for level, formatted_level in zip(size_levels, size_formatted_levels): if level is not None: size = self.size_lookup(level) - update( - self.size_label, formatted_level, linewidth=size, s=size) + update(self.variables["size"], + formatted_level, linewidth=size, s=size) # -- Add a legend for style semantics # Add the style semantic title - if self.style_label is not None: - update((self.style_label, "title"), self.style_label, **title_kws) + if "style" in self.variables and self.variables["style"] is not None: + update((self.variables["style"], "title"), + self.variables["style"], **title_kws) # Add the style semantic labels for level in self.style_levels: if level is not None: - update(self.style_label, level, + update(self.variables["style"], level, marker=self.markers.get(level, ""), dashes=self.dashes.get(level, "")) @@ -692,8 +552,8 @@ def __init__(self, units=None, estimator=None, ci=None, n_boot=None, seed=None, sort=True, err_style=None, err_kws=None, legend=None): - plot_data = self.establish_variables( - x, y, hue, size, style, units, data + plot_data, variables = self.establish_variables( + data, x=x, y=y, hue=hue, size=size, style=style, units=units, ) self._default_size_range = ( @@ -883,8 +743,8 @@ def __init__(self, alpha=None, x_jitter=None, y_jitter=None, legend=None): - plot_data = self.establish_variables( - x, y, hue, size, style, units, data + plot_data, variables = self.establish_variables( + data, x=x, y=y, hue=hue, size=size, style=style, units=units, ) self._default_size_range = ( @@ -934,12 +794,12 @@ def plot(self, ax, kws): # Assign arguments for plt.scatter and draw the plot - data = self.plot_data[self.semantics].dropna() + data = self.plot_data[list(self.variables)].dropna() if not data.size: return - x = data["x"] - y = data["y"] + x = data.get(["x"], np.full(len(data), np.nan)) + y = data.get(["y"], np.full(len(data), np.nan)) if self.palette: c = [self.palette.get(val) for val in data["hue"]] @@ -1662,8 +1522,10 @@ def relplot( # Check for attempt to plot onto specific axes and warn if "ax" in kwargs: - msg = ("relplot is a figure-level function and does not accept " - "target axes. You may wish to try {}".format(kind + "plot")) + msg = ( + "relplot is a figure-level function and does not accept " + "the ax= paramter. You may wish to try {}".format(kind + "plot") + ) warnings.warn(msg, UserWarning) kwargs.pop("ax") @@ -1676,6 +1538,7 @@ def relplot( legend=legend, ) + # Extract the semantic mappings palette = p.palette if p.palette else None hue_order = p.hue_levels if any(p.hue_levels) else None hue_norm = p.hue_norm if p.hue_norm is not None else None @@ -1688,6 +1551,12 @@ def relplot( dashes = p.dashes if p.dashes else None style_order = p.style_levels if any(p.style_levels) else None + # Now extract the data that would be used to draw a single plot + variables = p.variables + plot_data = p.plot_data + plot_semantics = p.semantics + + # Define the common plotting parameters plot_kws = dict( palette=palette, hue_order=hue_order, hue_norm=p.hue_norm, sizes=sizes, size_order=size_order, size_norm=p.size_norm, @@ -1698,22 +1567,52 @@ def relplot( if kind == "scatter": plot_kws.pop("dashes") + # Define the named variables for plotting on each facet + plot_variables = {key: key for key in p.variables} + plot_kws.update(plot_variables) + + # Define grid_data with row/col semantics + grid_semantics = ["row", "col"] # TODO define on FacetGrid? + p.semantics = plot_semantics + grid_semantics + full_data, full_variables = p.establish_variables( + data, + x=x, y=y, + hue=hue, size=size, style=style, + row=row, col=col, + ) + + # Assemble a data object with the plot_data from the original + # plotter and the row/col variables with their external names. + # This is so FacetGrid labels the subplots correctly. + # We can't use just full_data because the hue/size type inference can + # change the data type of variables with object type but numeric behavior. + grid_kws = {v: full_variables.get(v, None) for v in grid_semantics} + grid_data = full_data[grid_semantics].rename(columns=grid_kws) + plot_data = pd.concat([plot_data, grid_data], axis=1) + # Set up the FacetGrid object - facet_kws = {} if facet_kws is None else facet_kws + facet_kws = {} if facet_kws is None else facet_kws.copy() + facet_kws.update(grid_kws) g = FacetGrid( - data=data, row=row, col=col, col_wrap=col_wrap, - row_order=row_order, col_order=col_order, + data=plot_data, + col_wrap=col_wrap, row_order=row_order, col_order=col_order, height=height, aspect=aspect, dropna=False, **facet_kws ) # Draw the plot - g.map_dataframe(func, x, y, - hue=hue, size=size, style=style, - **plot_kws) + g.map_dataframe(func, **plot_kws) + + # Label the axes + g.set_axis_labels( + variables.get("x", None), variables.get("y", None) + ) # Show the legend if legend: + # Replace the original plot data so the legend uses + # numeric data with the correct type + p.plot_data = plot_data p.add_legend_data(g.axes.flat[0]) if p.legend_data: g.add_legend(legend_data=p.legend_data, diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py new file mode 100644 index 0000000000..913abe958b --- /dev/null +++ b/seaborn/tests/test_core.py @@ -0,0 +1,35 @@ +import numpy as np + +from numpy.testing import assert_array_equal + +from ..core import _VectorPlotter + + +class TestVectorPlotter: + + def test_flat_variables(self, flat_data): + + p = _VectorPlotter() + p.establish_variables(data=flat_data) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y"] + assert len(p.plot_data) == len(flat_data) + + try: + expected_x = flat_data.index + expected_x_name = flat_data.index.name + except AttributeError: + expected_x = np.arange(len(flat_data)) + expected_x_name = None + + x = p.plot_data["x"] + assert_array_equal(x, expected_x) + + expected_y = flat_data + expected_y_name = getattr(flat_data, "name", None) + + y = p.plot_data["y"] + assert_array_equal(y, expected_y) + + assert p.variables["x"] == expected_x_name + assert p.variables["y"] == expected_y_name diff --git a/seaborn/tests/test_relational.py b/seaborn/tests/test_relational.py index 102381708f..63ae5d1d85 100644 --- a/seaborn/tests/test_relational.py +++ b/seaborn/tests/test_relational.py @@ -4,13 +4,44 @@ import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt + import pytest -from .. import relational as rel +from numpy.testing import assert_array_equal + from ..palettes import color_palette from ..utils import categorical_order - -class TestRelationalPlotter(object): +from ..relational import ( + _RelationalPlotter, + _LinePlotter, + _ScatterPlotter, + relplot, + lineplot, + scatterplot +) + + +@pytest.fixture(params=[ + dict(x="x", y="y"), + dict(x="t", y="y"), + dict(x="a", y="y"), + dict(x="x", y="y", hue="y"), + dict(x="x", y="y", hue="a"), + dict(x="x", y="y", size="a"), + dict(x="x", y="y", style="a"), + dict(x="x", y="y", hue="s"), + dict(x="x", y="y", size="s"), + dict(x="x", y="y", style="s"), + dict(x="x", y="y", hue="a", style="a"), + dict(x="x", y="y", hue="a", size="b", style="b"), +]) +def long_semantics(request): + return request.param + + +class Helpers: + + # TODO Better place for these? def scatter_rgbs(self, collections): rgbs = [] @@ -36,348 +67,476 @@ def paths_equal(self, *args): equal &= np.array_equal(p1.codes, p2.codes) return equal - @pytest.fixture - def wide_df(self): - - columns = list("abc") - index = pd.Int64Index(np.arange(10, 50, 2), name="wide_index") - values = np.random.randn(len(index), len(columns)) - return pd.DataFrame(values, index=index, columns=columns) - - @pytest.fixture - def wide_array(self): - - return np.random.randn(20, 3) - - @pytest.fixture - def flat_array(self): - - return np.random.randn(20) - - @pytest.fixture - def flat_series(self): - - index = pd.Int64Index(np.arange(10, 30), name="t") - return pd.Series(np.random.randn(20), index, name="s") - - @pytest.fixture - def wide_list(self): - - return [np.random.randn(20), np.random.randn(10)] - - @pytest.fixture - def wide_list_of_series(self): - - return [pd.Series(np.random.randn(20), np.arange(20), name="a"), - pd.Series(np.random.randn(10), np.arange(5, 15), name="b")] - - @pytest.fixture - def long_df(self): - - n = 100 - rs = np.random.RandomState() - df = pd.DataFrame(dict( - x=rs.randint(0, 20, n), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - b=np.take(list("mnop"), rs.randint(0, 4, n)), - c=np.take(list([0, 1]), rs.randint(0, 2, n)), - d=np.repeat(np.datetime64('2005-02-25'), n), - s=np.take([2, 4, 8], rs.randint(0, 3, n)), - f=np.take(list([0.2, 0.3]), rs.randint(0, 2, n)), - )) - df["s_cat"] = df["s"].astype("category") - return df - - @pytest.fixture - def repeated_df(self): - - n = 100 - rs = np.random.RandomState() - return pd.DataFrame(dict( - x=np.tile(np.arange(n // 2), 2), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - u=np.repeat(np.arange(2), n // 2), - )) - - @pytest.fixture - def missing_df(self): - - n = 100 - rs = np.random.RandomState() - df = pd.DataFrame(dict( - x=rs.randint(0, 20, n), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - b=np.take(list("mnop"), rs.randint(0, 4, n)), - s=np.take([2, 4, 8], rs.randint(0, 3, n)), - )) - for col in df: - idx = rs.permutation(df.index)[:10] - df.loc[idx, col] = np.nan - return df - - @pytest.fixture - def null_column(self): - - return pd.Series(index=np.arange(20), dtype='float64') + +class TestRelationalPlotter(Helpers): def test_wide_df_variables(self, wide_df): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(data=wide_df) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] + assert list(p.variables) == ["x", "y", "hue", "style"] assert len(p.plot_data) == np.product(wide_df.shape) x = p.plot_data["x"] expected_x = np.tile(wide_df.index, wide_df.shape[1]) - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = wide_df.values.ravel(order="f") - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] expected_hue = np.repeat(wide_df.columns.values, wide_df.shape[0]) - assert np.array_equal(hue, expected_hue) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) + assert_array_equal(style, expected_style) assert p.plot_data["size"].isnull().all() - assert p.x_label == wide_df.index.name - assert p.y_label is None - assert p.hue_label == wide_df.columns.name - assert p.size_label is None - assert p.style_label == wide_df.columns.name + assert p.variables["x"] == wide_df.index.name + assert p.variables["y"] is None + assert p.variables["hue"] == wide_df.columns.name + assert p.variables["style"] == wide_df.columns.name - def test_wide_df_variables_check(self, wide_df): + def test_wide_df_with_nonnumeric_variables(self, long_df): - p = rel._RelationalPlotter() - wide_df = wide_df.copy() - wide_df.loc[:, "not_numeric"] = "a" - with pytest.raises(ValueError): - p.establish_variables(data=wide_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + + numeric_df = long_df.select_dtypes("number") + + assert len(p.plot_data) == np.product(numeric_df.shape) + + x = p.plot_data["x"] + expected_x = np.tile(numeric_df.index, numeric_df.shape[1]) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"] + expected_y = numeric_df.values.ravel(order="f") + assert_array_equal(y, expected_y) + + hue = p.plot_data["hue"] + expected_hue = np.repeat( + numeric_df.columns.values, numeric_df.shape[0] + ) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) + + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] == numeric_df.index.name + assert p.variables["y"] is None + assert p.variables["hue"] == numeric_df.columns.name + assert p.variables["style"] == numeric_df.columns.name def test_wide_array_variables(self, wide_array): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(data=wide_array) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] + assert list(p.variables) == ["x", "y", "hue", "style"] assert len(p.plot_data) == np.product(wide_array.shape) nrow, ncol = wide_array.shape x = p.plot_data["x"] expected_x = np.tile(np.arange(nrow), ncol) - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = wide_array.ravel(order="f") - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] expected_hue = np.repeat(np.arange(ncol), nrow) - assert np.array_equal(hue, expected_hue) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) + assert_array_equal(style, expected_style) assert p.plot_data["size"].isnull().all() - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None def test_flat_array_variables(self, flat_array): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(data=flat_array) assert p.input_format == "wide" - assert p.semantics == ["x", "y"] + assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == np.product(flat_array.shape) x = p.plot_data["x"] expected_x = np.arange(flat_array.shape[0]) - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = flat_array - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) assert p.plot_data["hue"].isnull().all() assert p.plot_data["style"].isnull().all() assert p.plot_data["size"].isnull().all() - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None + + def test_flat_list_variables(self, flat_list): + + p = _RelationalPlotter() + p.establish_variables(data=flat_list) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y"] + assert len(p.plot_data) == len(flat_list) + + x = p.plot_data["x"] + expected_x = np.arange(len(flat_list)) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"] + expected_y = flat_list + assert_array_equal(y, expected_y) + + assert p.plot_data["hue"].isnull().all() + assert p.plot_data["style"].isnull().all() + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] is None + assert p.variables["y"] is None def test_flat_series_variables(self, flat_series): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(data=flat_series) assert p.input_format == "wide" - assert p.semantics == ["x", "y"] + assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == len(flat_series) x = p.plot_data["x"] expected_x = flat_series.index - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = flat_series - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is flat_series.index.name + assert p.variables["y"] is flat_series.name - def test_wide_list_variables(self, wide_list): + def test_wide_list_of_series_variables(self, wide_list_of_series): - p = rel._RelationalPlotter() - p.establish_variables(data=wide_list) + p = _RelationalPlotter() + p.establish_variables(data=wide_list_of_series) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] - assert len(p.plot_data) == sum(len(l) for l in wide_list) + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_list_of_series) + chunk_size = max(len(l) for l in wide_list_of_series) + + assert len(p.plot_data) == chunks * chunk_size + + index_union = np.unique( + np.concatenate([s.index for s in wide_list_of_series]) + ) x = p.plot_data["x"] - expected_x = np.concatenate([np.arange(len(l)) for l in wide_list]) - assert np.array_equal(x, expected_x) + expected_x = np.tile(index_union, chunks) + assert_array_equal(x, expected_x) y = p.plot_data["y"] - expected_y = np.concatenate(wide_list) - assert np.array_equal(y, expected_y) + expected_y = np.concatenate([ + s.reindex(index_union) for s in wide_list_of_series + ]) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] - expected_hue = np.concatenate([ - np.ones_like(l) * i for i, l in enumerate(wide_list) - ]) - assert np.array_equal(hue, expected_hue) + series_names = [s.name for s in wide_list_of_series] + expected_hue = np.repeat(series_names, chunk_size) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) + assert_array_equal(style, expected_style) assert p.plot_data["size"].isnull().all() - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - def test_wide_list_of_series_variables(self, wide_list_of_series): + def test_wide_list_of_arrays_variables(self, wide_list_of_arrays): - p = rel._RelationalPlotter() - p.establish_variables(data=wide_list_of_series) + p = _RelationalPlotter() + p.establish_variables(data=wide_list_of_arrays) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] - assert len(p.plot_data) == sum(len(l) for l in wide_list_of_series) + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_list_of_arrays) + chunk_size = max(len(l) for l in wide_list_of_arrays) + + assert len(p.plot_data) == chunks * chunk_size x = p.plot_data["x"] - expected_x = np.concatenate([s.index for s in wide_list_of_series]) - assert np.array_equal(x, expected_x) + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - y = p.plot_data["y"] - expected_y = np.concatenate(wide_list_of_series) - assert np.array_equal(y, expected_y) + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(wide_list_of_arrays) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] - expected_hue = np.concatenate([ - np.full(len(s), s.name, object) for s in wide_list_of_series - ]) - assert np.array_equal(hue, expected_hue) + expected_hue = np.repeat(np.arange(chunks), chunk_size) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) + assert_array_equal(style, expected_style) assert p.plot_data["size"].isnull().all() - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None + + def test_wide_list_of_list_variables(self, wide_list_of_lists): + + p = _RelationalPlotter() + p.establish_variables(data=wide_list_of_lists) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_list_of_lists) + chunk_size = max(len(l) for l in wide_list_of_lists) + + assert len(p.plot_data) == chunks * chunk_size + + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(wide_list_of_lists) + assert_array_equal(y, expected_y) + + hue = p.plot_data["hue"] + expected_hue = np.repeat(np.arange(chunks), chunk_size) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) + + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None + + def test_wide_dict_of_series_variables(self, wide_dict_of_series): + + p = _RelationalPlotter() + p.establish_variables(data=wide_dict_of_series) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_dict_of_series) + chunk_size = max(len(l) for l in wide_dict_of_series.values()) + + assert len(p.plot_data) == chunks * chunk_size + + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_series.values())) + assert_array_equal(y, expected_y) + + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_series), chunk_size) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) + + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None + + def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays): + + p = _RelationalPlotter() + p.establish_variables(data=wide_dict_of_arrays) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_dict_of_arrays) + chunk_size = max(len(l) for l in wide_dict_of_arrays.values()) + + assert len(p.plot_data) == chunks * chunk_size + + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_arrays.values())) + assert_array_equal(y, expected_y) + + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_arrays), chunk_size) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) + + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None + + def test_wide_dict_of_lists_variables(self, wide_dict_of_lists): + + p = _RelationalPlotter() + p.establish_variables(data=wide_dict_of_lists) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_dict_of_lists) + chunk_size = max(len(l) for l in wide_dict_of_lists.values()) - def test_long_df(self, long_df): + assert len(p.plot_data) == chunks * chunk_size - p = rel._RelationalPlotter() - p.establish_variables(x="x", y="y", data=long_df) + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) + + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_lists.values())) + assert_array_equal(y, expected_y) + + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_lists), chunk_size) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) + + assert p.plot_data["size"].isnull().all() + + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None + + def test_long_df(self, long_df, long_semantics): + + p = _RelationalPlotter() + p.establish_variables(long_df, **long_semantics) assert p.input_format == "long" - assert p.semantics == ["x", "y"] + assert p.variables == long_semantics - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - for col in ["hue", "style", "size"]: + for key, val in long_semantics.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + for col in set(p.semantics) - set(long_semantics): assert p.plot_data[col].isnull().all() - assert (p.x_label, p.y_label) == ("x", "y") - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None - - p.establish_variables(x=long_df.x, y="y", data=long_df) - assert p.semantics == ["x", "y"] - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - assert (p.x_label, p.y_label) == ("x", "y") - - p.establish_variables(x="x", y=long_df.y, data=long_df) - assert p.semantics == ["x", "y"] - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - assert (p.x_label, p.y_label) == ("x", "y") - - p.establish_variables(x="x", y="y", hue="a", data=long_df) - assert p.semantics == ["x", "y", "hue"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - for col in ["style", "size"]: + + def test_long_df_with_index(self, long_df, long_semantics): + + p = _RelationalPlotter() + p.establish_variables(long_df.set_index("a"), **long_semantics) + assert p.input_format == "long" + assert p.variables == long_semantics + + for key, val in long_semantics.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + for col in set(p.semantics) - set(long_semantics): assert p.plot_data[col].isnull().all() - assert p.hue_label == "a" - assert p.size_label is None and p.style_label is None - p.establish_variables(x="x", y="y", hue="a", style="a", data=long_df) - assert p.semantics == ["x", "y", "hue", "style"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - assert np.array_equal(p.plot_data["style"], long_df["a"]) - assert p.plot_data["size"].isnull().all() - assert p.hue_label == p.style_label == "a" - assert p.size_label is None + def test_long_df_with_multiindex(self, long_df, long_semantics): - p.establish_variables(x="x", y="y", hue="a", style="b", data=long_df) - assert p.semantics == ["x", "y", "hue", "style"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - assert np.array_equal(p.plot_data["style"], long_df["b"]) - assert p.plot_data["size"].isnull().all() + p = _RelationalPlotter() + p.establish_variables(long_df.set_index(["a", "x"]), **long_semantics) + assert p.input_format == "long" + assert p.variables == long_semantics - p.establish_variables(x="x", y="y", size="y", data=long_df) - assert p.semantics == ["x", "y", "size"] - assert np.array_equal(p.plot_data["size"], long_df["y"]) - assert p.size_label == "y" - assert p.hue_label is None and p.style_label is None + for key, val in long_semantics.items(): + assert_array_equal(p.plot_data[key], long_df[val]) - def test_bad_input(self, long_df): + for col in set(p.semantics) - set(long_semantics): + assert p.plot_data[col].isnull().all() - p = rel._RelationalPlotter() + def test_long_dict(self, long_dict, long_semantics): - with pytest.raises(ValueError): - p.establish_variables(x=long_df.x) + p = _RelationalPlotter() + p.establish_variables(long_dict, **long_semantics) + assert p.input_format == "long" + assert p.variables == long_semantics - with pytest.raises(ValueError): - p.establish_variables(y=long_df.y) + for key, val in long_semantics.items(): + assert_array_equal(p.plot_data[key], pd.Series(long_dict[val])) + + for col in set(p.semantics) - set(long_semantics): + assert p.plot_data[col].isnull().all() + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_long_vectors(self, long_df, long_semantics, vector_type): + + kws = {key: long_df[val] for key, val in long_semantics.items()} + if vector_type == "numpy": + # Requires pandas >= 0.24 + # kws = {key: val.to_numpy() for key, val in kws.items()} + kws = {key: np.asarray(val) for key, val in kws.items()} + elif vector_type == "list": + # Requires pandas >= 0.24 + # kws = {key: val.to_list() for key, val in kws.items()} + kws = {key: val.tolist() for key, val in kws.items()} + + p = _RelationalPlotter() + p.establish_variables(**kws) + assert p.input_format == "long" + + assert list(p.variables) == list(long_semantics) + if vector_type == "series": + assert p.variables == long_semantics + + for key, val in long_semantics.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + for col in set(p.semantics) - set(long_semantics): + assert p.plot_data[col].isnull().all() + + def test_long_undefined_variables(self, long_df): + + p = _RelationalPlotter() with pytest.raises(ValueError): p.establish_variables(x="not_in_df", data=long_df) @@ -386,11 +545,11 @@ def test_bad_input(self, long_df): p.establish_variables(x="x", y="not_in_df", data=long_df) with pytest.raises(ValueError): - p.establish_variables(x="x", y="not_in_df", data=long_df) + p.establish_variables(x="x", y="y", hue="not_in_df", data=long_df) def test_empty_input(self): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(data=[]) p.establish_variables(data=np.array([])) @@ -399,14 +558,15 @@ def test_empty_input(self): def test_units(self, repeated_df): - p = rel._RelationalPlotter() + p = _RelationalPlotter() p.establish_variables(x="x", y="y", units="u", data=repeated_df) - assert np.array_equal(p.plot_data["units"], repeated_df["u"]) + assert_array_equal(p.plot_data["units"], repeated_df["u"]) - def test_parse_hue_null(self, wide_df, null_column): + def test_parse_hue_null(self, wide_df, null_series): - p = rel._LinePlotter(data=wide_df) - p.parse_hue(null_column, "Blues", None, None) + p = _RelationalPlotter() + p.establish_variables(wide_df) + p.parse_hue(null_series, "Blues", None, None) assert p.hue_levels == [None] assert p.palette == {} assert p.hue_type is None @@ -414,7 +574,9 @@ def test_parse_hue_null(self, wide_df, null_column): def test_parse_hue_categorical(self, wide_df, long_df): - p = rel._LinePlotter(data=wide_df) + p = _RelationalPlotter() + p.establish_variables(data=wide_df) + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == wide_df.columns.tolist() assert p.hue_type == "categorical" assert p.cmap is None @@ -423,44 +585,46 @@ def test_parse_hue_categorical(self, wide_df, long_df): palette = "Blues" expected_colors = color_palette(palette, wide_df.shape[1]) expected_palette = dict(zip(wide_df.columns, expected_colors)) - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.palette == expected_palette # Test list palette palette = color_palette("Reds", wide_df.shape[1]) - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) expected_palette = dict(zip(wide_df.columns, palette)) assert p.palette == expected_palette # Test dict palette colors = color_palette("Set1", 8) palette = dict(zip(wide_df.columns, colors)) - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.palette == palette # Test dict with missing keys palette = dict(zip(wide_df.columns[:-1], colors)) with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) # Test list with wrong number of colors palette = colors[:-1] with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) # Test hue order hue_order = ["a", "c", "d"] - p.parse_hue(p.plot_data.hue, None, hue_order, None) + p.parse_hue(p.plot_data["hue"], order=hue_order) assert p.hue_levels == hue_order # Test long data - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="a") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == categorical_order(long_df.a) assert p.hue_type == "categorical" assert p.cmap is None # Test default palette - p.parse_hue(p.plot_data.hue, None, None, None) + p.parse_hue(p.plot_data["hue"]) hue_levels = categorical_order(long_df.a) expected_colors = color_palette(n_colors=len(hue_levels)) expected_palette = dict(zip(hue_levels, expected_colors)) @@ -468,41 +632,52 @@ def test_parse_hue_categorical(self, wide_df, long_df): # Test default palette with many levels levels = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - p.parse_hue(levels, None, None, None) + p.parse_hue(levels) expected_colors = color_palette("husl", n_colors=len(levels)) expected_palette = dict(zip(levels, expected_colors)) assert p.palette == expected_palette # Test binary data - p = rel._LinePlotter(x="x", y="y", hue="c", data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="c") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == [0, 1] assert p.hue_type == "categorical" df = long_df[long_df["c"] == 0] - p = rel._LinePlotter(x="x", y="y", hue="c", data=df) + p = _RelationalPlotter() + p.establish_variables(data=df, x="x", y="y", hue="c") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == [0] assert p.hue_type == "categorical" df = long_df[long_df["c"] == 1] - p = rel._LinePlotter(x="x", y="y", hue="c", data=df) + p = _RelationalPlotter() + p.establish_variables(data=df, x="x", y="y", hue="c") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == [1] assert p.hue_type == "categorical" # Test Timestamp data - p = rel._LinePlotter(x="x", y="y", hue="d", data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="t") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == [pd.Timestamp('2005-02-25')] assert p.hue_type == "categorical" # Test numeric data with category type - p = rel._LinePlotter(x="x", y="y", hue="s_cat", data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="s_cat") + p.parse_hue(p.plot_data["hue"]) assert p.hue_levels == categorical_order(long_df.s_cat) assert p.hue_type == "categorical" assert p.cmap is None # Test categorical palette specified for numeric data palette = "deep" - p = rel._LinePlotter(x="x", y="y", hue="s", - palette=palette, data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="s") + p.parse_hue(p.plot_data["hue"], palette=palette) expected_colors = color_palette(palette, n_colors=len(levels)) hue_levels = categorical_order(long_df["s"]) expected_palette = dict(zip(hue_levels, expected_colors)) @@ -511,7 +686,9 @@ def test_parse_hue_categorical(self, wide_df, long_df): def test_parse_hue_numeric(self, long_df): - p = rel._LinePlotter(x="x", y="y", hue="s", data=long_df) + p = _RelationalPlotter() + p.establish_variables(data=long_df, x="x", y="y", hue="s") + p.parse_hue(p.plot_data["hue"]) hue_levels = list(np.sort(long_df.s.unique())) assert p.hue_levels == hue_levels assert p.hue_type == "numeric" @@ -519,26 +696,27 @@ def test_parse_hue_numeric(self, long_df): # Test named colormap palette = "Purples" - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.cmap is mpl.cm.get_cmap(palette) # Test colormap object palette = mpl.cm.get_cmap("Greens") - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.cmap is palette # Test cubehelix shorthand palette = "ch:2,0,light=.2" - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert isinstance(p.cmap, mpl.colors.ListedColormap) # Test default hue limits - p.parse_hue(p.plot_data.hue, None, None, None) - assert p.hue_limits == (p.plot_data.hue.min(), p.plot_data.hue.max()) + p.parse_hue(p.plot_data["hue"]) + data_range = p.plot_data["hue"].min(), p.plot_data["hue"].max() + assert p.hue_limits == data_range # Test specified hue limits hue_norm = 1, 4 - p.parse_hue(p.plot_data.hue, None, None, hue_norm) + p.parse_hue(p.plot_data["hue"], norm=hue_norm) assert p.hue_limits == hue_norm assert isinstance(p.hue_norm, mpl.colors.Normalize) assert p.hue_norm.vmin == hue_norm[0] @@ -546,19 +724,19 @@ def test_parse_hue_numeric(self, long_df): # Test Normalize object hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10) - p.parse_hue(p.plot_data.hue, None, None, hue_norm) + p.parse_hue(p.plot_data["hue"], norm=hue_norm) assert p.hue_limits == (hue_norm.vmin, hue_norm.vmax) assert p.hue_norm is hue_norm # Test default colormap values - hmin, hmax = p.plot_data.hue.min(), p.plot_data.hue.max() - p.parse_hue(p.plot_data.hue, None, None, None) + hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max() + p.parse_hue(p.plot_data["hue"]) assert p.palette[hmin] == pytest.approx(p.cmap(0.0)) assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) # Test specified colormap values hue_norm = hmin - 1, hmax - 1 - p.parse_hue(p.plot_data.hue, None, None, hue_norm) + p.parse_hue(p.plot_data["hue"], norm=hue_norm) norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0]) assert p.palette[hmin] == pytest.approx(p.cmap(norm_min)) assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) @@ -566,59 +744,59 @@ def test_parse_hue_numeric(self, long_df): # Test list of colors hue_levels = list(np.sort(long_df.s.unique())) palette = color_palette("Blues", len(hue_levels)) - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.palette == dict(zip(hue_levels, palette)) palette = color_palette("Blues", len(hue_levels) + 1) with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) # Test dictionary of colors palette = dict(zip(hue_levels, color_palette("Reds"))) - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) assert p.palette == palette palette.pop(hue_levels[0]) with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) # Test invalid palette palette = "not_a_valid_palette" with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p.parse_hue(p.plot_data["hue"], palette=palette) # Test bad norm argument hue_norm = "not a norm" with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, None, None, hue_norm) + p.parse_hue(p.plot_data["hue"], norm=hue_norm) def test_parse_size(self, long_df): - p = rel._LinePlotter(x="x", y="y", size="s", data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", size="s") # Test default size limits and range - default_linewidth = mpl.rcParams["lines.linewidth"] default_limits = p.plot_data["size"].min(), p.plot_data["size"].max() - default_range = .5 * default_linewidth, 2 * default_linewidth - p.parse_size(p.plot_data["size"], None, None, None) + default_range = p._default_size_range + p.parse_size(p.plot_data["size"]) assert p.size_limits == default_limits size_range = min(p.sizes.values()), max(p.sizes.values()) assert size_range == default_range # Test specified size limits size_limits = (1, 5) - p.parse_size(p.plot_data["size"], None, None, size_limits) + p.parse_size(p.plot_data["size"], norm=size_limits) assert p.size_limits == size_limits # Test specified size range sizes = (.1, .5) - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) assert p.size_limits == default_limits # Test size values with normalization range sizes = (1, 5) size_norm = (1, 10) - p.parse_size(p.plot_data["size"], sizes, None, size_norm) + p.parse_size(p.plot_data["size"], sizes=sizes, norm=size_norm) normalize = mpl.colors.Normalize(*size_norm, clip=True) for level, width in p.sizes.items(): assert width == sizes[0] + (sizes[1] - sizes[0]) * normalize(level) @@ -626,77 +804,75 @@ def test_parse_size(self, long_df): # Test size values with normalization object sizes = (1, 5) size_norm = mpl.colors.LogNorm(1, 10, clip=False) - p.parse_size(p.plot_data["size"], sizes, None, size_norm) + p.parse_size(p.plot_data["size"], sizes=sizes, norm=size_norm) assert p.size_norm.clip for level, width in p.sizes.items(): assert width == sizes[0] + (sizes[1] - sizes[0]) * size_norm(level) - # Test specified size order + # Use a categorical variable var = "a" + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", size=var) + + # Test specified size order levels = long_df[var].unique() sizes = [1, 4, 6] size_order = [levels[1], levels[2], levels[0]] - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, size_order, None) + p.parse_size(p.plot_data["size"], sizes=sizes, order=size_order) assert p.sizes == dict(zip(size_order, sizes)) # Test list of sizes - var = "a" levels = categorical_order(long_df[var]) sizes = list(np.random.rand(len(levels))) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) assert p.sizes == dict(zip(levels, sizes)) # Test dict of sizes - var = "a" - levels = categorical_order(long_df[var]) sizes = dict(zip(levels, np.random.rand(len(levels)))) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) assert p.sizes == sizes # Test sizes list with wrong length sizes = list(np.random.rand(len(levels) + 1)) with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) # Test sizes dict with missing levels sizes = dict(zip(levels, np.random.rand(len(levels) - 1))) with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) # Test bad sizes argument sizes = "bad_size" with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + p.parse_size(p.plot_data["size"], sizes=sizes) # Test bad norm argument size_norm = "not a norm" - p = rel._LinePlotter(x="x", y="y", size="s", data=long_df) with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], None, None, size_norm) + p.parse_size(p.plot_data["size"], norm=size_norm) def test_parse_style(self, long_df): - p = rel._LinePlotter(x="x", y="y", style="a", data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", style="a") # Test defaults markers, dashes = True, True - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) assert p.markers == dict(zip(p.style_levels, p.default_markers)) assert p.dashes == dict(zip(p.style_levels, p.default_dashes)) # Test lists markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)] - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) assert p.markers == dict(zip(p.style_levels, markers)) assert p.dashes == dict(zip(p.style_levels, dashes)) # Test dicts markers = dict(zip(p.style_levels, markers)) dashes = dict(zip(p.style_levels, dashes)) - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) assert p.markers == markers assert p.dashes == dashes @@ -710,29 +886,33 @@ def test_parse_style(self, long_df): # Test too many levels with style lists markers, dashes = ["o", "s"], False with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) markers, dashes = False, [(2, 1)] with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) # Test too many levels with style dicts markers, dashes = {"a": "o", "b": "s"}, False with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) markers, dashes = False, {"a": (1, 0), "b": (2, 1)} with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) # Test mixture of filled and unfilled markers markers, dashes = ["o", "x", "s"], None with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p.parse_style(p.plot_data["style"], markers, dashes) def test_subset_data_quantities(self, long_df): - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["size"]) assert len(list(p.subset_data())) == 1 # -- @@ -740,23 +920,26 @@ def test_subset_data_quantities(self, long_df): var = "a" n_subsets = len(long_df[var].unique()) - p = rel._LinePlotter(x="x", y="y", hue=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets - - p = rel._LinePlotter(x="x", y="y", style=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + for semantic in ["hue", "size", "style"]: - n_subsets = len(long_df[var].unique()) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", **{semantic: var}) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + assert len(list(p.subset_data())) == n_subsets # -- var = "a" n_subsets = len(long_df[var].unique()) - p = rel._LinePlotter(x="x", y="y", hue=var, style=var, data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue=var, style=var) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets # -- @@ -764,12 +947,20 @@ def test_subset_data_quantities(self, long_df): var1, var2 = "a", "s" n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values)))) - p = rel._LinePlotter(x="x", y="y", hue=var1, style=var2, - data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue=var1, style=var2) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, style=var1, - data=long_df) + p = _RelationalPlotter() + p.establish_variables( + long_df, x="x", y="y", hue=var1, size=var2, style=var1, + ) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets # -- @@ -778,13 +969,22 @@ def test_subset_data_quantities(self, long_df): cols = [var1, var2, var3] n_subsets = len(set(list(map(tuple, long_df[cols].values)))) - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, style=var3, - data=long_df) + p = _RelationalPlotter() + p.establish_variables( + long_df, x="x", y="y", hue=var1, size=var2, style=var3, + ) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets def test_subset_data_keys(self, long_df): - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) for (hue, size, style), _ in p.subset_data(): assert hue is None assert size is None @@ -794,35 +994,55 @@ def test_subset_data_keys(self, long_df): var = "a" - p = rel._LinePlotter(x="x", y="y", hue=var, data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue=var) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var].values assert size is None assert style is None - p = rel._LinePlotter(x="x", y="y", style=var, data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", size=var) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + for (hue, size, style), _ in p.subset_data(): + assert hue is None + assert size in long_df[var].values + assert style is None + + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", style=var) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) for (hue, size, style), _ in p.subset_data(): assert hue is None assert size is None assert style in long_df[var].values - p = rel._LinePlotter(x="x", y="y", hue=var, style=var, data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue=var, style=var) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var].values assert size is None assert style in long_df[var].values - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue is None - assert size in long_df[var].values - assert style is None - # -- var1, var2 = "a", "s" - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue=var1, size=var2) + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var1].values assert size in long_df[var2].values @@ -830,50 +1050,260 @@ def test_subset_data_keys(self, long_df): def test_subset_data_values(self, long_df): - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = True _, data = next(p.subset_data()) expected = p.plot_data.loc[:, ["x", "y"]].sort_values(["x", "y"]) - assert np.array_equal(data.values, expected) + assert_array_equal(data.values, expected) - p = rel._LinePlotter(x="x", y="y", data=long_df, sort=False) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = False _, data = next(p.subset_data()) expected = p.plot_data.loc[:, ["x", "y"]] - assert np.array_equal(data.values, expected) - - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) + assert_array_equal(data.values, expected) + + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue="a") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = True for (hue, _, _), data in p.subset_data(): rows = p.plot_data["hue"] == hue cols = ["x", "y"] expected = p.plot_data.loc[rows, cols].sort_values(cols) - assert np.array_equal(data.values, expected.values) + assert_array_equal(data.values, expected.values) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, sort=False) + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue="a") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = False for (hue, _, _), data in p.subset_data(): rows = p.plot_data["hue"] == hue cols = ["x", "y"] expected = p.plot_data.loc[rows, cols] - assert np.array_equal(data.values, expected.values) - - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", data=long_df) - for (hue, _, _), data in p.subset_data(): - rows = p.plot_data["hue"] == hue + assert_array_equal(data.values, expected.values) + + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", style="a") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = True + for (_, _, style), data in p.subset_data(): + rows = p.plot_data["style"] == style cols = ["x", "y"] expected = p.plot_data.loc[rows, cols].sort_values(cols) - assert np.array_equal(data.values, expected.values) - - p = rel._LinePlotter(x="x", y="y", hue="a", size="s", data=long_df) + assert_array_equal(data.values, expected.values) + + p = _RelationalPlotter() + p.establish_variables(long_df, x="x", y="y", hue="a", size="s") + p.parse_hue(p.plot_data["hue"]) + p.parse_size(p.plot_data["size"]) + p.parse_style(p.plot_data["style"]) + p.sort = True for (hue, size, _), data in p.subset_data(): rows = (p.plot_data["hue"] == hue) & (p.plot_data["size"] == size) cols = ["x", "y"] expected = p.plot_data.loc[rows, cols].sort_values(cols) - assert np.array_equal(data.values, expected.values) + assert_array_equal(data.values, expected.values) + + def test_relplot_simple(self, long_df): + + g = relplot(x="x", y="y", kind="scatter", data=long_df) + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(x, long_df["x"]) + assert_array_equal(y, long_df["y"]) + + g = relplot(x="x", y="y", kind="line", data=long_df) + x, y = g.ax.lines[0].get_xydata().T + expected = long_df.groupby("x").y.mean() + assert_array_equal(x, expected.index) + assert y == pytest.approx(expected.values) + + with pytest.raises(ValueError): + g = relplot(x="x", y="y", kind="not_a_kind", data=long_df) + + def test_relplot_complex(self, long_df): + + for sem in ["hue", "size", "style"]: + g = relplot(x="x", y="y", data=long_df, **{sem: "a"}) + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(x, long_df["x"]) + assert_array_equal(y, long_df["y"]) + + for sem in ["hue", "size", "style"]: + g = relplot( + x="x", y="y", col="c", data=long_df, **{sem: "a"} + ) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) + + for sem in ["size", "style"]: + g = relplot( + x="x", y="y", hue="b", col="c", data=long_df, **{sem: "a"} + ) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) + for sem in ["hue", "size", "style"]: + g = relplot( + x="x", y="y", col="b", row="c", + data=long_df.sort_values(["c", "b"]), **{sem: "a"} + ) + grouped = long_df.groupby(["c", "b"]) + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_relplot_vectors(self, long_df, vector_type): + + semantics = dict(x="x", y="y", hue="f", col="c") + kws = {key: long_df[val] for key, val in semantics.items()} + g = relplot(data=long_df, **kws) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) -class TestLinePlotter(TestRelationalPlotter): + def test_relplot_wide(self, wide_df): + + g = relplot(data=wide_df) + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(y, wide_df.values.T.ravel()) + + def test_relplot_hues(self, long_df): + + palette = ["r", "b", "g"] + g = relplot( + x="x", y="y", hue="a", style="b", col="c", + palette=palette, data=long_df + ) + + palette = dict(zip(long_df["a"].unique(), palette)) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_hues = [palette[val] for val in grp_df["a"]] + assert self.colors_equal(points.get_facecolors(), expected_hues) + + def test_relplot_sizes(self, long_df): + + sizes = [5, 12, 7] + g = relplot( + x="x", y="y", size="a", hue="b", col="c", + sizes=sizes, data=long_df + ) + + sizes = dict(zip(long_df["a"].unique(), sizes)) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_sizes = [sizes[val] for val in grp_df["a"]] + assert_array_equal(points.get_sizes(), expected_sizes) + + def test_relplot_styles(self, long_df): + + markers = ["o", "d", "s"] + g = relplot( + x="x", y="y", style="a", hue="b", col="c", + markers=markers, data=long_df + ) + + paths = [] + for m in markers: + m = mpl.markers.MarkerStyle(m) + paths.append(m.get_path().transformed(m.get_transform())) + paths = dict(zip(long_df["a"].unique(), paths)) + + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_paths = [paths[val] for val in grp_df["a"]] + assert self.paths_equal(points.get_paths(), expected_paths) + + def test_relplot_stringy_numerics(self, long_df): + + long_df["x_str"] = long_df["x"].astype(str) + + g = relplot(x="x", y="y", hue="x_str", data=long_df) + points = g.ax.collections[0] + xys = points.get_offsets() + mask = np.ma.getmask(xys) + assert not mask.any() + assert_array_equal(xys, long_df[["x", "y"]]) + + g = relplot(x="x", y="y", size="x_str", data=long_df) + points = g.ax.collections[0] + xys = points.get_offsets() + mask = np.ma.getmask(xys) + assert not mask.any() + assert_array_equal(xys, long_df[["x", "y"]]) + + def test_relplot_legend(self, long_df): + + g = relplot(x="x", y="y", data=long_df) + assert g._legend is None + + g = relplot(x="x", y="y", hue="a", data=long_df) + texts = [t.get_text() for t in g._legend.texts] + expected_texts = np.append(["a"], long_df["a"].unique()) + assert_array_equal(texts, expected_texts) + + g = relplot(x="x", y="y", hue="s", size="s", data=long_df) + texts = [t.get_text() for t in g._legend.texts] + assert_array_equal(texts[1:], np.sort(texts[1:])) + + g = relplot(x="x", y="y", hue="a", legend=False, data=long_df) + assert g._legend is None + + palette = color_palette("deep", len(long_df["b"].unique())) + a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique())) + long_df["a_like_b"] = long_df["a"].map(a_like_b) + g = relplot( + x="x", y="y", hue="b", style="a_like_b", + palette=palette, kind="line", estimator=None, data=long_df + ) + lines = g._legend.get_lines()[1:] # Chop off title dummy + for line, color in zip(lines, palette): + assert line.get_color() == color + + def test_ax_kwarg_removal(self, long_df): + + f, ax = plt.subplots() + with pytest.warns(UserWarning): + g = relplot("x", "y", data=long_df, ax=ax) + assert len(ax.collections) == 0 + assert len(g.ax.collections) > 0 + + +class TestLinePlotter(Helpers): def test_aggregate(self, long_df): - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter(x="x", y="y", data=long_df) p.n_boot = 10000 p.sort = False @@ -892,7 +1322,7 @@ def sem(x): p.ci = 68 p.estimator = "mean" index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) + assert_array_equal(index.values, x.unique()) assert est.index.equals(index) assert est.values == pytest.approx(y_mean.values) assert cis.values == pytest.approx(y_cis.values, 4) @@ -900,7 +1330,7 @@ def sem(x): p.estimator = np.mean index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) + assert_array_equal(index.values, x.unique()) assert est.index.equals(index) assert est.values == pytest.approx(y_mean.values) assert cis.values == pytest.approx(y_cis.values, 4) @@ -909,7 +1339,7 @@ def sem(x): p.seed = 0 _, _, ci1 = p.aggregate(y, x) _, _, ci2 = p.aggregate(y, x) - assert np.array_equal(ci1, ci2) + assert_array_equal(ci1, ci2) y_std = y.groupby(x).std() y_cis = pd.DataFrame(dict(low=y_mean - y_std, @@ -918,7 +1348,7 @@ def sem(x): p.ci = "sd" index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) + assert_array_equal(index.values, x.unique()) assert est.index.equals(index) assert est.values == pytest.approx(y_mean.values) assert cis.values == pytest.approx(y_cis.values) @@ -931,15 +1361,15 @@ def sem(x): p.ci = 68 x, y = pd.Series([1, 2, 3]), pd.Series([4, 3, 2]) index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x) - assert np.array_equal(est.values, y) + assert_array_equal(index.values, x) + assert_array_equal(est.values, y) assert cis is None x, y = pd.Series([1, 1, 2]), pd.Series([2, 3, 4]) index, est, cis = p.aggregate(y, x) assert cis.loc[2].isnull().all() - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter(x="x", y="y", data=long_df) p.estimator = "mean" p.n_boot = 100 p.ci = 95 @@ -954,7 +1384,7 @@ def test_legend_data(self, long_df): f, ax = plt.subplots() - p = rel._LinePlotter(x="x", y="y", data=long_df, legend="full") + p = _LinePlotter(x="x", y="y", data=long_df, legend="full") p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert handles == [] @@ -962,8 +1392,9 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - legend="full") + p = _LinePlotter( + x="x", y="y", hue="a", data=long_df, legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] @@ -973,8 +1404,10 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df) + p = _LinePlotter( + x="x", y="y", hue="a", style="a", + markers=True, legend="full", data=long_df + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] @@ -986,8 +1419,10 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df) + p = _LinePlotter( + x="x", y="y", hue="a", style="b", + markers=True, legend="full", data=long_df + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] @@ -1003,8 +1438,9 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", size="a", data=long_df, - legend="full") + p = _LinePlotter( + x="x", y="y", hue="a", size="a", data=long_df, legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] @@ -1018,7 +1454,7 @@ def test_legend_data(self, long_df): x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = rel._LinePlotter(x=x, y=y, hue=z) + p = _LinePlotter(x=x, y=y, hue=z) ax.clear() p.legend = "full" @@ -1032,7 +1468,7 @@ def test_legend_data(self, long_df): handles, labels = ax.get_legend_handles_labels() assert len(labels) == 4 - p = rel._LinePlotter(x=x, y=y, size=z) + p = _LinePlotter(x=x, y=y, size=z) ax.clear() p.legend = "full" @@ -1052,32 +1488,36 @@ def test_legend_data(self, long_df): p.add_legend_data(ax) ax.clear() - p = rel._LinePlotter(x=x, y=y, hue=z, - hue_norm=mpl.colors.LogNorm(), - legend="brief") + p = _LinePlotter( + x=x, y=y, hue=z, + hue_norm=mpl.colors.LogNorm(), legend="brief" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert float(labels[2]) / float(labels[1]) == 10 ax.clear() - p = rel._LinePlotter(x=x, y=y, size=z, - size_norm=mpl.colors.LogNorm(), - legend="brief") + p = _LinePlotter( + x=x, y=y, size=z, + size_norm=mpl.colors.LogNorm(), legend="brief" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert float(labels[2]) / float(labels[1]) == 10 ax.clear() - p = rel._LinePlotter( - x="x", y="y", hue="f", legend="brief", data=long_df) + p = _LinePlotter( + x="x", y="y", hue="f", legend="brief", data=long_df + ) p.add_legend_data(ax) expected_levels = ['0.20', '0.24', '0.28', '0.32'] handles, labels = ax.get_legend_handles_labels() assert labels == ["f"] + expected_levels ax.clear() - p = rel._LinePlotter( - x="x", y="y", size="f", legend="brief", data=long_df) + p = _LinePlotter( + x="x", y="y", size="f", legend="brief", data=long_df + ) p.add_legend_data(ax) expected_levels = ['0.20', '0.24', '0.28', '0.32'] handles, labels = ax.get_legend_handles_labels() @@ -1087,12 +1527,13 @@ def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() - p = rel._LinePlotter(x="x", y="y", data=long_df, - sort=False, estimator=None) + p = _LinePlotter( + x="x", y="y", data=long_df, sort=False, estimator=None + ) p.plot(ax, {}) line, = ax.lines - assert np.array_equal(line.get_xdata(), long_df.x.values) - assert np.array_equal(line.get_ydata(), long_df.y.values) + assert_array_equal(line.get_xdata(), long_df.x.values) + assert_array_equal(line.get_ydata(), long_df.y.values) ax.clear() p.plot(ax, {"color": "k", "label": "test"}) @@ -1100,17 +1541,18 @@ def test_plot(self, long_df, repeated_df): assert line.get_color() == "k" assert line.get_label() == "test" - p = rel._LinePlotter(x="x", y="y", data=long_df, - sort=True, estimator=None) + p = _LinePlotter( + x="x", y="y", data=long_df, sort=True, estimator=None + ) ax.clear() p.plot(ax, {}) line, = ax.lines sorted_data = long_df.sort_values(["x", "y"]) - assert np.array_equal(line.get_xdata(), sorted_data.x.values) - assert np.array_equal(line.get_ydata(), sorted_data.y.values) + assert_array_equal(line.get_xdata(), sorted_data.x.values) + assert_array_equal(line.get_ydata(), sorted_data.y.values) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) + p = _LinePlotter(x="x", y="y", hue="a", data=long_df) ax.clear() p.plot(ax, {}) @@ -1118,7 +1560,7 @@ def test_plot(self, long_df, repeated_df): for line, level in zip(ax.lines, p.hue_levels): assert line.get_color() == p.palette[level] - p = rel._LinePlotter(x="x", y="y", size="a", data=long_df) + p = _LinePlotter(x="x", y="y", size="a", data=long_df) ax.clear() p.plot(ax, {}) @@ -1126,8 +1568,9 @@ def test_plot(self, long_df, repeated_df): for line, level in zip(ax.lines, p.size_levels): assert line.get_linewidth() == p.sizes[level] - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", - markers=True, data=long_df) + p = _LinePlotter( + x="x", y="y", hue="a", style="a", markers=True, data=long_df + ) ax.clear() p.plot(ax, {}) @@ -1136,8 +1579,9 @@ def test_plot(self, long_df, repeated_df): assert line.get_color() == p.palette[level] assert line.get_marker() == p.markers[level] - p = rel._LinePlotter(x="x", y="y", hue="a", style="b", - markers=True, data=long_df) + p = _LinePlotter( + x="x", y="y", hue="a", style="b", markers=True, data=long_df + ) ax.clear() p.plot(ax, {}) @@ -1147,20 +1591,23 @@ def test_plot(self, long_df, repeated_df): assert line.get_color() == p.palette[hue] assert line.get_marker() == p.markers[style] - p = rel._LinePlotter(x="x", y="y", data=long_df, - estimator="mean", err_style="band", ci="sd", - sort=True) + p = _LinePlotter( + x="x", y="y", data=long_df, + estimator="mean", err_style="band", ci="sd", sort=True + ) ax.clear() p.plot(ax, {}) line, = ax.lines expected_data = long_df.groupby("x").y.mean() - assert np.array_equal(line.get_xdata(), expected_data.index.values) + assert_array_equal(line.get_xdata(), expected_data.index.values) assert np.allclose(line.get_ydata(), expected_data.values) assert len(ax.collections) == 1 - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - estimator="mean", err_style="band", ci="sd") + p = _LinePlotter( + x="x", y="y", hue="a", data=long_df, + estimator="mean", err_style="band", ci="sd" + ) ax.clear() p.plot(ax, {}) @@ -1168,8 +1615,10 @@ def test_plot(self, long_df, repeated_df): for c in ax.collections: assert isinstance(c, mpl.collections.PolyCollection) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - estimator="mean", err_style="bars", ci="sd") + p = _LinePlotter( + x="x", y="y", hue="a", data=long_df, + estimator="mean", err_style="bars", ci="sd" + ) ax.clear() p.plot(ax, {}) @@ -1179,16 +1628,18 @@ def test_plot(self, long_df, repeated_df): for c in ax.collections: assert isinstance(c, mpl.collections.LineCollection) - p = rel._LinePlotter(x="x", y="y", data=repeated_df, - units="u", estimator=None) + p = _LinePlotter( + x="x", y="y", data=repeated_df, units="u", estimator=None + ) ax.clear() p.plot(ax, {}) n_units = len(repeated_df["u"].unique()) assert len(ax.lines) == n_units - p = rel._LinePlotter(x="x", y="y", hue="a", data=repeated_df, - units="u", estimator=None) + p = _LinePlotter( + x="x", y="y", hue="a", data=repeated_df, units="u", estimator=None + ) ax.clear() p.plot(ax, {}) @@ -1199,16 +1650,20 @@ def test_plot(self, long_df, repeated_df): with pytest.raises(ValueError): p.plot(ax, {}) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - err_style="band", err_kws={"alpha": .5}) + p = _LinePlotter( + x="x", y="y", hue="a", data=long_df, + err_style="band", err_kws={"alpha": .5} + ) ax.clear() p.plot(ax, {}) for band in ax.collections: assert band.get_alpha() == .5 - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - err_style="bars", err_kws={"elinewidth": 2}) + p = _LinePlotter( + x="x", y="y", hue="a", data=long_df, + err_style="bars", err_kws={"elinewidth": 2} + ) ax.clear() p.plot(ax, {}) @@ -1220,11 +1675,11 @@ def test_plot(self, long_df, repeated_df): p.plot(ax, {}) x_str = long_df["x"].astype(str) - p = rel._LinePlotter(x="x", y="y", hue=x_str, data=long_df) + p = _LinePlotter(x="x", y="y", hue=x_str, data=long_df) ax.clear() p.plot(ax, {}) - p = rel._LinePlotter(x="x", y="y", size=x_str, data=long_df) + p = _LinePlotter(x="x", y="y", size=x_str, data=long_df) ax.clear() p.plot(ax, {}) @@ -1232,7 +1687,7 @@ def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter(x="x", y="y", data=long_df) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -1248,80 +1703,90 @@ def test_lineplot_axes(self, wide_df): f1, ax1 = plt.subplots() f2, ax2 = plt.subplots() - ax = rel.lineplot(data=wide_df) + ax = lineplot(data=wide_df) assert ax is ax2 - ax = rel.lineplot(data=wide_df, ax=ax1) + ax = lineplot(data=wide_df, ax=ax1) assert ax is ax1 - def test_lineplot_smoke(self, flat_array, flat_series, - wide_array, wide_list, wide_list_of_series, - wide_df, long_df, missing_df): + def test_lineplot_smoke( + self, + wide_df, wide_array, + wide_list_of_series, wide_list_of_arrays, wide_list_of_lists, + flat_array, flat_series, flat_list, + long_df, missing_df + ): f, ax = plt.subplots() - rel.lineplot([], []) + lineplot([], []) + ax.clear() + + lineplot(data=wide_df) + ax.clear() + + lineplot(data=wide_array) ax.clear() - rel.lineplot(data=flat_array) + lineplot(data=wide_list_of_series) ax.clear() - rel.lineplot(data=flat_series) + lineplot(data=wide_list_of_arrays) ax.clear() - rel.lineplot(data=wide_array) + lineplot(data=wide_list_of_lists) ax.clear() - rel.lineplot(data=wide_list) + lineplot(data=flat_series) ax.clear() - rel.lineplot(data=wide_list_of_series) + lineplot(data=flat_array) ax.clear() - rel.lineplot(data=wide_df) + lineplot(data=flat_list) ax.clear() - rel.lineplot(x="x", y="y", data=long_df) + lineplot(x="x", y="y", data=long_df) ax.clear() - rel.lineplot(x=long_df.x, y=long_df.y) + lineplot(x=long_df.x, y=long_df.y) ax.clear() - rel.lineplot(x=long_df.x, y="y", data=long_df) + lineplot(x=long_df.x, y="y", data=long_df) ax.clear() - rel.lineplot(x="x", y=long_df.y.values, data=long_df) + lineplot(x="x", y=long_df.y.values, data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", data=long_df) + lineplot(x="x", y="y", hue="a", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="a", data=long_df) + lineplot(x="x", y="y", hue="a", style="a", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="b", data=long_df) + lineplot(x="x", y="y", hue="a", style="b", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="a", data=missing_df) + lineplot(x="x", y="y", hue="a", style="a", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="b", data=missing_df) + lineplot(x="x", y="y", hue="a", style="b", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="a", data=long_df) + lineplot(x="x", y="y", hue="a", size="a", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="s", data=long_df) + lineplot(x="x", y="y", hue="a", size="s", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="a", data=missing_df) + lineplot(x="x", y="y", hue="a", size="a", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="s", data=missing_df) + lineplot(x="x", y="y", hue="a", size="s", data=missing_df) ax.clear() -class TestScatterPlotter(TestRelationalPlotter): +class TestScatterPlotter(Helpers): def test_legend_data(self, long_df): @@ -1333,7 +1798,7 @@ def test_legend_data(self, long_df): f, ax = plt.subplots() - p = rel._ScatterPlotter(x="x", y="y", data=long_df, legend="full") + p = _ScatterPlotter(x="x", y="y", data=long_df, legend="full") p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert handles == [] @@ -1341,8 +1806,9 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", data=long_df, - legend="full") + p = _ScatterPlotter( + x="x", y="y", hue="a", data=long_df, legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] @@ -1353,8 +1819,10 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df) + p = _ScatterPlotter( + x="x", y="y", hue="a", style="a", + markers=True, legend="full", data=long_df + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] @@ -1368,8 +1836,10 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df) + p = _ScatterPlotter( + x="x", y="y", hue="a", style="b", + markers=True, legend="full", data=long_df + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] @@ -1385,8 +1855,9 @@ def test_legend_data(self, long_df): # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", size="a", - data=long_df, legend="full") + p = _ScatterPlotter( + x="x", y="y", hue="a", size="a", data=long_df, legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] @@ -1401,8 +1872,10 @@ def test_legend_data(self, long_df): ax.clear() sizes_list = [10, 100, 200] - p = rel._ScatterPlotter(x="x", y="y", size="s", sizes=sizes_list, - data=long_df, legend="full") + p = _ScatterPlotter( + x="x", y="y", size="s", data=long_df, + legend="full", sizes=sizes_list, + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() sizes = [h.get_sizes()[0] for h in handles] @@ -1414,8 +1887,10 @@ def test_legend_data(self, long_df): ax.clear() sizes_dict = {2: 10, 4: 100, 8: 200} - p = rel._ScatterPlotter(x="x", y="y", size="s", sizes=sizes_dict, - data=long_df, legend="full") + p = _ScatterPlotter( + x="x", y="y", size="s", sizes=sizes_dict, + data=long_df, legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() sizes = [h.get_sizes()[0] for h in handles] @@ -1428,7 +1903,7 @@ def test_legend_data(self, long_df): x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = rel._ScatterPlotter(x=x, y=y, hue=z) + p = _ScatterPlotter(x=x, y=y, hue=z) ax.clear() p.legend = "full" @@ -1442,7 +1917,7 @@ def test_legend_data(self, long_df): handles, labels = ax.get_legend_handles_labels() assert len(labels) == 4 - p = rel._ScatterPlotter(x=x, y=y, size=z) + p = _ScatterPlotter(x=x, y=y, size=z) ax.clear() p.legend = "full" @@ -1465,11 +1940,11 @@ def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() - p = rel._ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(x="x", y="y", data=long_df) p.plot(ax, {}) points = ax.collections[0] - assert np.array_equal(points.get_offsets(), long_df[["x", "y"]].values) + assert_array_equal(points.get_offsets(), long_df[["x", "y"]].values) ax.clear() p.plot(ax, {"color": "k", "label": "test"}) @@ -1477,7 +1952,7 @@ def test_plot(self, long_df, repeated_df): assert self.colors_equal(points.get_facecolor(), "k") assert points.get_label() == "test" - p = rel._ScatterPlotter(x="x", y="y", hue="a", data=long_df) + p = _ScatterPlotter(x="x", y="y", hue="a", data=long_df) ax.clear() p.plot(ax, {}) @@ -1485,8 +1960,9 @@ def test_plot(self, long_df, repeated_df): expected_colors = [p.palette[k] for k in p.plot_data["hue"]] assert self.colors_equal(points.get_facecolors(), expected_colors) - p = rel._ScatterPlotter(x="x", y="y", style="c", - markers=["+", "x"], data=long_df) + p = _ScatterPlotter( + x="x", y="y", style="c", markers=["+", "x"], data=long_df + ) ax.clear() color = (1, .3, .8) @@ -1494,16 +1970,17 @@ def test_plot(self, long_df, repeated_df): points = ax.collections[0] assert self.colors_equal(points.get_edgecolors(), [color]) - p = rel._ScatterPlotter(x="x", y="y", size="a", data=long_df) + p = _ScatterPlotter(x="x", y="y", size="a", data=long_df) ax.clear() p.plot(ax, {}) points = ax.collections[0] expected_sizes = [p.size_lookup(k) for k in p.plot_data["size"]] - assert np.array_equal(points.get_sizes(), expected_sizes) + assert_array_equal(points.get_sizes(), expected_sizes) - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="a", - markers=True, data=long_df) + p = _ScatterPlotter( + x="x", y="y", hue="a", style="a", markers=True, data=long_df + ) ax.clear() p.plot(ax, {}) @@ -1512,8 +1989,9 @@ def test_plot(self, long_df, repeated_df): assert self.colors_equal(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="b", - markers=True, data=long_df) + p = _ScatterPlotter( + x="x", y="y", hue="a", style="b", markers=True, data=long_df + ) ax.clear() p.plot(ax, {}) @@ -1523,11 +2001,11 @@ def test_plot(self, long_df, repeated_df): assert self.paths_equal(points.get_paths(), expected_paths) x_str = long_df["x"].astype(str) - p = rel._ScatterPlotter(x="x", y="y", hue=x_str, data=long_df) + p = _ScatterPlotter(x="x", y="y", hue=x_str, data=long_df) ax.clear() p.plot(ax, {}) - p = rel._ScatterPlotter(x="x", y="y", size=x_str, data=long_df) + p = _ScatterPlotter(x="x", y="y", size=x_str, data=long_df) ax.clear() p.plot(ax, {}) @@ -1535,7 +2013,7 @@ def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = rel._ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(x="x", y="y", data=long_df) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -1551,226 +2029,84 @@ def test_scatterplot_axes(self, wide_df): f1, ax1 = plt.subplots() f2, ax2 = plt.subplots() - ax = rel.scatterplot(data=wide_df) + ax = scatterplot(data=wide_df) assert ax is ax2 - ax = rel.scatterplot(data=wide_df, ax=ax1) + ax = scatterplot(data=wide_df, ax=ax1) assert ax is ax1 - def test_scatterplot_smoke(self, flat_array, flat_series, - wide_array, wide_list, wide_list_of_series, - wide_df, long_df, missing_df): + def test_scatterplot_smoke( + self, + wide_df, wide_array, + flat_series, flat_array, flat_list, + wide_list_of_series, wide_list_of_arrays, wide_list_of_lists, + long_df, missing_df + ): f, ax = plt.subplots() - rel.scatterplot([], []) + scatterplot([], []) ax.clear() - rel.scatterplot(data=flat_array) + scatterplot(data=wide_df) ax.clear() - rel.scatterplot(data=flat_series) + scatterplot(data=wide_array) ax.clear() - rel.scatterplot(data=wide_array) + scatterplot(data=wide_list_of_series) ax.clear() - rel.scatterplot(data=wide_list) + scatterplot(data=wide_list_of_arrays) ax.clear() - rel.scatterplot(data=wide_list_of_series) + scatterplot(data=wide_list_of_lists) ax.clear() - rel.scatterplot(data=wide_df) + scatterplot(data=flat_series) ax.clear() - rel.scatterplot(x="x", y="y", data=long_df) + scatterplot(data=flat_array) ax.clear() - rel.scatterplot(x=long_df.x, y=long_df.y) + scatterplot(data=flat_list) ax.clear() - rel.scatterplot(x=long_df.x, y="y", data=long_df) + scatterplot(x="x", y="y", data=long_df) ax.clear() - rel.scatterplot(x="x", y=long_df.y.values, data=long_df) + scatterplot(x=long_df.x, y=long_df.y) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", data=long_df) + scatterplot(x=long_df.x, y="y", data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", style="a", data=long_df) + scatterplot(x="x", y=long_df.y.values, data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", style="b", data=long_df) + scatterplot(x="x", y="y", hue="a", data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", style="a", data=missing_df) + scatterplot(x="x", y="y", hue="a", style="a", data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", style="b", data=missing_df) + scatterplot(x="x", y="y", hue="a", style="b", data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", size="a", data=long_df) + scatterplot(x="x", y="y", hue="a", style="a", data=missing_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", size="s", data=long_df) + scatterplot(x="x", y="y", hue="a", style="b", data=missing_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", size="a", data=missing_df) + scatterplot(x="x", y="y", hue="a", size="a", data=long_df) ax.clear() - rel.scatterplot(x="x", y="y", hue="a", size="s", data=missing_df) + scatterplot(x="x", y="y", hue="a", size="s", data=long_df) ax.clear() + scatterplot(x="x", y="y", hue="a", size="a", data=missing_df) + ax.clear() -class TestRelPlotter(TestRelationalPlotter): - - def test_relplot_simple(self, long_df): - - g = rel.relplot(x="x", y="y", kind="scatter", data=long_df) - x, y = g.ax.collections[0].get_offsets().T - assert np.array_equal(x, long_df["x"]) - assert np.array_equal(y, long_df["y"]) - - g = rel.relplot(x="x", y="y", kind="line", data=long_df) - x, y = g.ax.lines[0].get_xydata().T - expected = long_df.groupby("x").y.mean() - assert np.array_equal(x, expected.index) - assert y == pytest.approx(expected.values) - - with pytest.raises(ValueError): - g = rel.relplot(x="x", y="y", kind="not_a_kind", data=long_df) - - def test_relplot_complex(self, long_df): - - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", data=long_df, **{sem: "a"}) - x, y = g.ax.collections[0].get_offsets().T - assert np.array_equal(x, long_df["x"]) - assert np.array_equal(y, long_df["y"]) - - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", col="c", data=long_df, - **{sem: "a"}) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) - - for sem in ["size", "style"]: - g = rel.relplot(x="x", y="y", hue="b", col="c", data=long_df, - **{sem: "a"}) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) - - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", col="b", row="c", - data=long_df.sort_values(["c", "b"]), - **{sem: "a"}) - grouped = long_df.groupby(["c", "b"]) - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) - - def test_relplot_hues(self, long_df): - - palette = ["r", "b", "g"] - g = rel.relplot(x="x", y="y", hue="a", style="b", col="c", - palette=palette, data=long_df) - - palette = dict(zip(long_df["a"].unique(), palette)) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_hues = [palette[val] for val in grp_df["a"]] - assert self.colors_equal(points.get_facecolors(), expected_hues) - - def test_relplot_sizes(self, long_df): - - sizes = [5, 12, 7] - g = rel.relplot(x="x", y="y", size="a", hue="b", col="c", - sizes=sizes, data=long_df) - - sizes = dict(zip(long_df["a"].unique(), sizes)) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_sizes = [sizes[val] for val in grp_df["a"]] - assert np.array_equal(points.get_sizes(), expected_sizes) - - def test_relplot_styles(self, long_df): - - markers = ["o", "d", "s"] - g = rel.relplot(x="x", y="y", style="a", hue="b", col="c", - markers=markers, data=long_df) - - paths = [] - for m in markers: - m = mpl.markers.MarkerStyle(m) - paths.append(m.get_path().transformed(m.get_transform())) - paths = dict(zip(long_df["a"].unique(), paths)) - - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_paths = [paths[val] for val in grp_df["a"]] - assert self.paths_equal(points.get_paths(), expected_paths) - - def test_relplot_stringy_numerics(self, long_df): - - long_df["x_str"] = long_df["x"].astype(str) - - g = rel.relplot(x="x", y="y", hue="x_str", data=long_df) - points = g.ax.collections[0] - xys = points.get_offsets() - mask = np.ma.getmask(xys) - assert not mask.any() - assert np.array_equal(xys, long_df[["x", "y"]]) - - g = rel.relplot(x="x", y="y", size="x_str", data=long_df) - points = g.ax.collections[0] - xys = points.get_offsets() - mask = np.ma.getmask(xys) - assert not mask.any() - assert np.array_equal(xys, long_df[["x", "y"]]) - - def test_relplot_legend(self, long_df): - - g = rel.relplot(x="x", y="y", data=long_df) - assert g._legend is None - - g = rel.relplot(x="x", y="y", hue="a", data=long_df) - texts = [t.get_text() for t in g._legend.texts] - expected_texts = np.append(["a"], long_df["a"].unique()) - assert np.array_equal(texts, expected_texts) - - g = rel.relplot(x="x", y="y", hue="s", size="s", data=long_df) - texts = [t.get_text() for t in g._legend.texts] - assert np.array_equal(texts[1:], np.sort(texts[1:])) - - g = rel.relplot(x="x", y="y", hue="a", legend=False, data=long_df) - assert g._legend is None - - palette = color_palette("deep", len(long_df["b"].unique())) - a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique())) - long_df["a_like_b"] = long_df["a"].map(a_like_b) - g = rel.relplot(x="x", y="y", hue="b", style="a_like_b", - palette=palette, kind="line", estimator=None, - data=long_df) - lines = g._legend.get_lines()[1:] # Chop off title dummy - for line, color in zip(lines, palette): - assert line.get_color() == color - - def test_ax_kwarg_removal(self, long_df): - - f, ax = plt.subplots() - with pytest.warns(UserWarning): - g = rel.relplot("x", "y", data=long_df, ax=ax) - assert len(ax.collections) == 0 - assert len(g.ax.collections) > 0 + scatterplot(x="x", y="y", hue="a", size="s", data=missing_df) + ax.clear()