-
Notifications
You must be signed in to change notification settings - Fork 6
/
plot.py
150 lines (112 loc) · 4.77 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python
# coding: utf8
"""
plotable rich object display on IPython/Jupyter notebooks
"""
__author__ = "Philippe Guglielmetti"
__copyright__ = "Copyright 2015, Philippe Guglielmetti"
__credits__ = []
__license__ = "LGPL"
# import matplotlib and set backend once for all
import os, io, sys, logging, base64
import matplotlib
if os.getenv('TRAVIS'): # are we running https://travis-ci.org/ automated tests ?
matplotlib.use('Agg') # Force matplotlib not to use any Xwindows backend
elif sys.gettrace(): # http://stackoverflow.com/questions/333995/how-to-detect-that-python-code-is-being-executed-through-the-debugger
matplotlib.use('Agg') # because 'QtAgg' crashes python while debugging
else:
pass
# matplotlib.use('pdf') #for high quality pdf, but doesn't work for png, svg ...
logging.info('matplotlib backend is %s' % matplotlib.get_backend())
from . import itertools2
class Plot(object):
"""base class for plotable rich object display on IPython notebooks
inspired from http://nbviewer.ipython.org/github/ipython/ipython/blob/3607712653c66d63e0d7f13f073bde8c0f209ba8/docs/examples/notebooks/display_protocol.ipynb
"""
def _plot(self, ax, **kwargs):
"""abstract method, must be overriden
:param ax: `matplotlib.axis`
:return ax: `matplotlib.axis` after plot
"""
raise NotImplementedError('objects derived from plot.PLot must define a _plot method')
return ax
def render(self, fmt='svg', **kwargs):
return render([self], fmt, **kwargs) # call global function
def save(self, filename, **kwargs):
return save([self], filename, **kwargs) # call global function
# for IPython notebooks
def _repr_html_(self):
"""default rich format is svg plot"""
try:
return self._repr_svg_()
except NotImplementedError:
pass
# this returns the same as _repr_png_, but is Table compatible
buffer = self.render('png')
s = base64.b64encode(buffer).decode('utf-8')
return '<img src="data:image/png;base64,%s">' % s
def html(self, **kwargs):
from IPython.display import HTML
return HTML(self._repr_html_(**kwargs))
def svg(self, **kwargs):
from IPython.display import SVG
return SVG(self._repr_svg_(**kwargs))
def _repr_svg_(self, **kwargs):
return self.render(fmt='svg', **kwargs).decode('utf-8')
def png(self, **kwargs):
from IPython.display import Image
return Image(self._repr_png_(**kwargs), embed=True)
def _repr_png_(self, **kwargs):
return self.render(fmt='png', **kwargs)
def plot(self, **kwargs):
""" renders on IPython Notebook
(alias to make usage more straightforward)
"""
return self.svg(**kwargs)
def render(plotables, fmt='svg', **kwargs):
"""renders several Plot objects"""
import matplotlib.pyplot as plt
# extract optional arguments used for rasterization
printargs, kwargs = itertools2.dictsplit(
kwargs,
['dpi', 'transparent', 'facecolor', 'background', 'figsize']
)
ylim = kwargs.pop('ylim', None)
xlim = kwargs.pop('xlim', None)
title = kwargs.pop('title', None)
fig, ax = plt.subplots()
labels = kwargs.pop('labels', [None] * len(plotables))
offset = kwargs.pop('offset', 0) # slightly shift the points to make superimposed curves more visible
for i, obj in enumerate(plotables):
if labels[i] is None:
labels[i] = str(obj)
if not title:
try:
title = obj._repr_latex_()
# check that title can be used in matplotlib
from matplotlib.mathtext import MathTextParser
parser = MathTextParser('path').parse(title)
except Exception as e:
title = labels[i]
ax = obj._plot(ax, label=labels[i], offset=i * offset, **kwargs)
if ylim: plt.ylim(ylim)
if xlim: plt.xlim(xlim)
ax.set_title(title)
if len(labels) > 1:
ax.legend()
output = io.BytesIO()
fig.savefig(output, format=fmt, **printargs)
data = output.getvalue()
plt.close(fig)
return data
def png(plotables, **kwargs):
from IPython.display import Image
return Image(render(plotables, 'png', **kwargs), embed=True)
def svg(plotables, **kwargs):
from IPython.display import SVG
return SVG(render(plotables, 'svg', **kwargs))
plot = svg
def save(plotables, filename, **kwargs):
ext = filename.split('.')[-1].lower()
kwargs.setdefault('dpi', 600) # force good quality
return open(filename, 'wb').write(render(plotables, ext, **kwargs))