3535from collections import LineCollection
3636from text import Text
3737from transforms import Bbox , Point , Value , get_bbox_transform , bbox_all ,\
38- unit_bbox , inverse_transform_bbox
38+ unit_bbox , inverse_transform_bbox , lbwh_to_bbox
3939
4040class Legend (Artist ):
4141 """
@@ -113,7 +113,7 @@ def __init__(self, parent, handles, labels, loc,
113113 Artist .__init__ (self )
114114 if is_string_like (loc ) and not self .codes .has_key (loc ):
115115 warnings .warn ('Unrecognized location %s. Falling back on upper right; valid locations are\n %s\t ' % (loc , '\n \t ' .join (self .codes .keys ())))
116- if is_string_like (loc ): loc = self .codes .get (loc , 1 )
116+ if is_string_like (loc ): loc = self .codes .get (loc , 0 )
117117
118118 self .numpoints = numpoints
119119 self .prop = prop
@@ -125,7 +125,8 @@ def __init__(self, parent, handles, labels, loc,
125125 self .handletextsep = handletextsep
126126 self .axespad = axespad
127127 self .shadow = shadow
128-
128+
129+ self .isaxes = isaxes
129130 if isaxes : # parent is an Axes
130131 self .set_figure (parent .figure )
131132 else : # parent is a Figure
@@ -144,7 +145,7 @@ def __init__(self, parent, handles, labels, loc,
144145 self ._xdata = linspace (left , left + self .handlelen , self .numpoints )
145146 textleft = left + self .handlelen + self .handletextsep
146147 self .texts = self ._get_texts (labels , textleft , upper )
147- self .handles = self ._get_handles (handles , self .texts )
148+ self .legendHandles = self ._get_handles (handles , self .texts )
148149
149150 left , top = self .texts [- 1 ].get_position ()
150151 HEIGHT = self ._approx_text_height ()
@@ -176,7 +177,7 @@ def draw(self, renderer):
176177 self .legendPatch .draw (renderer )
177178
178179
179- for h in self .handles :
180+ for h in self .legendHandles :
180181 if h is not None :
181182 h .draw (renderer )
182183 if 0 : bbox_artist (h , renderer )
@@ -192,7 +193,7 @@ def _get_handle_text_bbox(self, renderer):
192193 'Get a bbox for the text and lines in axes coords'
193194
194195 bboxesText = [t .get_window_extent (renderer ) for t in self .texts ]
195- bboxesHandles = [h .get_window_extent (renderer ) for h in self .handles if h is not None ]
196+ bboxesHandles = [h .get_window_extent (renderer ) for h in self .legendHandles if h is not None ]
196197
197198
198199 bboxesAll = bboxesText
@@ -249,6 +250,68 @@ def _get_handles(self, handles, texts):
249250
250251 return ret
251252
253+ def _auto_legend_data (self ):
254+ """ Returns list of vertices and extents covered by the plot.
255+
256+ Returns a two long list.
257+
258+ First element is a list of (x, y) vertices (in
259+ axes-coordinates) covered by all the lines and line
260+ collections, in the legend's handles.
261+
262+ Second element is a list of bounding boxes for all the patches in
263+ the legend's handles.
264+ """
265+
266+ if not self .isaxes :
267+ raise Exception , 'Auto legends not available for figure legends.'
268+
269+ def get_handles (ax ):
270+ handles = ax .lines
271+ handles .extend (ax .patches )
272+ handles .extend ([c for c in ax .collections if isinstance (c , LineCollection )])
273+ return handles
274+
275+ ax = self .parent
276+ handles = get_handles (ax )
277+ vertices = []
278+ bboxes = []
279+
280+ inv = ax .transAxes .inverse_xy_tup
281+ for handle in handles :
282+
283+ if isinstance (handle , Line2D ):
284+
285+ xdata = handle .get_xdata ()
286+ ydata = handle .get_ydata ()
287+ trans = handle .get_transform ()
288+ xt , yt = trans .numerix_x_y (xdata , ydata )
289+
290+ # XXX need a special method in transform to do a list of verts
291+ averts = [inv (v ) for v in zip (xt , yt )]
292+ vertices .extend (averts )
293+
294+ elif isinstance (handle , Patch ):
295+
296+ verts = handle .get_verts ()
297+ trans = handle .get_transform ()
298+ tverts = trans .seq_xy_tups (verts )
299+
300+ averts = [inv (v ) for v in tverts ]
301+
302+ bbox = unit_bbox ()
303+ bbox .update (averts , True )
304+ bboxes .append (bbox )
305+
306+ elif isinstance (handle , LineCollection ):
307+ verts = handle .get_verts ()
308+ trans = handle .get_transform ()
309+ tverts = trans .seq_xy_tups (verts )
310+ averts = [inv (v ) for v in tverts ]
311+ vertices .extend (averts )
312+
313+ return [vertices , bboxes ]
314+
252315 def draw_frame (self , b ):
253316 'b is a boolean. Set draw frame to b'
254317 self ._drawFrame = b
@@ -259,11 +322,11 @@ def get_frame(self):
259322
260323 def get_lines (self ):
261324 'return a list of lines.Line2D instances in the legend'
262- return [h for h in self .handles if isinstance (h , Line2D )]
325+ return [h for h in self .legendHandles if isinstance (h , Line2D )]
263326
264327 def get_patches (self ):
265328 'return a list of patch instances in the legend'
266- return silent_list ('Patch' , [h for h in self .handles if isinstance (h , Patch )])
329+ return silent_list ('Patch' , [h for h in self .legendHandles if isinstance (h , Patch )])
267330
268331 def get_texts (self ):
269332 'return a list of text.Text instance in the legend'
@@ -302,7 +365,7 @@ def _offset(self, ox, oy):
302365 x ,y = t .get_position ()
303366 t .set_position ( (x + ox , y + oy ) )
304367
305- for h in self .handles :
368+ for h in self .legendHandles :
306369 if isinstance (h , Line2D ):
307370 x ,y = h .get_xdata (), h .get_ydata ()
308371 h .set_data ( x + ox , y + oy )
@@ -314,6 +377,85 @@ def _offset(self, ox, oy):
314377 self .legendPatch .set_x (x + ox )
315378 self .legendPatch .set_y (y + oy )
316379
380+ def _find_best_position (self , width , height , consider = None ):
381+ """Determine the best location to place the legend.
382+
383+ `consider` is a list of (x, y) pairs to consider as a potential
384+ lower-left corner of the legend. All are axes coords.
385+ """
386+
387+ verts , bboxes = self ._auto_legend_data ()
388+
389+ consider = [self ._loc_to_axes_coords (x , width , height ) for x in range (1 , len (self .codes ))]
390+
391+ tx , ty = self .legendPatch .xy
392+
393+ candidates = []
394+ for l , b in consider :
395+ legendBox = lbwh_to_bbox (l , b , width , height )
396+ badness = 0
397+ badness = legendBox .count_contains (verts )
398+ ox , oy = l - tx , b - ty
399+ for bbox in bboxes :
400+ if legendBox .overlaps (bbox ):
401+ badness += 1
402+
403+ if badness == 0 :
404+ return ox , oy
405+
406+ candidates .append ((badness , (ox , oy )))
407+
408+ # rather than use min() or list.sort(), do this so that we are assured
409+ # that in the case of two equal badnesses, the one first considered is
410+ # returned.
411+ minCandidate = candidates [0 ]
412+ for candidate in candidates :
413+ if candidate [0 ] < minCandidate [0 ]:
414+ minCandidate = candidate
415+
416+ ox , oy = minCandidate [1 ]
417+
418+ return ox , oy
419+
420+
421+ def _loc_to_axes_coords (self , loc , width , height ):
422+ """Convert a location code to axes coordinates.
423+
424+ - loc: a location code, which may be a pair of literal axes coords, or
425+ in range(1, 11). This coresponds to the possible values for
426+ self._loc, excluding "best".
427+
428+ - width, height: the final size of the legend, axes units.
429+ """
430+ BEST , UR , UL , LL , LR , R , CL , CR , LC , UC , C = range (11 )
431+
432+ left = self .axespad
433+ right = 1.0 - (self .axespad + width )
434+ upper = 1.0 - (self .axespad + height )
435+ lower = self .axespad
436+ centerx = 0.5 - (width / 2.0 )
437+ centery = 0.5 - (height / 2.0 )
438+
439+ if loc == UR :
440+ return right , upper
441+ if loc == UL :
442+ return left , upper
443+ if loc == LL :
444+ return left , lower
445+ if loc == LR :
446+ return right , lower
447+ if loc == CL :
448+ return left , centery
449+ if loc in (CR , R ):
450+ return right , centery
451+ if loc == LC :
452+ return centerx , lower
453+ if loc == UC :
454+ return centerx , upper
455+ if loc == C :
456+ return centerx , centery
457+ raise TypeError , "%r isn't an understood type code." % (loc ,)
458+
317459 def _update_positions (self , renderer ):
318460 # called from renderer to allow more precise estimates of
319461 # widths and heights with get_window_extent
@@ -338,7 +480,7 @@ def get_tbounds(text): #get text bounds in axes coords
338480 h += 2 * self .labelsep
339481 hpos .append ( (b ,h ) )
340482
341- for handle , tup in zip (self .handles , hpos ):
483+ for handle , tup in zip (self .legendHandles , hpos ):
342484 y ,h = tup
343485 if isinstance (handle , Line2D ):
344486 ydata = y * ones (self ._xdata .shape , Float )
@@ -366,11 +508,13 @@ def get_tbounds(text): #get text bounds in axes coords
366508 oy = y - yo
367509 self ._offset (ox , oy )
368510 else :
511+ if self ._loc in (BEST ,):
512+ ox , oy = self ._find_best_position (w , h )
369513 if self ._loc in (UL , LL , CL ): # left
370514 ox = self .axespad - l
371- if self ._loc in (BEST , UR , LR , R , CR ): # right
515+ if self ._loc in (UR , LR , R , CR ): # right
372516 ox = 1 - (l + w + self .axespad )
373- if self ._loc in (BEST , UR , UL , UC ): # upper
517+ if self ._loc in (UR , UL , UC ): # upper
374518 oy = 1 - (b + h + self .axespad )
375519 if self ._loc in (LL , LR , LC ): # lower
376520 oy = self .axespad - b
0 commit comments