Skip to content

Commit

Permalink
Dont add null columns in plot_data for unassigned semantics (#2148)
Browse files Browse the repository at this point in the history
* Dont assign null columns in plot_data for unassigned semantics

* Fix comment wording

[ci skip]
  • Loading branch information
mwaskom committed Jun 28, 2020
1 parent d1bbef1 commit a4a9542
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 60 deletions.
41 changes: 26 additions & 15 deletions seaborn/_core.py
Expand Up @@ -90,7 +90,7 @@ def __init__(
"""
super().__init__(plotter)

data = plotter.plot_data["hue"]
data = plotter.plot_data.get("hue", pd.Series(dtype=float))

if data.notna().any():

Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(
"""
super().__init__(plotter)

data = plotter.plot_data["size"]
data = plotter.plot_data.get("size", pd.Series(dtype=float))

if data.notna().any():

Expand Down Expand Up @@ -495,7 +495,7 @@ def __init__(
"""
super().__init__(plotter)

data = plotter.plot_data["style"]
data = plotter.plot_data.get("style", pd.Series(dtype=float))

if data.notna().any():

Expand Down Expand Up @@ -602,19 +602,29 @@ def __init__(self, data=None, variables={}):
self.assign_variables(data, variables)

for var, cls in self._semantic_mappings.items():
if var in self.semantics:

# Create the mapping function
map_func = partial(cls.map, plotter=self)
setattr(self, f"map_{var}", map_func)
# Create the mapping function
map_func = partial(cls.map, plotter=self)
setattr(self, f"map_{var}", map_func)

# Call the mapping function to initialize with default values
getattr(self, f"map_{var}")()
# Call the mapping function to initialize with default values
getattr(self, f"map_{var}")()

@classmethod
def get_semantics(cls, kwargs):
def get_semantics(cls, kwargs, semantics=None):
"""Subset a dictionary` arguments with known semantic variables."""
return {k: kwargs[k] for k in cls.semantics}
if semantics is None:
semantics = cls.semantics
variables = {}
for key, val in kwargs.items():
if key in semantics and val is not None:
variables[key] = val
return variables

@property
def has_xy_data(self):
"""Return True at least one of x or y is defined."""
return bool({"x", "y"} & set(self.variables))

def assign_variables(self, data=None, variables={}):
"""Define plot variables, optionally using lookup from `data`."""
Expand Down Expand Up @@ -685,7 +695,7 @@ def _assign_variables_wideform(self, data=None, **kwargs):
if empty:

# Make an object with the structure of plot_data, but empty
plot_data = pd.DataFrame(columns=self.semantics)
plot_data = pd.DataFrame()
variables = {}

elif flat:
Expand All @@ -708,7 +718,7 @@ def _assign_variables_wideform(self, data=None, **kwargs):
plot_data[var] = getattr(flat_data, attr)
variables[var] = names[self.flat_structure[var]]

plot_data = pd.DataFrame(plot_data).reindex(columns=self.semantics)
plot_data = pd.DataFrame(plot_data)

else:

Expand Down Expand Up @@ -752,7 +762,6 @@ def _assign_variables_wideform(self, data=None, **kwargs):
# Assign names corresponding to plot semantics
for var, attr in self.wide_structure.items():
plot_data[var] = plot_data[attr]
plot_data = plot_data.reindex(columns=self.semantics)

# Define the variable names
variables = {}
Expand Down Expand Up @@ -838,7 +847,7 @@ def _assign_variables_longform(self, data=None, **kwargs):

# Construct a tidy plot DataFrame. This will convert a number of
# types automatically, aligning on index in case of pandas objects
plot_data = pd.DataFrame(plot_data, columns=self.semantics)
plot_data = pd.DataFrame(plot_data)

# Reduce the variables dictionary to fields with valid data
variables = {
Expand Down Expand Up @@ -930,6 +939,8 @@ def comp_data(self):

comp_data = self.plot_data.copy(deep=False)
for var in "xy":
if var not in self.variables:
continue
axis = getattr(self.ax, f"{var}axis")
comp_var = axis.convert_units(self.plot_data[var])
if axis.get_scale() == "log":
Expand Down
22 changes: 18 additions & 4 deletions seaborn/relational.py
Expand Up @@ -293,11 +293,19 @@ def plot(self, ax, kws):
):

if self.sort:
sub_data = sub_data.sort_values(["units", "x", "y"])
sort_vars = ["units", "x", "y"]
sort_cols = [var for var in sort_vars if var in self.variables]
sub_data = sub_data.sort_values(sort_cols)

x = sub_data["x"]
y = sub_data["y"]
u = sub_data["units"]
# Due to the original design, code below was written assuming that
# sub_data always has x, y, and units columns, which may be empty.
# Adding this here to avoid otherwise disruptive changes, but it
# could get removed if the rest of the logic is sorted out
null = pd.Series(index=sub_data.index, dtype=float)

x = sub_data.get("x", null)
y = sub_data.get("y", null)
u = sub_data.get("units", null)

if self.estimator is not None:
if "units" in self.variables:
Expand Down Expand Up @@ -643,6 +651,9 @@ def lineplot(
if ax is None:
ax = plt.gca()

if not p.has_xy_data:
return ax

p._attach(ax)

p.plot(ax, kwargs)
Expand Down Expand Up @@ -924,6 +935,9 @@ def scatterplot(
if ax is None:
ax = plt.gca()

if not p.has_xy_data:
return ax

p._attach(ax)

p.plot(ax, kwargs)
Expand Down
41 changes: 0 additions & 41 deletions seaborn/tests/test_relational.py
Expand Up @@ -102,8 +102,6 @@ def test_wide_df_variables(self, wide_df):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] == wide_df.index.name
assert p.variables["y"] is None
assert p.variables["hue"] == wide_df.columns.name
Expand Down Expand Up @@ -138,8 +136,6 @@ def test_wide_df_with_nonnumeric_variables(self, long_df):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] == numeric_df.index.name
assert p.variables["y"] is None
assert p.variables["hue"] == numeric_df.columns.name
Expand Down Expand Up @@ -171,8 +167,6 @@ def test_wide_array_variables(self, wide_array):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand All @@ -194,10 +188,6 @@ def test_flat_array_variables(self, flat_array):
expected_y = flat_array
assert_array_equal(y, expected_y)

assert p.plot_data["hue"].isnull().all()
assert p.plot_data["style"].isnull().all()
assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None

Expand All @@ -217,10 +207,6 @@ def test_flat_list_variables(self, flat_list):
expected_y = flat_list
assert_array_equal(y, expected_y)

assert p.plot_data["hue"].isnull().all()
assert p.plot_data["style"].isnull().all()
assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None

Expand Down Expand Up @@ -278,8 +264,6 @@ def test_wide_list_of_series_variables(self, wide_list_of_series):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand Down Expand Up @@ -313,8 +297,6 @@ def test_wide_list_of_arrays_variables(self, wide_list_of_arrays):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand Down Expand Up @@ -348,8 +330,6 @@ def test_wide_list_of_list_variables(self, wide_list_of_lists):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand Down Expand Up @@ -383,8 +363,6 @@ def test_wide_dict_of_series_variables(self, wide_dict_of_series):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand Down Expand Up @@ -418,8 +396,6 @@ def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand Down Expand Up @@ -453,8 +429,6 @@ def test_wide_dict_of_lists_variables(self, wide_dict_of_lists):
expected_style = expected_hue
assert_array_equal(style, expected_style)

assert p.plot_data["size"].isnull().all()

assert p.variables["x"] is None
assert p.variables["y"] is None
assert p.variables["hue"] is None
Expand All @@ -469,9 +443,6 @@ def test_long_df(self, long_df, long_semantics):
for key, val in long_semantics.items():
assert_array_equal(p.plot_data[key], long_df[val])

for col in set(p.semantics) - set(long_semantics):
assert p.plot_data[col].isnull().all()

def test_long_df_with_index(self, long_df, long_semantics):

p = _RelationalPlotter(
Expand All @@ -484,9 +455,6 @@ def test_long_df_with_index(self, long_df, long_semantics):
for key, val in long_semantics.items():
assert_array_equal(p.plot_data[key], long_df[val])

for col in set(p.semantics) - set(long_semantics):
assert p.plot_data[col].isnull().all()

def test_long_df_with_multiindex(self, long_df, long_semantics):

p = _RelationalPlotter(
Expand All @@ -499,9 +467,6 @@ def test_long_df_with_multiindex(self, long_df, long_semantics):
for key, val in long_semantics.items():
assert_array_equal(p.plot_data[key], long_df[val])

for col in set(p.semantics) - set(long_semantics):
assert p.plot_data[col].isnull().all()

def test_long_dict(self, long_dict, long_semantics):

p = _RelationalPlotter(
Expand All @@ -514,9 +479,6 @@ def test_long_dict(self, long_dict, long_semantics):
for key, val in long_semantics.items():
assert_array_equal(p.plot_data[key], pd.Series(long_dict[val]))

for col in set(p.semantics) - set(long_semantics):
assert p.plot_data[col].isnull().all()

@pytest.mark.parametrize(
"vector_type",
["series", "numpy", "list"],
Expand Down Expand Up @@ -547,9 +509,6 @@ def test_long_vectors(self, long_df, long_semantics, vector_type):
for key, val in long_semantics.items():
assert_array_equal(p.plot_data[key], long_df[val])

for col in set(p.semantics) - set(long_semantics):
assert p.plot_data[col].isnull().all()

def test_long_undefined_variables(self, long_df):

p = _RelationalPlotter()
Expand Down

0 comments on commit a4a9542

Please sign in to comment.