/
variable_statistics_plot.py
350 lines (291 loc) · 12.6 KB
/
variable_statistics_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from __future__ import division
import os
import warnings
import numpy
import six
import chainer
from chainer.backends import cuda
from chainer.training import extension
from chainer.training import trigger as trigger_module
try:
import matplotlib
_available = True
except ImportError:
_available = False
if _available:
if hasattr(matplotlib.colors, 'to_rgba'):
_to_rgba = matplotlib.colors.to_rgba
else:
# For matplotlib 1.x
_to_rgba = matplotlib.colors.ColorConverter().to_rgba
_plot_color = _to_rgba('#1f77b4') # C0 color
_plot_color_trans = _plot_color[:3] + (0.2,) # apply alpha
_plot_common_kwargs = {
'alpha': 0.2, 'linewidth': 0, 'color': _plot_color_trans}
def _check_available():
if not _available:
warnings.warn('matplotlib is not installed on your environment, '
'so nothing will be plotted at this time. '
'Please install matplotlib to plot figures.\n\n'
' $ pip install matplotlib\n')
def _unpack_variables(x, memo=None):
if memo is None:
memo = ()
if isinstance(x, chainer.Variable):
memo += (x,)
elif isinstance(x, chainer.Link):
memo += tuple(x.params(include_uninit=True))
elif isinstance(x, (list, tuple)):
for xi in x:
memo += _unpack_variables(xi)
return memo
class Reservoir(object):
"""Reservoir sample with a fixed sized buffer."""
def __init__(self, size, data_shape, dtype='f'):
self.size = size
self.data = numpy.zeros((size,) + data_shape, dtype=dtype)
self.idxs = numpy.zeros((size,), dtype='i')
self.counter = 0
def add(self, x, idx=None):
if self.counter < self.size:
self.data[self.counter] = x
self.idxs[self.counter] = idx or self.counter
elif self.counter >= self.size and \
numpy.random.random() < self.size / float(self.counter + 1):
i = numpy.random.randint(self.size)
self.data[i] = x
self.idxs[i] = idx or self.counter
self.counter += 1
def get_data(self):
idxs = self.idxs[:min(self.counter, self.size)]
sorted_args = numpy.argsort(idxs)
return idxs[sorted_args], self.data[sorted_args]
class Statistician(object):
"""Helper to compute basic NumPy-like statistics."""
def __init__(self, collect_mean, collect_std, percentile_sigmas):
self.collect_mean = collect_mean
self.collect_std = collect_std
self.percentile_sigmas = percentile_sigmas
def __call__(self, x, axis=0, dtype=None, xp=None):
if axis is None:
axis = tuple(range(x.ndim))
elif not isinstance(axis, (tuple, list)):
axis = axis,
return self.collect(x, axis)
def collect(self, x, axis):
out = dict()
if self.collect_mean:
out['mean'] = x.mean(axis=axis)
if self.collect_std:
out['std'] = x.std(axis=axis)
if self.percentile_sigmas:
xp = cuda.get_array_module(x)
if xp is numpy:
p = numpy.percentile(x, self.percentile_sigmas, axis=axis)
else:
# TODO(hvy): Use percentile from CuPy once it is supported
p = cuda.to_gpu(
numpy.percentile(
cuda.to_cpu(x), self.percentile_sigmas, axis=axis))
out['percentile'] = p
return out
class VariableStatisticsPlot(extension.Extension):
"""Trainer extension to plot statistics for :class:`Variable`\s.
This extension collects statistics for a single :class:`Variable`, a list
of :class:`Variable`\s or similarly a single or a list of
:class:`Link`\s containing one or more :class:`Variable`\s. In case
multiple :class:`Variable`\s are found, the means are computed. The
collected statistics are plotted and saved as an image in the directory
specified by the :class:`Trainer`.
Statistics include mean, standard deviation and percentiles.
This extension uses reservoir sampling to preserve memory, using a fixed
size running sample. This means that collected items in the sample are
discarded uniformly at random when the number of items becomes larger
than the maximum sample size, but each item is expected to occur in the
sample with equal probability.
Args:
targets (:class:`Variable`, :class:`Link` or list of either):
Parameters for which statistics are collected.
max_sample_size (int):
Maximum number of running samples.
report_data (bool):
If ``True``, data (e.g. weights) statistics are plotted. If
``False``, they are neither computed nor plotted.
report_grad (bool):
If ``True``, gradient statistics are plotted. If ``False``, they
are neither computed nor plotted.
plot_mean (bool):
If ``True``, means are plotted. If ``False``, they are
neither computed nor plotted.
plot_std (bool):
If ``True``, standard deviations are plotted. If ``False``, they
are neither computed nor plotted.
percentile_sigmas (float or tuple of floats):
Percentiles to plot in the range :math:`[0, 100]`.
trigger:
Trigger that decides when to save the plots as an image. This is
distinct from the trigger of this extension itself. If it is a
tuple in the form ``<int>, 'epoch'`` or ``<int>, 'iteration'``, it
is passed to :class:`IntervalTrigger`.
file_name (str):
Name of the output image file under the output directory.
figsize (tuple of int):
Matlotlib ``figsize`` argument that specifies the size of the
output image.
marker (str):
Matplotlib ``marker`` argument that specified the marker style of
the plots.
grid (bool):
Matplotlib ``grid`` argument that specifies whether grids are
rendered in in the plots or not.
"""
def __init__(self, targets, max_sample_size=1000,
report_data=True, report_grad=True,
plot_mean=True, plot_std=True,
percentile_sigmas=(
0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100),
trigger=(1, 'epoch'), file_name='statistics.png',
figsize=None, marker=None, grid=True):
if file_name is None:
raise ValueError('Missing output file name of statstics plot')
self._vars = _unpack_variables(targets)
if len(self._vars) == 0:
raise ValueError(
'Need at least one variables for which to collect statistics.'
'\nActual: 0 <= 0')
if not any((plot_mean, plot_std, bool(percentile_sigmas))):
raise ValueError('Nothing to plot')
self._keys = []
if report_data:
self._keys.append('data')
if report_grad:
self._keys.append('grad')
self._report_data = report_data
self._report_grad = report_grad
self._statistician = Statistician(
collect_mean=plot_mean, collect_std=plot_std,
percentile_sigmas=percentile_sigmas)
self._plot_mean = plot_mean
self._plot_std = plot_std
self._plot_percentile = bool(percentile_sigmas)
self._trigger = trigger_module.get_trigger(trigger)
self._file_name = file_name
self._figsize = figsize
self._marker = marker
self._grid = grid
if not self._plot_percentile:
n_percentile = 0
else:
if not isinstance(percentile_sigmas, (list, tuple)):
n_percentile = 1 # scalar, single percentile
else:
n_percentile = len(percentile_sigmas)
self._data_shape = (
len(self._keys), int(plot_mean) + int(plot_std) + n_percentile)
self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
@staticmethod
def available():
_check_available()
return _available
def __call__(self, trainer):
if _available:
# Dynamically import pyplot to call matplotlib.use()
# after importing chainer.training.extensions
import matplotlib.pyplot as plt
else:
return
xp = cuda.get_array_module(self._vars[0].data)
stats = xp.zeros(self._data_shape, dtype=xp.float32)
for i, k in enumerate(self._keys):
xs = []
for var in self._vars:
x = getattr(var, k, None)
if x is not None:
xs.append(x.ravel())
if len(xs) > 0:
stat_dict = self._statistician(
xp.concatenate(xs, axis=0), axis=0, xp=xp)
stat_list = []
if self._plot_mean:
stat_list.append(xp.atleast_1d(stat_dict['mean']))
if self._plot_std:
stat_list.append(xp.atleast_1d(stat_dict['std']))
if self._plot_percentile:
stat_list.append(xp.atleast_1d(stat_dict['percentile']))
stats[i] = xp.concatenate(stat_list, axis=0)
if xp != numpy:
stats = cuda.to_cpu(stats)
self._samples.add(stats, idx=trainer.updater.iteration)
if self._trigger(trainer):
file_path = os.path.join(trainer.out, self._file_name)
self.save_plot_using_module(file_path, plt)
def save_plot_using_module(self, file_path, plt):
nrows = int(self._plot_mean or self._plot_std) \
+ int(self._plot_percentile)
ncols = len(self._keys)
fig, axes = plt.subplots(
nrows, ncols, figsize=self._figsize, sharex=True)
if not isinstance(axes, numpy.ndarray): # single subplot
axes = numpy.asarray([axes])
if nrows == 1:
axes = axes[None, :]
elif ncols == 1:
axes = axes[:, None]
assert axes.ndim == 2
idxs, data = self._samples.get_data()
# Offset to access percentile data from `data`
offset = int(self._plot_mean) + int(self._plot_std)
n_percentile = data.shape[-1] - offset
n_percentile_mid_floor = n_percentile // 2
n_percentile_odd = n_percentile % 2 == 1
for col in six.moves.range(ncols):
row = 0
ax = axes[row, col]
ax.set_title(self._keys[col]) # `data` or `grad`
if self._plot_mean or self._plot_std:
if self._plot_mean and self._plot_std:
ax.errorbar(
idxs, data[:, col, 0], data[:, col, 1],
color=_plot_color, ecolor=_plot_color_trans,
label='mean, std', marker=self._marker)
else:
if self._plot_mean:
label = 'mean'
elif self._plot_std:
label = 'std'
ax.plot(
idxs, data[:, col, 0], color=_plot_color, label=label,
marker=self._marker)
row += 1
if self._plot_percentile:
ax = axes[row, col]
for i in six.moves.range(n_percentile_mid_floor + 1):
if n_percentile_odd and i == n_percentile_mid_floor:
# Enters at most once per sub-plot, in case there is
# only a single percentile to plot or when this
# percentile is the mid percentile and the numner of
# percentiles are odd
ax.plot(
idxs, data[:, col, offset + i], color=_plot_color,
label='percentile', marker=self._marker)
else:
if i == n_percentile_mid_floor:
# Last percentiles and the number of all
# percentiles are even
label = 'percentile'
else:
label = '_nolegend_'
ax.fill_between(
idxs,
data[:, col, offset + i],
data[:, col, -i - 1],
label=label,
**_plot_common_kwargs)
ax.set_xlabel('iteration')
for ax in axes.ravel():
ax.legend()
if self._grid:
ax.grid()
ax.set_axisbelow(True)
fig.savefig(file_path)
plt.close()