Permalink
Browse files

Merge pull request #973 from cgohlke/patch-9

Fix sankey.py pep8 and py3 compatibility
  • Loading branch information...
2 parents 7a6897e + 829e253 commit 11567bbabb8c569e789e150684b0cc278ef35e4e @WeatherGod WeatherGod committed Jul 18, 2012
Showing with 334 additions and 317 deletions.
  1. +13 −9 examples/api/sankey_demo_links.py
  2. +99 −101 examples/api/sankey_demo_old.py
  3. +222 −207 lib/matplotlib/sankey.py
@@ -1,31 +1,35 @@
"""Demonstrate/test the Sankey class by producing a long chain of connections.
"""
-import numpy as np
-import matplotlib.pyplot as plt
-from matplotlib.sankey import Sankey
from itertools import cycle
+import matplotlib.pyplot as plt
+from matplotlib.sankey import Sankey
+
links_per_side = 6
+
+
def side(sankey, n=1):
- """Generate a side chain.
- """
+ """Generate a side chain."""
prior = len(sankey.diagrams)
colors = cycle(['orange', 'b', 'g', 'r', 'c', 'm', 'y'])
for i in range(0, 2*n, 2):
sankey.add(flows=[1, -1], orientations=[-1, -1],
- patchlabel=str(prior+i), facecolor=colors.next(),
+ patchlabel=str(prior+i), facecolor=next(colors),
prior=prior+i-1, connect=(1, 0), alpha=0.5)
sankey.add(flows=[1, -1], orientations=[1, 1],
- patchlabel=str(prior+i+1), facecolor=colors.next(),
+ patchlabel=str(prior+i+1), facecolor=next(colors),
prior=prior+i, connect=(1, 0), alpha=0.5)
+
+
def corner(sankey):
- """Generate a corner link.
- """
+ """Generate a corner link."""
prior = len(sankey.diagrams)
sankey.add(flows=[1, -1], orientations=[0, 1],
patchlabel=str(prior), facecolor='k',
prior=prior-1, connect=(1, 0), alpha=0.5)
+
+
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[],
title="Why would you want to do this?\n(But you could.)")
@@ -7,149 +7,149 @@
import numpy as np
+
def sankey(ax,
outputs=[100.], outlabels=None,
inputs=[100.], inlabels='',
dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
"""Draw a Sankey diagram.
-outputs: array of outputs, should sum up to 100%
-outlabels: output labels (same length as outputs),
-or None (use default labels) or '' (no labels)
-inputs and inlabels: similar for inputs
-dx: horizontal elongation
-dy: vertical elongation
-outangle: output arrow angle [deg]
-w: output arrow shoulder
-inangle: input dip angle
-offset: text offset
-**kwargs: propagated to Patch (e.g. fill=False)
-
-Return (patch,[intexts,outtexts])."""
-
+ outputs: array of outputs, should sum up to 100%
+ outlabels: output labels (same length as outputs),
+ or None (use default labels) or '' (no labels)
+ inputs and inlabels: similar for inputs
+ dx: horizontal elongation
+ dy: vertical elongation
+ outangle: output arrow angle [deg]
+ w: output arrow shoulder
+ inangle: input dip angle
+ offset: text offset
+ **kwargs: propagated to Patch (e.g. fill=False)
+
+ Return (patch,[intexts,outtexts]).
+ """
import matplotlib.patches as mpatches
from matplotlib.path import Path
outs = np.absolute(outputs)
outsigns = np.sign(outputs)
- outsigns[-1] = 0 # Last output
+ outsigns[-1] = 0 # Last output
ins = np.absolute(inputs)
insigns = np.sign(inputs)
- insigns[0] = 0 # First input
+ insigns[0] = 0 # First input
- assert sum(outs)==100, "Outputs don't sum up to 100%"
- assert sum(ins)==100, "Inputs don't sum up to 100%"
+ assert sum(outs) == 100, "Outputs don't sum up to 100%"
+ assert sum(ins) == 100, "Inputs don't sum up to 100%"
def add_output(path, loss, sign=1):
- h = (loss/2+w)*np.tan(outangle/180.*np.pi) # Arrow tip height
- move,(x,y) = path[-1] # Use last point as reference
- if sign==0: # Final loss (horizontal)
- path.extend([(Path.LINETO,[x+dx,y]),
- (Path.LINETO,[x+dx,y+w]),
- (Path.LINETO,[x+dx+h,y-loss/2]), # Tip
- (Path.LINETO,[x+dx,y-loss-w]),
- (Path.LINETO,[x+dx,y-loss])])
- outtips.append((sign,path[-3][1]))
- else: # Intermediate loss (vertical)
- path.extend([(Path.CURVE4,[x+dx/2,y]),
- (Path.CURVE4,[x+dx,y]),
- (Path.CURVE4,[x+dx,y+sign*dy]),
- (Path.LINETO,[x+dx-w,y+sign*dy]),
- (Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
- (Path.LINETO,[x+dx+loss+w,y+sign*dy]),
- (Path.LINETO,[x+dx+loss,y+sign*dy]),
- (Path.CURVE3,[x+dx+loss,y-sign*loss]),
- (Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
- outtips.append((sign,path[-5][1]))
+ h = (loss/2 + w)*np.tan(outangle/180. * np.pi) # Arrow tip height
+ move, (x, y) = path[-1] # Use last point as reference
+ if sign == 0: # Final loss (horizontal)
+ path.extend([(Path.LINETO, [x+dx, y]),
+ (Path.LINETO, [x+dx, y+w]),
+ (Path.LINETO, [x+dx+h, y-loss/2]), # Tip
+ (Path.LINETO, [x+dx, y-loss-w]),
+ (Path.LINETO, [x+dx, y-loss])])
+ outtips.append((sign, path[-3][1]))
+ else: # Intermediate loss (vertical)
+ path.extend([(Path.CURVE4, [x+dx/2, y]),
+ (Path.CURVE4, [x+dx, y]),
+ (Path.CURVE4, [x+dx, y+sign*dy]),
+ (Path.LINETO, [x+dx-w, y+sign*dy]),
+ (Path.LINETO, [x+dx+loss/2, y+sign*(dy+h)]), # Tip
+ (Path.LINETO, [x+dx+loss+w, y+sign*dy]),
+ (Path.LINETO, [x+dx+loss, y+sign*dy]),
+ (Path.CURVE3, [x+dx+loss, y-sign*loss]),
+ (Path.CURVE3, [x+dx/2+loss, y-sign*loss])])
+ outtips.append((sign, path[-5][1]))
def add_input(path, gain, sign=1):
- h = (gain/2)*np.tan(inangle/180.*np.pi) # Dip depth
- move,(x,y) = path[-1] # Use last point as reference
- if sign==0: # First gain (horizontal)
- path.extend([(Path.LINETO,[x-dx,y]),
- (Path.LINETO,[x-dx+h,y+gain/2]), # Dip
- (Path.LINETO,[x-dx,y+gain])])
- xd,yd = path[-2][1] # Dip position
- indips.append((sign,[xd-h,yd]))
- else: # Intermediate gain (vertical)
- path.extend([(Path.CURVE4,[x-dx/2,y]),
- (Path.CURVE4,[x-dx,y]),
- (Path.CURVE4,[x-dx,y+sign*dy]),
- (Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
- (Path.LINETO,[x-dx-gain,y+sign*dy]),
- (Path.CURVE3,[x-dx-gain,y-sign*gain]),
- (Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
- xd,yd = path[-4][1] # Dip position
- indips.append((sign,[xd,yd+sign*h]))
-
- outtips = [] # Output arrow tip dir. and positions
- urpath = [(Path.MOVETO,[0,100])] # 1st point of upper right path
- lrpath = [(Path.LINETO,[0,0])] # 1st point of lower right path
- for loss,sign in zip(outs,outsigns):
+ h = (gain/2)*np.tan(inangle/180. * np.pi) # Dip depth
+ move, (x, y) = path[-1] # Use last point as reference
+ if sign == 0: # First gain (horizontal)
+ path.extend([(Path.LINETO, [x-dx, y]),
+ (Path.LINETO, [x-dx+h, y+gain/2]), # Dip
+ (Path.LINETO, [x-dx, y+gain])])
+ xd, yd = path[-2][1] # Dip position
+ indips.append((sign, [xd-h, yd]))
+ else: # Intermediate gain (vertical)
+ path.extend([(Path.CURVE4, [x-dx/2, y]),
+ (Path.CURVE4, [x-dx, y]),
+ (Path.CURVE4, [x-dx, y+sign*dy]),
+ (Path.LINETO, [x-dx-gain/2, y+sign*(dy-h)]), # Dip
+ (Path.LINETO, [x-dx-gain, y+sign*dy]),
+ (Path.CURVE3, [x-dx-gain, y-sign*gain]),
+ (Path.CURVE3, [x-dx/2-gain, y-sign*gain])])
+ xd, yd = path[-4][1] # Dip position
+ indips.append((sign, [xd, yd+sign*h]))
+
+ outtips = [] # Output arrow tip dir. and positions
+ urpath = [(Path.MOVETO, [0, 100])] # 1st point of upper right path
+ lrpath = [(Path.LINETO, [0, 0])] # 1st point of lower right path
+ for loss, sign in zip(outs, outsigns):
add_output(sign>=0 and urpath or lrpath, loss, sign=sign)
- indips = [] # Input arrow tip dir. and positions
- llpath = [(Path.LINETO,[0,0])] # 1st point of lower left path
- ulpath = [(Path.MOVETO,[0,100])] # 1st point of upper left path
- for gain,sign in zip(ins,insigns)[::-1]:
+ indips = [] # Input arrow tip dir. and positions
+ llpath = [(Path.LINETO, [0, 0])] # 1st point of lower left path
+ ulpath = [(Path.MOVETO, [0, 100])] # 1st point of upper left path
+ for gain, sign in reversed(list(zip(ins, insigns))):
add_input(sign<=0 and llpath or ulpath, gain, sign=sign)
def revert(path):
"""A path is not just revertable by path[::-1] because of Bezier
-curves."""
+ curves."""
rpath = []
nextmove = Path.LINETO
- for move,pos in path[::-1]:
- rpath.append((nextmove,pos))
+ for move, pos in path[::-1]:
+ rpath.append((nextmove, pos))
nextmove = move
return rpath
# Concatenate subpathes in correct order
path = urpath + revert(lrpath) + llpath + revert(ulpath)
- codes,verts = zip(*path)
+ codes, verts = zip(*path)
verts = np.array(verts)
# Path patch
- path = Path(verts,codes)
+ path = Path(verts, codes)
patch = mpatches.PathPatch(path, **kwargs)
ax.add_patch(patch)
- if False: # DEBUG
+ if False: # DEBUG
print("urpath", urpath)
print("lrpath", revert(lrpath))
print("llpath", llpath)
print("ulpath", revert(ulpath))
-
- xs,ys = zip(*verts)
- ax.plot(xs,ys,'go-')
+ xs, ys = zip(*verts)
+ ax.plot(xs, ys, 'go-')
# Labels
- def set_labels(labels,values):
+ def set_labels(labels, values):
"""Set or check labels according to values."""
- if labels=='': # No labels
+ if labels == '': # No labels
return labels
- elif labels is None: # Default labels
- return [ '%2d%%' % val for val in values ]
+ elif labels is None: # Default labels
+ return ['%2d%%' % val for val in values]
else:
- assert len(labels)==len(values)
+ assert len(labels) == len(values)
return labels
- def put_labels(labels,positions,output=True):
+ def put_labels(labels, positions, output=True):
"""Put labels to positions."""
texts = []
lbls = output and labels or labels[::-1]
- for i,label in enumerate(lbls):
- s,(x,y) = positions[i] # Label direction and position
- if s==0:
- t = ax.text(x+offset,y,label,
+ for i, label in enumerate(lbls):
+ s, (x, y) = positions[i] # Label direction and position
+ if s == 0:
+ t = ax.text(x+offset, y, label,
ha=output and 'left' or 'right', va='center')
- elif s>0:
- t = ax.text(x,y+offset,label, ha='center', va='bottom')
+ elif s > 0:
+ t = ax.text(x, y+offset, label, ha='center', va='bottom')
else:
- t = ax.text(x,y-offset,label, ha='center', va='top')
+ t = ax.text(x, y-offset, label, ha='center', va='top')
texts.append(t)
return texts
@@ -160,32 +160,30 @@ def put_labels(labels,positions,output=True):
intexts = put_labels(inlabels, indips, output=False)
# Axes management
- ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
- ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
+ ax.set_xlim(verts[:, 0].min()-dx, verts[:, 0].max()+dx)
+ ax.set_ylim(verts[:, 1].min()-dy, verts[:, 1].max()+dy)
ax.set_aspect('equal', adjustable='datalim')
- return patch,[intexts,outtexts]
+ return patch, [intexts, outtexts]
+
if __name__=='__main__':
import matplotlib.pyplot as plt
- outputs = [10.,-20.,5.,15.,-10.,40.]
- outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
- outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]
+ outputs = [10., -20., 5., 15., -10., 40.]
+ outlabels = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Hurray!']
+ outlabels = [s+'\n%d%%' % abs(l) for l, s in zip(outputs, outlabels)]
- inputs = [60.,-25.,15.]
+ inputs = [60., -25., 15.]
fig = plt.figure()
- ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
- title="Sankey diagram"
- )
+ ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[], title="Sankey diagram")
- patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
- inputs=inputs, inlabels=None,
- fc='g', alpha=0.2)
+ patch, (intexts, outtexts) = sankey(ax, outputs=outputs,
+ outlabels=outlabels, inputs=inputs,
+ inlabels=None, fc='g', alpha=0.2)
outtexts[1].set_color('r')
outtexts[-1].set_fontweight('bold')
plt.show()
-
Oops, something went wrong.

0 comments on commit 11567bb

Please sign in to comment.