Skip to content

Commit

Permalink
Logit scale
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabio Zanini committed Mar 2, 2015
1 parent ee086de commit 76840ea
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
('color', 'Color'),
('text_labels_and_annotations', 'Text, labels, and annotations'),
('ticks_and_spines', 'Ticks and spines'),
('scales', 'Axis scales'),
('subplots_axes_and_figures', 'Subplots, axes, and figures'),
('style_sheets', 'Style sheets'),
('specialty_plots', 'Specialty plots'),
Expand Down
43 changes: 43 additions & 0 deletions doc/pyplots/pyplot_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import matplotlib.pyplot as plt

# make up some data in the interval ]0, 1[
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
y = y[(y > 0) & (y < 1)]
y.sort()
x = np.arange(len(y))

# plot with various axes scales
plt.figure(1)

# linear
plt.subplot(221)
plt.plot(x, y)
plt.yscale('linear')
plt.title('linear')
plt.grid(True)


# log
plt.subplot(222)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)


# symmetric log
plt.subplot(223)
plt.plot(x, y - y.mean())
plt.yscale('symlog', linthreshy=0.05)
plt.title('symlog')
plt.grid(True)

# logit
plt.subplot(223)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)

plt.show()
19 changes: 19 additions & 0 deletions doc/users/pyplot_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,22 @@ variety of other coordinate systems one can choose -- see
:ref:`annotations-tutorial` and :ref:`plotting-guide-annotation` for
details. More examples can be found in
:ref:`pylab_examples-annotation_demo`.


Logarithmic and other nonlinear axis
====================================

:mod:`matplotlib.pyplot` supports not only linear axis scales, but also
logarithmic and logit scales. This is commonly used if data spans many orders
of magnitude. Changing the scale of an axis is easy:

plt.xscale('log')

An example of four plots with the same data and different scales for the y axis
is shown below.

.. plot:: pyplots/pyplot_scales.py
:include-source:

It is also possible to add your own scale, see :ref:`adding-new-scales` for
details.
4 changes: 4 additions & 0 deletions doc/users/whats_new/updated_scale.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Logit Scale
-----------
Added support for the 'logit' axis scale, a nonlinear transformation
`x -> log10(x / (1-x))` for data between 0 and 1 excluded.
47 changes: 47 additions & 0 deletions examples/scales/scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Illustrate the scale transformations applied to axes, e.g. log, symlog, logit.
"""
import numpy as np
import matplotlib.pyplot as plt

# make up some data in the interval ]0, 1[
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
y = y[(y > 0) & (y < 1)]
y.sort()
x = np.arange(len(y))

# plot with various axes scales
fig, axs = plt.subplots(2, 2)

# linear
ax = axs[0, 0]
ax.plot(x, y)
ax.set_yscale('linear')
ax.set_title('linear')
ax.grid(True)


# log
ax = axs[0, 1]
ax.plot(x, y)
ax.set_yscale('log')
ax.set_title('log')
ax.grid(True)


# symmetric log
ax = axs[1, 0]
ax.plot(x, y - y.mean())
ax.set_yscale('symlog', linthreshy=0.05)
ax.set_title('symlog')
ax.grid(True)

# logit
ax = axs[1, 1]
ax.plot(x, y)
ax.set_yscale('logit')
ax.set_title('logit')
ax.grid(True)


plt.show()
107 changes: 104 additions & 3 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from matplotlib.cbook import dedent
from matplotlib.ticker import (NullFormatter, ScalarFormatter,
LogFormatterMathtext)
LogFormatterMathtext, LogitFormatter)
from matplotlib.ticker import (NullLocator, LogLocator, AutoLocator,
SymmetricalLogLocator)
SymmetricalLogLocator, LogitLocator)
from matplotlib.transforms import Transform, IdentityTransform
from matplotlib import docstring

Expand Down Expand Up @@ -478,10 +478,111 @@ def get_transform(self):
return self._transform


def _mask_non_logit(a):
"""
Return a Numpy masked array where all values outside ]0, 1[ are
masked. If all values are inside ]0, 1[, the original array is
returned.
"""
a = a.copy()
mask = (a <= 0.0) | (a >= 1.0)
a[mask] = np.nan
return a


def _clip_non_logit(a):
a = a.copy()
a[a <= 0.0] = 1e-300
a[a >= 1.0] = 1 - 1e-300
return a


class LogitTransform(Transform):
input_dims = 1
output_dims = 1
is_separable = True
has_inverse = True

def __init__(self, nonpos):
Transform.__init__(self)
if nonpos == 'mask':
self._handle_nonpos = _mask_non_logit
else:
self._handle_nonpos = _clip_non_logit
self._nonpos = nonpos

def transform_non_affine(self, a):
"""logit transform (base 10), masked or clipped"""
a = self._handle_nonpos(a)
if isinstance(a, ma.MaskedArray):
return ma.log10(1.0 * a / (1.0 - a))
return np.log10(1.0 * a / (1.0 - a))

def inverted(self):
return LogisticTransform(self._nonpos)


class LogisticTransform(Transform):
input_dims = 1
output_dims = 1
is_separable = True
has_inverse = True

def __init__(self, nonpos='mask'):
Transform.__init__(self)
self._nonpos = nonpos

def transform_non_affine(self, a):
"""logistic transform (base 10)"""
return 1.0 / (1 + 10**(-a))

def inverted(self):
return LogitTransform(self._nonpos)


class LogitScale(ScaleBase):
"""
Logit scale for data between zero and one, both excluded.
This scale is similar to a log scale close to zero and to one, and almost
linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[.
"""
name = 'logit'

def __init__(self, axis, nonpos='mask'):
"""
*nonpos*: ['mask' | 'clip' ]
values beyond ]0, 1[ can be masked as invalid, or clipped to a number
very close to 0 or 1
"""
if nonpos not in ['mask', 'clip']:
raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'")

self._transform = LogitTransform(nonpos)

def get_transform(self):
"""
Return a :class:`LogitTransform` instance.
"""
return self._transform

def set_default_locators_and_formatters(self, axis):
# ..., 0.01, 0.1, 0.5, 0.9, 0.99, ...
axis.set_major_locator(LogitLocator())
axis.set_major_formatter(LogitFormatter())
axis.set_minor_locator(LogitLocator(minor=True))
axis.set_minor_formatter(LogitFormatter())

def limit_range_for_scale(self, vmin, vmax, minpos):
return (vmin <= 0 and minpos or vmin,
vmax >= 1 and (1 - minpos) or vmax)


_scale_mapping = {
'linear': LinearScale,
'log': LogScale,
'symlog': SymmetricalLogScale
'symlog': SymmetricalLogScale,
'logit': LogitScale,
}


Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions lib/matplotlib/tests/test_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ def test_log_scales():
ax.axhline(24.1)


@image_comparison(baseline_images=['logit_scales'], remove_text=True,
extensions=['png'])
def test_logit_scales():
ax = plt.subplot(111, xscale='logit')

# Typical exctinction curve for logit
x = np.array([0.001, 0.003, 0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 0.97, 0.99, 0.997, 0.999])
y = 1.0 / x

ax.plot(x, y)
ax.grid(True)


@cleanup
def test_log_scatter():
"""Issue #1799"""
Expand Down
102 changes: 102 additions & 0 deletions lib/matplotlib/ticker.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,26 @@ def __call__(self, x, pos=None):
nearest_long(fx))


class LogitFormatter(Formatter):
'''Probability formatter (using Math text)'''
def __call__(self, x, pos=None):
s = ''
if 0.01 <= x <= 0.99:
if x in [.01, 0.1, 0.5, 0.9, 0.99]:
s = '{:.2f}'.format(x)
elif x < 0.01:
if is_decade(x):
s = '$10^{%.0f}$' % np.log10(x)
elif x > 0.99:
if is_decade(1-x):
s = '$1-10^{%.0f}$' % np.log10(1-x)
return s

def format_data_short(self, value):
'return a short formatted string representation of a number'
return '%-12g' % value


class EngFormatter(Formatter):
"""
Formats axis values using engineering prefixes to represent powers of 1000,
Expand Down Expand Up @@ -1694,6 +1714,88 @@ def view_limits(self, vmin, vmax):
return result


class LogitLocator(Locator):
"""
Determine the tick locations for logit axes
"""

def __init__(self, minor=False):
"""
place ticks on the logit locations
"""
self.minor = minor

def __call__(self):
'Return the locations of the ticks'
vmin, vmax = self.axis.get_view_interval()
return self.tick_values(vmin, vmax)

def tick_values(self, vmin, vmax):
# dummy axis has no axes attribute
if hasattr(self.axis, 'axes') and self.axis.axes.name == 'polar':
raise NotImplementedError('Polar axis cannot be logit scaled yet')

# what to do if a window beyond ]0, 1[ is chosen
if vmin <= 0.0:
if self.axis is not None:
vmin = self.axis.get_minpos()

if (vmin <= 0.0) or (not np.isfinite(vmin)):
raise ValueError(
"Data has no values in ]0, 1[ and therefore can not be "
"logit-scaled.")

# NOTE: for vmax, we should query a property similar to get_minpos, but
# related to the maximal, less-than-one data point. Unfortunately,
# get_minpos is defined very deep in the BBox and updated with data,
# so for now we use the trick below.
if vmax >= 1.0:
if self.axis is not None:
vmax = 1 - self.axis.get_minpos()

if (vmax >= 1.0) or (not np.isfinite(vmax)):
raise ValueError(
"Data has no values in ]0, 1[ and therefore can not be "
"logit-scaled.")

if vmax < vmin:
vmin, vmax = vmax, vmin

vmin = np.log10(vmin / (1 - vmin))
vmax = np.log10(vmax / (1 - vmax))

decade_min = np.floor(vmin)
decade_max = np.ceil(vmax)

# major ticks
if not self.minor:
ticklocs = []
if (decade_min <= -1):
expo = np.arange(decade_min, min(0, decade_max + 1))
ticklocs.extend(list(10**expo))
if (decade_min <= 0) and (decade_max >= 0):
ticklocs.append(0.5)
if (decade_max >= 1):
expo = -np.arange(max(1, decade_min), decade_max + 1)
ticklocs.extend(list(1 - 10**expo))

# minor ticks
else:
ticklocs = []
if (decade_min <= -2):
expo = np.arange(decade_min, min(-1, decade_max))
newticks = np.outer(np.arange(2, 10), 10**expo).ravel()
ticklocs.extend(list(newticks))
if (decade_min <= 0) and (decade_max >= 0):
ticklocs.extend([0.2, 0.3, 0.4, 0.6, 0.7, 0.8])
if (decade_max >= 2):
expo = -np.arange(max(2, decade_min), decade_max + 1)
newticks = 1 - np.outer(np.arange(2, 10), 10**expo).ravel()
ticklocs.extend(list(newticks))

return self.raise_if_exceeds(np.array(ticklocs))


class AutoLocator(MaxNLocator):
def __init__(self):
MaxNLocator.__init__(self, nbins=9, steps=[1, 2, 5, 10])
Expand Down

0 comments on commit 76840ea

Please sign in to comment.