Skip to content

Commit

Permalink
Add jitterplot to base_plot_1d framework.
Browse files Browse the repository at this point in the history
  • Loading branch information
corranwebster committed Aug 19, 2014
1 parent 4fd6f9f commit 06aebf6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 226 deletions.
252 changes: 28 additions & 224 deletions chaco/jitterplot.py
@@ -1,5 +1,5 @@

from __future__ import with_statement
from __future__ import absolute_import

from itertools import izip
from math import sqrt
Expand All @@ -9,98 +9,27 @@
from traits.api import (Any, Bool, Callable, Enum, Float,
Instance, Int, Property, Str, Trait, on_trait_change)

from abstract_plot_renderer import AbstractPlotRenderer
from abstract_mapper import AbstractMapper
from array_data_source import ArrayDataSource
from base import reverse_map_1d
from scatterplot import render_markers
from .scatterplot_1d import ScatterPlot1D
from .abstract_mapper import AbstractMapper
from .array_data_source import ArrayDataSource
from .base import reverse_map_1d
from .scatterplot import render_markers


class JitterPlot(AbstractPlotRenderer):
class JitterPlot(ScatterPlot1D):
"""A renderer for a jitter plot, a 1D plot with some width in the
dimension perpendicular to the primary axis. Useful for understanding
dense collections of points.
"""

# The data source of values
index = Instance(ArrayDataSource)

# The single mapper that this plot uses
mapper = Instance(AbstractMapper)

# Just an alias for "mapper"
index_mapper = Property(lambda obj,attr: getattr(obj, "mapper"),
lambda obj,attr,val: setattr(obj, "mapper", val))

x_mapper = Property()
y_mapper = Property()

orientation = Enum("h", "v")

# The size, in pixels, of the area over which to spread the data points
# along the dimension orthogonal to the index direction.
jitter_width = Int(50)

# How the plot should center itself along the orthogonal dimension if the
# component's width is greater than the jitter_width
#align = Enum("center", "left", "right", "top", "bottom")

# The type of marker to use. This is a mapped trait using strings as the
# keys.
marker = MarkerTrait

# The pixel size of the marker, not including the thickness of the outline.
marker_size = Float(4.0)

# The CompiledPath to use if **marker** is set to "custom". This attribute
# must be a compiled path for the Kiva context onto which this plot will
# be rendered. Usually, importing kiva.GraphicsContext will do
# the right thing.
custom_symbol = Any

# The function which actually renders the markers
render_markers_func = Callable(render_markers)

# The thickness, in pixels, of the outline to draw around the marker. If
# this is 0, no outline is drawn.
line_width = Float(1.0)

# The fill color of the marker.
color = black_color_trait

# The color of the outline to draw around the marker.
outline_color = black_color_trait

#------------------------------------------------------------------------
# Built-in selection handling
#------------------------------------------------------------------------

# The name of the metadata attribute to look for on the datasource for
# determine which points are selected and which are not. The metadata
# value returned should be a *list* of numpy arrays suitable for masking
# the values returned by index.get_data().
selection_metadata_name = Str("selections")

# The color to use to render selected points
selected_color = black_color_trait

# Alpha value to apply to points that are not in the set of "selected"
# points
unselected_alpha = Float(0.3)
unselected_line_width = Float(0.0)

#------------------------------------------------------------------------
# Private traits
#------------------------------------------------------------------------

_cache_valid = Bool(False)

_cached_data_pts = Any()
_cached_data_pts_sorted = Any()
_cached_data_argsort = Any()

_screen_cache_valid = Bool(False)
_cached_screen_pts = Any()
_cached_screen_map = Any() # dict mapping index to value points

# The random number seed used to generate the jitter. We store this
Expand All @@ -127,13 +56,13 @@ def map_screen(self, data_array):
if new_x:
new_y = self._make_jitter_vals(len(new_x))
sm.update(dict((new_x[i], new_y[i]) for i in range(len(new_x))))
xs = self.mapper.map_screen(data_array)
xs = self.index_mapper.map_screen(data_array)
ys = [sm[x] for x in xs]

