diff --git a/plotnine/scales/scale.py b/plotnine/scales/scale.py index 8a3128715f..802b308cbe 100644 --- a/plotnine/scales/scale.py +++ b/plotnine/scales/scale.py @@ -718,8 +718,26 @@ def get_labels(self, breaks=None): return labels +class scale_continuous_fixed_transformation(scale_continuous): + """Base class/mixin for scales that have one fixed transformation, + which may not be overwritten by the user + e.g. scale_x_log10 + """ + + def __init__(self, *args, **kwargs): + if 'trans' in kwargs and kwargs['trans'] != self._trans: + name = self.__class__.__name__ + xy = 'x' if '_x_' in name else 'y' + raise ValueError(("%s is fixed to transformation %s " + "- use scale_%s_continuous instead") % + (name, + self._trans, + xy)) + super().__init__(*args, **kwargs) + + @document -class scale_datetime(scale_continuous): +class scale_datetime(scale_continuous_fixed_transformation): """ Base class for all date/datetime scales @@ -770,4 +788,4 @@ def __init__(self, **kwargs): minor_breaks_fmt = kwargs.pop('date_minor_breaks') kwargs['minor_breaks'] = date_breaks(minor_breaks_fmt) - scale_continuous.__init__(self, **kwargs) + super().__init__(**kwargs) diff --git a/plotnine/scales/scale_xy.py b/plotnine/scales/scale_xy.py index eb38084a89..d9af94667a 100644 --- a/plotnine/scales/scale_xy.py +++ b/plotnine/scales/scale_xy.py @@ -6,7 +6,8 @@ from ..utils import identity, match, alias, array_kind from ..exceptions import PlotnineError from .range import RangeContinuous -from .scale import scale_discrete, scale_continuous, scale_datetime +from .scale import (scale_discrete, scale_continuous, scale_datetime, + scale_continuous_fixed_transformation) # positions scales have a couple of differences (quirks) that @@ -225,7 +226,9 @@ class scale_y_datetime(scale_datetime, scale_y_continuous): @document -class scale_x_timedelta(scale_x_continuous): +class scale_x_timedelta( + scale_continuous_fixed_transformation, + scale_x_continuous): """ Continuous x position for timedelta data points @@ -237,7 +240,9 @@ class scale_x_timedelta(scale_x_continuous): @document -class scale_y_timedelta(scale_y_continuous): +class scale_y_timedelta( + scale_continuous_fixed_transformation, + scale_y_continuous): """ Continuous y position for timedelta data points @@ -249,7 +254,9 @@ class scale_y_timedelta(scale_y_continuous): @document -class scale_x_sqrt(scale_x_continuous): +class scale_x_sqrt( + scale_continuous_fixed_transformation, + scale_x_continuous): """ Continuous x position sqrt transformed scale @@ -261,7 +268,9 @@ class scale_x_sqrt(scale_x_continuous): @document -class scale_y_sqrt(scale_y_continuous): +class scale_y_sqrt( + scale_continuous_fixed_transformation, + scale_y_continuous): """ Continuous y position sqrt transformed scale @@ -273,7 +282,9 @@ class scale_y_sqrt(scale_y_continuous): @document -class scale_x_log10(scale_x_continuous): +class scale_x_log10( + scale_continuous_fixed_transformation, + scale_x_continuous): """ Continuous x position log10 transformed scale @@ -285,7 +296,9 @@ class scale_x_log10(scale_x_continuous): @document -class scale_y_log10(scale_y_continuous): +class scale_y_log10( + scale_continuous_fixed_transformation, + scale_y_continuous): """ Continuous y position log10 transformed scale @@ -297,7 +310,9 @@ class scale_y_log10(scale_y_continuous): @document -class scale_x_reverse(scale_x_continuous): +class scale_x_reverse( + scale_continuous_fixed_transformation, + scale_x_continuous): """ Continuous x position reverse transformed scale @@ -309,7 +324,9 @@ class scale_x_reverse(scale_x_continuous): @document -class scale_y_reverse(scale_y_continuous): +class scale_y_reverse( + scale_continuous_fixed_transformation, + scale_y_continuous): """ Continuous y position reverse transformed scale diff --git a/plotnine/tests/test_scale_internals.py b/plotnine/tests/test_scale_internals.py index 03a6f4deb6..d2e5890923 100644 --- a/plotnine/tests/test_scale_internals.py +++ b/plotnine/tests/test_scale_internals.py @@ -536,3 +536,34 @@ def test_legend_ordering_added_scales(): ) assert p + _theme == 'legend_ordering_added_scales' + + +def test_trans_scales_raise_on_passing_trans(): + # superfluous, but ok + scale_xy.scale_x_log10(trans='log10') + scale_xy.scale_x_reverse(trans='reverse') + scale_xy.scale_x_datetime(trans='datetime') + scale_xy.scale_x_timedelta(trans='pd_timedelta') + + with pytest.raises(ValueError): + scale_xy.scale_x_log10(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_x_reverse(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_x_datetime(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_x_timedelta(trans='identity') + + scale_xy.scale_y_log10(trans='log10') + scale_xy.scale_y_reverse(trans='reverse') + scale_xy.scale_y_datetime(trans='datetime') + scale_xy.scale_y_timedelta(trans='pd_timedelta') + + with pytest.raises(ValueError): + scale_xy.scale_y_log10(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_y_reverse(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_y_datetime(trans='identity') + with pytest.raises(ValueError): + scale_xy.scale_y_timedelta(trans='identity')