-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
llvm_pass_timings.py
409 lines (334 loc) · 11.4 KB
/
llvm_pass_timings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import re
import operator
import heapq
from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
from functools import cached_property
from numba.core import config
import llvmlite.binding as llvm
class RecordLLVMPassTimings:
"""A helper context manager to track LLVM pass timings.
"""
__slots__ = ["_data"]
def __enter__(self):
"""Enables the pass timing in LLVM.
"""
llvm.set_time_passes(True)
return self
def __exit__(self, exc_val, exc_type, exc_tb):
"""Reset timings and save report internally.
"""
self._data = llvm.report_and_reset_timings()
llvm.set_time_passes(False)
return
def get(self):
"""Retrieve timing data for processing.
Returns
-------
timings: ProcessedPassTimings
"""
return ProcessedPassTimings(self._data)
PassTimingRecord = namedtuple(
"PassTimingRecord",
[
"user_time",
"user_percent",
"system_time",
"system_percent",
"user_system_time",
"user_system_percent",
"wall_time",
"wall_percent",
"pass_name",
"instruction",
],
)
def _adjust_timings(records):
"""Adjust timing records because of truncated information.
Details: The percent information can be used to improve the timing
information.
Returns
-------
res: List[PassTimingRecord]
"""
total_rec = records[-1]
assert total_rec.pass_name == "Total" # guard for implementation error
def make_adjuster(attr):
time_attr = f"{attr}_time"
percent_attr = f"{attr}_percent"
time_getter = operator.attrgetter(time_attr)
def adjust(d):
"""Compute percent x total_time = adjusted"""
total = time_getter(total_rec)
adjusted = total * d[percent_attr] * 0.01
d[time_attr] = adjusted
return d
return adjust
# Make adjustment functions for each field
adj_fns = [
make_adjuster(x) for x in ["user", "system", "user_system", "wall"]
]
# Extract dictionaries from the namedtuples
dicts = map(lambda x: x._asdict(), records)
def chained(d):
# Chain the adjustment functions
for fn in adj_fns:
d = fn(d)
# Reconstruct the namedtuple
return PassTimingRecord(**d)
return list(map(chained, dicts))
class ProcessedPassTimings:
"""A class for processing raw timing report from LLVM.
The processing is done lazily so we don't waste time processing unused
timing information.
"""
def __init__(self, raw_data):
self._raw_data = raw_data
def __bool__(self):
return bool(self._raw_data)
def get_raw_data(self):
"""Returns the raw string data.
Returns
-------
res: str
"""
return self._raw_data
def get_total_time(self):
"""Compute the total time spend in all passes.
Returns
-------
res: float
"""
return self.list_records()[-1].wall_time
def list_records(self):
"""Get the processed data for the timing report.
Returns
-------
res: List[PassTimingRecord]
"""
return self._processed
def list_top(self, n):
"""Returns the top(n) most time-consuming (by wall-time) passes.
Parameters
----------
n: int
This limits the maximum number of items to show.
This function will show the ``n`` most time-consuming passes.
Returns
-------
res: List[PassTimingRecord]
Returns the top(n) most time-consuming passes in descending order.
"""
records = self.list_records()
key = operator.attrgetter("wall_time")
return heapq.nlargest(n, records[:-1], key)
def summary(self, topn=5, indent=0):
"""Return a string summarizing the timing information.
Parameters
----------
topn: int; optional
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
indent: int; optional
Set the indentation level. Defaults to 0 for no indentation.
Returns
-------
res: str
"""
buf = []
prefix = " " * indent
def ap(arg):
buf.append(f"{prefix}{arg}")
ap(f"Total {self.get_total_time():.4f}s")
ap("Top timings:")
for p in self.list_top(topn):
ap(f" {p.wall_time:.4f}s ({p.wall_percent:5}%) {p.pass_name}")
return "\n".join(buf)
@cached_property
def _processed(self):
"""A cached property for lazily processing the data and returning it.
See ``_process()`` for details.
"""
return self._process()
def _process(self):
"""Parses the raw string data from LLVM timing report and attempts
to improve the data by recomputing the times
(See `_adjust_timings()``).
"""
def parse(raw_data):
"""A generator that parses the raw_data line-by-line to extract
timing information for each pass.
"""
lines = raw_data.splitlines()
colheader = r"[a-zA-Z+ ]+"
# Take at least one column header.
multicolheaders = fr"(?:\s*-+{colheader}-+)+"
line_iter = iter(lines)
# find column headers
header_map = {
"User Time": "user",
"System Time": "system",
"User+System": "user_system",
"Wall Time": "wall",
"Instr": "instruction",
"Name": "pass_name",
}
for ln in line_iter:
m = re.match(multicolheaders, ln)
if m:
# Get all the column headers
raw_headers = re.findall(r"[a-zA-Z][a-zA-Z+ ]+", ln)
headers = [header_map[k.strip()] for k in raw_headers]
break
assert headers[-1] == 'pass_name'
# compute the list of available attributes from the column headers
attrs = []
n = r"\s*((?:[0-9]+\.)?[0-9]+)"
pat = ""
for k in headers[:-1]:
if k == "instruction":
pat += n
else:
attrs.append(f"{k}_time")
attrs.append(f"{k}_percent")
pat += rf"\s+(?:{n}\s*\({n}%\)|-+)"
# put default value 0.0 to all missing attributes
missing = {}
for k in PassTimingRecord._fields:
if k not in attrs and k != 'pass_name':
missing[k] = 0.0
# parse timings
pat += r"\s*(.*)"
for ln in line_iter:
m = re.match(pat, ln)
if m is not None:
raw_data = list(m.groups())
data = {k: float(v) if v is not None else 0.0
for k, v in zip(attrs, raw_data)}
data.update(missing)
pass_name = raw_data[-1]
rec = PassTimingRecord(
pass_name=pass_name, **data,
)
yield rec
if rec.pass_name == "Total":
# "Total" means the report has ended
break
# Check that we have reach the end of the report
remaining = '\n'.join(line_iter)
if remaining:
raise ValueError(
f"unexpected text after parser finished:\n{remaining}"
)
# Parse raw data
records = list(parse(self._raw_data))
return _adjust_timings(records)
NamedTimings = namedtuple("NamedTimings", ["name", "timings"])
class PassTimingsCollection(Sequence):
"""A collection of pass timings.
This class implements the ``Sequence`` protocol for accessing the
individual timing records.
"""
def __init__(self, name):
self._name = name
self._records = []
@contextmanager
def record(self, name):
"""Record new timings and append to this collection.
Note: this is mainly for internal use inside the compiler pipeline.
See also ``RecordLLVMPassTimings``
Parameters
----------
name: str
Name for the records.
"""
if config.LLVM_PASS_TIMINGS:
# Recording of pass timings is enabled
with RecordLLVMPassTimings() as timings:
yield
rec = timings.get()
# Only keep non-empty records
if rec:
self._append(name, rec)
else:
# Do nothing. Recording of pass timings is disabled.
yield
def _append(self, name, timings):
"""Append timing records
Parameters
----------
name: str
Name for the records.
timings: ProcessedPassTimings
the timing records.
"""
self._records.append(NamedTimings(name, timings))
def get_total_time(self):
"""Computes the sum of the total time across all contained timings.
Returns
-------
res: float or None
Returns the total number of seconds or None if no timings were
recorded
"""
if self._records:
return sum(r.timings.get_total_time() for r in self._records)
else:
return None
def list_longest_first(self):
"""Returns the timings in descending order of total time duration.
Returns
-------
res: List[ProcessedPassTimings]
"""
return sorted(self._records,
key=lambda x: x.timings.get_total_time(),
reverse=True)
@property
def is_empty(self):
"""
"""
return not self._records
def summary(self, topn=5):
"""Return a string representing the summary of the timings.
Parameters
----------
topn: int; optional, default=5.
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
Returns
-------
res: str
See also ``ProcessedPassTimings.summary()``
"""
if self.is_empty:
return "No pass timings were recorded"
else:
buf = []
ap = buf.append
ap(f"Printing pass timings for {self._name}")
overall_time = self.get_total_time()
ap(f"Total time: {overall_time:.4f}")
for i, r in enumerate(self._records):
ap(f"== #{i} {r.name}")
percent = r.timings.get_total_time() / overall_time * 100
ap(f" Percent: {percent:.1f}%")
ap(r.timings.summary(topn=topn, indent=1))
return "\n".join(buf)
def __getitem__(self, i):
"""Get the i-th timing record.
Returns
-------
res: (name, timings)
A named tuple with two fields:
- name: str
- timings: ProcessedPassTimings
"""
return self._records[i]
def __len__(self):
"""Length of this collection.
"""
return len(self._records)
def __str__(self):
return self.summary()