else:
if self._jitter_seed is None:
self._set_seed(data_array)
xs = self.mapper.map_screen(data_array)
xs = self.index_mapper.map_screen(data_array)
ys = self._make_jitter_vals(len(data_array))

if self.orientation == "h":
Expand All @@ -143,24 +72,9 @@ def map_screen(self, data_array):

def _make_jitter_vals(self, numpts):
vals = np.random.uniform(0, self.jitter_width, numpts)
if self.orientation == "h":
ymin = self.y
height = self.height
vals += ymin + height/2 - self.jitter_width/2
else:
xmin = self.x
width = self.width
vals += xmin + width/2 - self.jitter_width/2
vals += self._marker_position
return vals

def map_data(self, screen_pt):
""" Maps a screen space point into the index space of the plot.
"""
x, y = screen_pt
if self.orientation == "v":
x, y = y, x
return self.mapper.map_data(x)

def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, \
index_only = True):
""" Maps a screen space point to an index into the plot's index array(s).
Expand All @@ -171,13 +85,13 @@ def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, \
return None

data_pt = self.map_data(screen_pt)
if ((data_pt < self.mapper.range.low) or \
(data_pt > self.mapper.range.high)) and outside_returns_none:
if ((data_pt < self.index_mapper.range.low) or \
(data_pt > self.index_mapper.range.high)) and outside_returns_none:
return None

if self._cached_data_pts_sorted is None:
self._cached_data_argsort = np.argsort(self._cached_data_pts)
self._cached_data_pts_sorted = self._cached_data_pts[self._cached_data_argsort]
self._cached_data_argsort = np.argsort(self._cached_data)
self._cached_data_pts_sorted = self._cached_data[self._cached_data_argsort]

data = self._cached_data_pts_sorted
try:
Expand Down Expand Up @@ -215,7 +129,7 @@ def _draw_plot(self, gc, view_bounds=None, mode="normal"):
def get_screen_points(self):
if not self._screen_cache_valid:
self._gather_points()
pts = self.map_screen(self._cached_data_pts)
pts = self.map_screen(self._cached_data)
if self.orientation == "h":
self._cached_screen_map = dict((x,y) for x,y in izip(pts[:,0], pts[:,1]))
else:
Expand All @@ -226,61 +140,6 @@ def get_screen_points(self):
self._cached_data_argsort = None
return self._cached_screen_pts

def _gather_points(self):
if self._cache_valid:
return

if not self.index:
return

index, index_mask = self.index.get_data_mask()
if len(index) == 0:
self._cached_data_pts = []
self._cache_valid = True
return

# For the jitter plot, we do not mask or compress the data in any
# way, because if we do, we have no way of transforming from screen
# points back into dataspace. (Tools will be able to find an index
# into the screen points array, but won't be able to go from that
# back into the original data points array.)

#index_range_mask = self.mapper.range.mask_data(index)
#self._cached_data_pts = np.compress(index_mask & index_range_mask, index)
self._cached_data_pts = index
self._cache_valid = True
self._cached_screen_pts = None
self._screen_cache_valid = False

def _render(self, gc, pts):
with gc:
gc.clip_to_rect(self.x, self.y, self.width, self.height)
if not self.index:
return
name = self.selection_metadata_name
md = self.index.metadata
if name in md and md[name] is not None and len(md[name]) > 0:
# FIXME: when will we ever encounter multiple masks in the list?
sel_mask = md[name][0]
sel_pts = np.compress(sel_mask, pts, axis=0)
unsel_pts = np.compress(~sel_mask, pts, axis=0)
color = list(self.color_)
color[3] *= self.unselected_alpha
outline_color = list(self.outline_color_)
outline_color[3] *= self.unselected_alpha
if unsel_pts.size > 0:
self.render_markers_func(gc, unsel_pts, self.marker, self.marker_size,
tuple(color), self.unselected_line_width, tuple(outline_color),
self.custom_symbol)
if sel_pts.size > 0:
self.render_markers_func(gc, sel_pts, self.marker, self.marker_size,
self.selected_color_, self.line_width, self.outline_color_,
self.custom_symbol)
else:
self.render_markers_func(gc, pts, self.marker, self.marker_size,
self.color_, self.line_width, self.outline_color_,
self.custom_symbol)

def _set_seed(self, data_array):
""" Sets the internal random seed based on some input data """
if isinstance(data_array, np.ndarray):
Expand All @@ -290,74 +149,19 @@ def _set_seed(self, data_array):

self._jitter_seed = seed

@on_trait_change("index.data_changed")
def _invalidate(self):
self._cache_valid = False
self._screen_cache_valid = False

@on_trait_change("mapper.updated")
def _invalidate_screen(self):
self._screen_cache_valid = False

#------------------------------------------------------------------------
# Event handlers
#------------------------------------------------------------------------

def _get_x_mapper(self):
if self.orientation == "h":
return self.mapper
else:
return None

def _set_x_mapper(self, val):
if self.orientation == "h":
self.mapper = val
else:
raise ValueError("x_mapper is not defined for a vertical jitter plot")
def _get_marker_position(self):
x, y = self.position
w, h = self.bounds

def _get_y_mapper(self):
if self.orientation == "v":
return self.mapper
else:
return None
if self.orientation == 'v':
y, h = x, w

def _set_y_mapper(self, val):
if self.orientation == "v":
self.mapper = val
else:
raise ValueError("y_mapper is not defined for a horizontal jitter plot")

def _update_mappers(self):
mapper = self.mapper
if mapper is None:
return

x = self.x
x2 = self.x2
y = self.y
y2 = self.y2

if "left" in self.origin and self.orientation == 'h':
mapper.screen_bounds = (x, x2)
elif "right" in self.origin and self.orientation == 'h':
mapper.screen_bounds = (x2, x)
elif "bottom" in self.origin and self.orientation == 'v':
mapper.screen_bounds = (y, y2)
elif "top" in self.origin and self.orientation == 'v':
mapper.screen_bounds = (y2, y)

self.invalidate_draw()
self._cache_valid = False
self._screen_cache_valid = False

def _bounds_changed(self, old, new):
super(JitterPlot, self)._bounds_changed(old, new)
self._update_mappers()

def _bounds_items_changed(self, event):
super(JitterPlot, self)._bounds_items_changed(event)
self._update_mappers()

def _orientation_changed(self):
self._update_mappers()
if self.alignment == 'center':
position = y + h/2.0 - self.jitter_width/2.0
elif self.alignment in ['left', 'bottom']:
position = y
elif self.alignment in ['right', 'top']:
position = y + h - self.jitter_width/2.0

position += self.marker_offset
return position
7 changes: 5 additions & 2 deletions chaco/plot.py
Expand Up @@ -41,6 +41,7 @@
from text_plot_1d import TextPlot1D
from filled_line_plot import FilledLinePlot
from quiverplot import QuiverPlot
from jitterplot import JitterPlot



Expand Down Expand Up @@ -123,7 +124,8 @@ class Plot(DataView):
quiver = QuiverPlot,
scatter_1d = ScatterPlot1D,
textplot_1d = TextPlot1D,
line_scatter_1d = LineScatterPlot1D))
line_scatter_1d = LineScatterPlot1D,
jitterplot = JitterPlot))

#------------------------------------------------------------------------
# Annotations and decorations
Expand Down Expand Up @@ -974,7 +976,8 @@ def plot_1d(self, data, type='scatter_1d', name=None, orientation=None,
direction = 'flipped'

plots = []
if plot_type in ("scatter_1d", "textplot_1d", "line_scatter_1d"):
if plot_type in ("scatter_1d", "textplot_1d", "line_scatter_1d",
"jitterplot"):
# Tie data to the index range
index = self._get_or_create_datasource(data[0])
if self.default_index is None:
Expand Down

0 comments on commit 06aebf6

Please sign in to comment.