Skip to content

Commit

Permalink
Add Norm move class (#2827)
Browse files Browse the repository at this point in the history
* Add Norm move class

* Add percent option
  • Loading branch information
mwaskom committed Jun 2, 2022
1 parent fefd940 commit 1e8e843
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 14 deletions.
54 changes: 47 additions & 7 deletions seaborn/_core/moves.py
@@ -1,19 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Callable, Optional, Union

import numpy as np
from pandas import DataFrame

from seaborn._core.groupby import GroupBy

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional
from pandas import DataFrame


@dataclass
class Move:

group_by_orient: ClassVar[bool] = True

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
raise NotImplementedError

Expand Down Expand Up @@ -61,10 +60,12 @@ class Dodge(Move):
"""
Displacement and narrowing of overlapping marks along orientation axis.
"""
empty: str = "keep" # keep, drop, fill
empty: str = "keep" # Options: keep, drop, fill
gap: float = 0

# TODO accept just a str here?
# TODO should this always be present?
# TODO should the default be an "all" singleton?
by: Optional[list[str]] = None

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
Expand Down Expand Up @@ -117,7 +118,7 @@ class Stack(Move):
"""
Displacement of overlapping bar or area marks along the value axis.
"""
# TODO center? (or should this be a different move?)
# TODO center? (or should this be a different move, eg. Stream())

def _stack(self, df, orient):

Expand All @@ -140,6 +141,7 @@ def _stack(self, df, orient):
def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:

# TODO where to ensure that other semantic variables are sorted properly?
# TODO why are we not using the passed in groupby here?
groupers = ["col", "row", orient]
return GroupBy(groupers).apply(data, self._stack, orient)

Expand All @@ -158,3 +160,41 @@ def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
data["x"] = data["x"] + self.x
data["y"] = data["y"] + self.y
return data


@dataclass
class Norm(Move):
"""
Divisive scaling on the value axis after aggregating within groups.
"""

func: Union[Callable, str] = "max"
where: Optional[str] = None
by: Optional[list[str]] = None
percent: bool = False

group_by_orient: ClassVar[bool] = False

def _norm(self, df, var):

if self.where is None:
denom_data = df[var]
else:
denom_data = df.query(self.where)[var]
df[var] = df[var] / denom_data.agg(self.func)

if self.percent:
df[var] = df[var] * 100

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:

other = {"x": "y", "y": "x"}[orient]
return groupby.apply(data, self._norm, other)


# TODO
# @dataclass
# class Ridge(Move):
# ...
11 changes: 6 additions & 5 deletions seaborn/_core/plot.py
Expand Up @@ -1117,11 +1117,12 @@ def get_order(var):
if move is not None:
moves = move if isinstance(move, list) else [move]
for move_step in moves:
move_groupers = [
orient,
*(getattr(move_step, "by", None) or grouping_properties),
*default_grouping_vars,
]
move_by = getattr(move_step, "by", None)
if move_by is None:
move_by = grouping_properties
move_groupers = [*move_by, *default_grouping_vars]
if move_step.group_by_orient:
move_groupers.insert(0, orient)
order = {var: get_order(var) for var in move_groupers}
groupby = GroupBy(order)
df = move_step(df, groupby, orient)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/objects.py
Expand Up @@ -14,6 +14,6 @@
from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401
from seaborn._stats.histograms import Hist # noqa: F401

from seaborn._core.moves import Dodge, Jitter, Shift, Stack # noqa: F401
from seaborn._core.moves import Dodge, Jitter, Norm, Shift, Stack # noqa: F401

from seaborn._core.scales import Nominal, Continuous, Temporal # noqa: F401
40 changes: 39 additions & 1 deletion seaborn/tests/_core/test_moves.py
Expand Up @@ -6,7 +6,7 @@
from pandas.testing import assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.moves import Dodge, Jitter, Shift, Stack
from seaborn._core.moves import Dodge, Jitter, Shift, Stack, Norm
from seaborn._core.rules import categorical_order
from seaborn._core.groupby import GroupBy

Expand Down Expand Up @@ -318,3 +318,41 @@ def test_moves(self, toy_df, x, y):
res = Shift(x=x, y=y)(toy_df, gb, "x")
assert_array_equal(res["x"], toy_df["x"] + x)
assert_array_equal(res["y"], toy_df["y"] + y)


class TestNorm(MoveFixtures):

@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_no_groups(self, df, orient):

other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["null"])
res = Norm()(df, gb, orient)
assert res[other].max() == pytest.approx(1)

@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_groups(self, df, orient):

other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["grp2"])
res = Norm()(df, gb, orient)
for _, grp in res.groupby("grp2"):
assert grp[other].max() == pytest.approx(1)

def test_sum(self, df):

gb = GroupBy(["null"])
res = Norm("sum")(df, gb, "x")
assert res["y"].sum() == pytest.approx(1)

def test_where(self, df):

gb = GroupBy(["null"])
res = Norm(where="x == 2")(df, gb, "x")
assert res.loc[res["x"] == 2, "y"].max() == pytest.approx(1)

def test_percent(self, df):

gb = GroupBy(["null"])
res = Norm(percent=True)(df, gb, "x")
assert res["y"].max() == pytest.approx(100)

0 comments on commit 1e8e843

Please sign in to comment.