-
-
Notifications
You must be signed in to change notification settings - Fork 7.4k
/
category.py
233 lines (194 loc) · 7.14 KB
/
category.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Plotting of string "category" data: ``plot(['d', 'f', 'a'], [1, 2, 3])`` will
plot three points with x-axis values of 'd', 'f', 'a'.
See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an
example.
The module uses Matplotlib's `matplotlib.units` mechanism to convert from
strings to integers and provides a tick locator, a tick formatter, and the
`.UnitData` class that creates and stores the string-to-integer mapping.
"""
from collections import OrderedDict
import dateutil.parser
import itertools
import logging
import numpy as np
from matplotlib import _api, ticker, units
_log = logging.getLogger(__name__)
class StrCategoryConverter(units.ConversionInterface):
@staticmethod
def convert(value, unit, axis):
"""
Convert strings in *value* to floats using mapping information stored
in the *unit* object.
Parameters
----------
value : str or iterable
Value or list of values to be converted.
unit : `.UnitData`
An object mapping strings to integers.
axis : `~matplotlib.axis.Axis`
The axis on which the converted value is plotted.
.. note:: *axis* is unused.
Returns
-------
float or `~numpy.ndarray` of float
"""
if unit is None:
raise ValueError(
'Missing category information for StrCategoryConverter; '
'this might be caused by unintendedly mixing categorical and '
'numeric data')
StrCategoryConverter._validate_unit(unit)
# dtype = object preserves numerical pass throughs
values = np.atleast_1d(np.array(value, dtype=object))
# force an update so it also does type checking
unit.update(values)
return np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
@staticmethod
def axisinfo(unit, axis):
"""
Set the default axis ticks and labels.
Parameters
----------
unit : `.UnitData`
object string unit information for value
axis : `~matplotlib.axis.Axis`
axis for which information is being set
.. note:: *axis* is not used
Returns
-------
`~matplotlib.units.AxisInfo`
Information to support default tick labeling
"""
StrCategoryConverter._validate_unit(unit)
# locator and formatter take mapping dict because
# args need to be pass by reference for updates
majloc = StrCategoryLocator(unit._mapping)
majfmt = StrCategoryFormatter(unit._mapping)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
@staticmethod
def default_units(data, axis):
"""
Set and update the `~matplotlib.axis.Axis` units.
Parameters
----------
data : str or iterable of str
axis : `~matplotlib.axis.Axis`
axis on which the data is plotted
Returns
-------
`.UnitData`
object storing string to integer mapping
"""
# the conversion call stack is default_units -> axis_info -> convert
if axis.units is None:
axis.set_units(UnitData(data))
else:
axis.units.update(data)
return axis.units
@staticmethod
def _validate_unit(unit):
if not hasattr(unit, '_mapping'):
raise ValueError(
f'Provided unit "{unit}" is not valid for a categorical '
'converter, as it does not have a _mapping attribute.')
class StrCategoryLocator(ticker.Locator):
"""Tick at every integer mapping of the string data."""
def __init__(self, units_mapping):
"""
Parameters
----------
units_mapping : dict
Mapping of category names (str) to indices (int).
"""
self._units = units_mapping
def __call__(self):
# docstring inherited
return list(self._units.values())
def tick_values(self, vmin, vmax):
# docstring inherited
return self()
class StrCategoryFormatter(ticker.Formatter):
"""String representation of the data at every tick."""
def __init__(self, units_mapping):
"""
Parameters
----------
units_mapping : dict
Mapping of category names (str) to indices (int).
"""
self._units = units_mapping
def __call__(self, x, pos=None):
# docstring inherited
return self.format_ticks([x])[0]
def format_ticks(self, values):
# docstring inherited
r_mapping = {v: self._text(k) for k, v in self._units.items()}
return [r_mapping.get(round(val), '') for val in values]
@staticmethod
def _text(value):
"""Convert text values into utf-8 or ascii strings."""
if isinstance(value, bytes):
value = value.decode(encoding='utf-8')
elif not isinstance(value, str):
value = str(value)
return value
class UnitData:
def __init__(self, data=None):
"""
Create mapping between unique categorical values and integer ids.
Parameters
----------
data : iterable
sequence of string values
"""
self._mapping = OrderedDict()
self._counter = itertools.count()
if data is not None:
self.update(data)
@staticmethod
def _str_is_convertible(val):
"""
Helper method to check whether a string can be parsed as float or date.
"""
try:
float(val)
except ValueError:
try:
dateutil.parser.parse(val)
except (ValueError, TypeError):
# TypeError if dateutil >= 2.8.1 else ValueError
return False
return True
def update(self, data):
"""
Map new values to integer identifiers.
Parameters
----------
data : iterable of str or bytes
Raises
------
TypeError
If elements in *data* are neither str nor bytes.
"""
data = np.atleast_1d(np.array(data, dtype=object))
# check if convertible to number:
convertible = True
for val in OrderedDict.fromkeys(data):
# OrderedDict just iterates over unique values in data.
_api.check_isinstance((str, bytes), value=val)
if convertible:
# this will only be called so long as convertible is True.
convertible = self._str_is_convertible(val)
if val not in self._mapping:
self._mapping[val] = next(self._counter)
if data.size and convertible:
_log.info('Using categorical units to plot a list of strings '
'that are all parsable as floats or dates. If these '
'strings should be plotted as numbers, cast to the '
'appropriate data type before plotting.')
# Register the converter with Matplotlib's unit framework
units.registry[str] = StrCategoryConverter()
units.registry[np.str_] = StrCategoryConverter()
units.registry[bytes] = StrCategoryConverter()
units.registry[np.bytes_] = StrCategoryConverter()