-
Notifications
You must be signed in to change notification settings - Fork 3
/
plot.py
346 lines (298 loc) · 14.3 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import logging
from typing import Sequence, Callable, TypeVar, Tuple, Optional, List, Any
import matplotlib.figure
import matplotlib.ticker as plticker
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
log = logging.getLogger(__name__)
MATPLOTLIB_DEFAULT_FIGURE_SIZE = (6.4, 4.8)
class Color:
def __init__(self, c: Any):
"""
:param c: any color specification that is understood by matplotlib
"""
self.rgba = matplotlib.colors.to_rgba(c)
def darken(self, amount: float):
"""
:param amount: amount to darken in [0,1], where 1 results in black and 0 leaves the color unchanged
:return: the darkened color
"""
import colorsys
rgb = matplotlib.colors.to_rgb(self.rgba)
h, l, s = colorsys.rgb_to_hls(*rgb)
l *= amount
rgb = colorsys.hls_to_rgb(h, l, s)
return Color((*rgb, self.rgba[3]))
def lighten(self, amount: float):
"""
:param amount: amount to lighten in [0,1], where 1 results in white and 0 leaves the color unchanged
:return: the lightened color
"""
import colorsys
rgb = matplotlib.colors.to_rgb(self.rgba)
h, l, s = colorsys.rgb_to_hls(*rgb)
l += (1-l) * amount
rgb = colorsys.hls_to_rgb(h, l, s)
return Color((*rgb, self.rgba[3]))
def alpha(self, opacity: float) -> "Color":
"""
Returns a new color with modified alpha channel (opacity)
:param opacity: the opacity between 0 (transparent) and 1 (fully opaque)
:return: the modified color
"""
if not (0 <= opacity <= 1):
raise ValueError(f"Opacity must be between 0 and 1, got {opacity}")
return Color((*self.rgba[:3], opacity))
def to_hex(self, keep_alpha=True) -> str:
return matplotlib.colors.to_hex(self.rgba, keep_alpha)
class LinearColorMap:
"""
Facilitates usage of linear segmented colour maps by combining a colour map (member `cmap`), which transforms normalised values in [0,1]
into colours, with a normaliser that transforms the original values. The member `scalarMapper`
"""
def __init__(self, norm_min, norm_max, cmap_points: List[Tuple[float, Any]], cmap_points_normalised=False):
"""
:param norm_min: the value that shall be mapped to 0 in the normalised representation (any smaller values are also clipped to 0)
:param norm_max: the value that shall be mapped to 1 in the normalised representation (any larger values are also clipped to 1)
:param cmap_points: a list (of at least two) tuples (v, c) where v is the value and c is the colour associated with the value;
any colour specification supported by matplotlib is admissible
:param cmap_points_normalised: whether the values in `cmap_points` are already normalised
"""
self.norm = matplotlib.colors.Normalize(vmin=norm_min, vmax=norm_max, clip=True)
if not cmap_points_normalised:
cmap_points = [(self.norm(v), c) for v, c in cmap_points]
self.cmap = LinearSegmentedColormap.from_list(f"cmap{id(self)}", cmap_points)
self.scalarMapper = matplotlib.cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
def get_color(self, value):
rgba = self.scalarMapper.to_rgba(value)
return '#%02x%02x%02x%02x' % tuple(int(v * 255) for v in rgba)
def plot_matrix(matrix: np.ndarray, title: str, xtick_labels: Sequence[str], ytick_labels: Sequence[str], xlabel: str,
ylabel: str, normalize=True, figsize: Tuple[int, int] = (9, 9), title_add: str = None) -> matplotlib.figure.Figure:
"""
:param matrix: matrix whose data to plot, where matrix[i, j] will be rendered at x=i, y=j
:param title: the plot's title
:param xtick_labels: the labels for the x-axis ticks
:param ytick_labels: the labels for the y-axis ticks
:param xlabel: the label for the x-axis
:param ylabel: the label for the y-axis
:param normalize: whether to normalise the matrix before plotting it (dividing each entry by the sum of all entries)
:param figsize: an optional size of the figure to be created
:param title_add: an optional second line to add to the title
:return: the figure object
"""
matrix = np.transpose(matrix)
if title_add is not None:
title += f"\n {title_add} "
if normalize:
matrix = matrix.astype('float') / matrix.sum()
fig, ax = plt.subplots(figsize=figsize)
fig.canvas.manager.set_window_title(title.replace("\n", " "))
# We want to show all ticks...
ax.set(xticks=np.arange(matrix.shape[1]),
yticks=np.arange(matrix.shape[0]),
# ... and label them with the respective list entries
xticklabels=xtick_labels, yticklabels=ytick_labels,
title=title,
xlabel=xlabel,
ylabel=ylabel)
im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.4f' if normalize else ('.2f' if matrix.dtype.kind == 'f' else 'd')
thresh = matrix.max() / 2.
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
ax.text(j, i, format(matrix[i, j], fmt),
ha="center", va="center",
color="white" if matrix[i, j] > thresh else "black")
fig.tight_layout()
return fig
TPlot = TypeVar("TPlot", bound="Plot")
class Plot:
def __init__(self, draw: Callable[[], None] = None, name=None):
"""
:param draw: function which returns a matplotlib.Axes object to show
:param name: name/number of the figure, which determines the window caption; it should be unique, as any plot
with the same name will have its contents rendered in the same window. By default, figures are number
sequentially.
"""
fig, ax = plt.subplots(num=name)
self.fig: plt.Figure = fig
self.ax: plt.Axes = ax
if draw is not None:
draw()
def xlabel(self: TPlot, label) -> TPlot:
self.ax.set_xlabel(label)
return self
def ylabel(self: TPlot, label) -> TPlot:
self.ax.set_ylabel(label)
return self
def title(self: TPlot, title: str) -> TPlot:
self.ax.set_title(title)
return self
def xlim(self: TPlot, min_value, max_value) -> TPlot:
self.ax.set_xlim(min_value, max_value)
return self
def ylim(self: TPlot, min_value, max_value) -> TPlot:
self.ax.set_ylim(min_value, max_value)
return self
def save(self, path):
log.info(f"Saving figure in {path}")
self.fig.savefig(path)
def xtick(self: TPlot, major=None, minor=None) -> TPlot:
"""
Sets a tick on every integer multiple of the given base values.
The major ticks are labelled, the minor ticks are not.
:param major: the major tick base value
:param minor: the minor tick base value
:return: self
"""
if major is not None:
self.xtick_major(major)
if minor is not None:
self.xtick_minor(minor)
return self
def xtick_major(self: TPlot, base) -> TPlot:
self.ax.xaxis.set_major_locator(plticker.MultipleLocator(base=base))
return self
def xtick_minor(self: TPlot, base) -> TPlot:
self.ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=base))
return self
def ytick_major(self: TPlot, base) -> TPlot:
self.ax.yaxis.set_major_locator(plticker.MultipleLocator(base=base))
return self
class ScatterPlot(Plot):
N_MAX_TRANSPARENCY = 1000
N_MIN_TRANSPARENCY = 100
MAX_OPACITY = 0.5
MIN_OPACITY = 0.05
def __init__(self, x, y, c=None, c_base: Tuple[float, float, float] = (0, 0, 1), c_opacity=None, x_label=None, y_label=None, **kwargs):
"""
:param x: the x values; if has name (e.g. pd.Series), will be used as axis label
:param y: the y values; if has name (e.g. pd.Series), will be used as axis label
:param c: the colour specification; if None, compose from ``c_base`` and ``c_opacity``
:param c_base: the base colour as (R, G, B) floats
:param c_opacity: the opacity; if None, automatically determine from number of data points
:param x_label:
:param y_label:
:param kwargs:
"""
if c is None:
if c_base is None:
c_base = (0, 0, 1)
if c_opacity is None:
n = len(x)
if n > self.N_MAX_TRANSPARENCY:
transparency = 1
elif n < self.N_MIN_TRANSPARENCY:
transparency = 0
else:
transparency = (n - self.N_MIN_TRANSPARENCY) / (self.N_MAX_TRANSPARENCY - self.N_MIN_TRANSPARENCY)
c_opacity = self.MIN_OPACITY + (self.MAX_OPACITY - self.MIN_OPACITY) * (1-transparency)
c = ((*c_base, c_opacity),)
assert len(x) == len(y)
if x_label is None and hasattr(x, "name"):
x_label = x.name
if y_label is None and hasattr(y, "name"):
y_label = y.name
def draw():
if x_label is not None:
plt.xlabel(x_label)
if x_label is not None:
plt.ylabel(y_label)
plt.scatter(x, y, c=c, **kwargs)
super().__init__(draw)
class HeatMapPlot(Plot):
DEFAULT_CMAP_FACTORY = lambda num_points: LinearSegmentedColormap.from_list("whiteToRed",
((0, (1, 1, 1)), (1 / num_points, (1, 0.96, 0.96)), (1, (0.7, 0, 0))), num_points)
def __init__(self, x, y, x_label=None, y_label=None, bins=60, cmap=None, common_range=True, diagonal=False,
diagonal_color="green", **kwargs):
"""
:param x: the x values
:param y: the y values
:param x_label: the x-axis label
:param y_label: the y-axis label
:param bins: the number of bins to use in each dimension
:param cmap: the colour map to use for heat values (if None, use default)
:param common_range: whether the heat map is to use a common rng for the x- and y-axes (set to False if x and y are completely
different quantities; set to True use cases such as the evaluation of regression model quality)
:param diagonal: whether to draw the diagonal line (useful for regression evaluation)
:param diagonal_color: the colour to use for the diagonal line
:param kwargs: parameters to pass on to plt.imshow
"""
assert len(x) == len(y)
if x_label is None and hasattr(x, "name"):
x_label = x.name
if y_label is None and hasattr(y, "name"):
y_label = y.name
def draw():
nonlocal cmap
x_range = [min(x), max(x)]
y_range = [min(y), max(y)]
rng = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]
if common_range:
x_range = y_range = rng
if diagonal:
plt.plot(rng, rng, '-', lw=0.75, label="_not in legend", color=diagonal_color, zorder=2)
heatmap, _, _ = np.histogram2d(x, y, range=[x_range, y_range], bins=bins, density=False)
extent = [x_range[0], x_range[1], y_range[0], y_range[1]]
if cmap is None:
cmap = HeatMapPlot.DEFAULT_CMAP_FACTORY(len(x))
if x_label is not None:
plt.xlabel(x_label)
if y_label is not None:
plt.ylabel(y_label)
plt.imshow(heatmap.T, extent=extent, origin='lower', interpolation="none", cmap=cmap, zorder=1, aspect="auto", **kwargs)
super().__init__(draw)
class HistogramPlot(Plot):
def __init__(self, values, bins="auto", kde=False, cdf=False, cdf_complementary=False, cdf_secondary_axis=True,
binwidth=None, stat="probability", xlabel=None,
**kwargs):
"""
:param values: the values to plot
:param bins: a bin specification as understood by sns.histplot
:param kde: whether to add a kernel density estimator
:param cdf: whether to add a plot of the cumulative distribution function (cdf)
:param cdf_complementary: whether to plot, if cdf is enabled, the complementary values
:param cdf_secondary_axis: whether to use, if cdf is enabled, a secondary
:param binwidth: the bin width; if None, inferred
:param stat: the statistic to plot (as understood by sns.histplot)
:param xlabel: the label for the x-axis
:param kwargs: arguments to pass on to sns.histplot
"""
def draw():
nonlocal cdf_secondary_axis
sns.histplot(values, bins=bins, kde=kde, binwidth=binwidth, stat=stat, **kwargs)
plt.ylabel(stat)
if cdf:
ecdf_stat = stat
if ecdf_stat not in ("count", "proportion", "probability"):
ecdf_stat = "proportion"
cdf_secondary_axis = True
cdf_ax: Optional[plt.Axes] = None
cdf_ax_label = f"{ecdf_stat} (cdf)"
if cdf_secondary_axis:
cdf_ax: plt.Axes = plt.twinx()
if stat in ("proportion", "probability"):
y_tick = 0.1
elif stat == "percent":
y_tick = 10
else:
y_tick = None
if y_tick is not None:
cdf_ax.yaxis.set_major_locator(plticker.MultipleLocator(base=y_tick))
if cdf_complementary or ecdf_stat in ("count", "proportion", "probability"):
ecdf_stat = "proportion" if stat == "probability" else stat # same semantics but "probability" not understood by ecdfplot
sns.ecdfplot(values, stat=ecdf_stat, complementary=cdf_complementary, color="orange", ax=cdf_ax)
else:
sns.histplot(values, bins=100, stat=stat, element="poly", fill=False, cumulative=True, color="orange", ax=cdf_ax)
if cdf_ax is not None:
cdf_ax.set_ylabel(cdf_ax_label)
if xlabel is not None:
self.xlabel(xlabel)
super().__init__(draw)