import re
import operator
import heapq
from collections import namedtuple
from import Sequence
from contextlib import contextmanager
from numba.core.utils import cached_property
from numba.core import config
import llvmlite.binding as llvm
class RecordLLVMPassTimings:
"""A helper context manager to track LLVM pass timings.
__slots__ = ["_data"]
def __enter__(self):
"""Enables the pass timing in LLVM.
return self
def __exit__(self, exc_val, exc_type, exc_tb):
"""Reset timings and save report internally.
self._data = llvm.report_and_reset_timings()
def get(self):
"""Retrieve timing data for processing.
timings: ProcessedPassTimings
return ProcessedPassTimings(self._data)
PassTimingRecord = namedtuple(
def _adjust_timings(records):
"""Adjust timing records because of truncated information.
Details: The percent information can be used to improve the timing
res: List[PassTimingRecord]
total_rec = records[-1]
assert total_rec.pass_name == "Total" # guard for implementation error
def make_adjuster(attr):
time_attr = f"{attr}_time"
percent_attr = f"{attr}_percent"
time_getter = operator.attrgetter(time_attr)
def adjust(d):
"""Compute percent x total_time = adjusted"""
total = time_getter(total_rec)
adjusted = total * d[percent_attr] * 0.01
d[time_attr] = adjusted
return d
return adjust
# Make adjustment functions for each field
adj_fns = [
make_adjuster(x) for x in ["user", "system", "user_system", "wall"]
# Extract dictionaries from the namedtuples
dicts = map(lambda x: x._asdict(), records)
def chained(d):
# Chain the adjustment functions
for fn in adj_fns:
d = fn(d)
# Reconstruct the namedtuple
return PassTimingRecord(**d)
return list(map(chained, dicts))
class ProcessedPassTimings:
"""A class for processing raw timing report from LLVM.
The processing is done lazily so we don't waste time processing unused
timing information.
def __init__(self, raw_data):
self._raw_data = raw_data
def __bool__(self):
return bool(self._raw_data)
def get_raw_data(self):
"""Returns the raw string data.
res: str
return self._raw_data
def get_total_time(self):
"""Compute the total time spend in all passes.
res: float
return self.list_records()[-1].wall_time
def list_records(self):
"""Get the processed data for the timing report.
res: List[PassTimingRecord]
return self._processed
def list_top(self, n):
"""Returns the top(n) most time-consuming (by wall-time) passes.
n: int
This limits the maximum number of items to show.
This function will show the ``n`` most time-consuming passes.
res: List[PassTimingRecord]
Returns the top(n) most time-consuming passes in descending order.
records = self.list_records()
key = operator.attrgetter("wall_time")
return heapq.nlargest(n, records[:-1], key)
def summary(self, topn=5, indent=0):
"""Return a string summarizing the timing information.
topn: int; optional
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
indent: int; optional
Set the indentation level. Defaults to 0 for no indentation.
res: str
buf = []
prefix = " " * indent
def ap(arg):
ap(f"Total {self.get_total_time():.4f}s")
ap("Top timings:")
for p in self.list_top(topn):
ap(f" {p.wall_time:.4f}s ({p.wall_percent:5}%) {p.pass_name}")
return "\n".join(buf)
def _processed(self):
"""A cached property for lazily processing the data and returning it.
See ``_process()`` for details.
return self._process()
def _process(self):
"""Parses the raw string data from LLVM timing report and attempts
to improve the data by recomputing the times
(See `_adjust_timings()``).
def parse(raw_data):
"""A generator that parses the raw_data line-by-line to extract
timing information for each pass.
lines = raw_data.splitlines()
colheader = r"[a-zA-Z+ ]+"
# Take at least one column header.
multicolheaders = fr"(?:\s*-+{colheader}-+)+"
line_iter = iter(lines)
# find column headers
header_map = {
"User Time": "user",
"System Time": "system",
"User+System": "user_system",
"Wall Time": "wall",
"Name": "pass_name",
for ln in line_iter:
m = re.match(multicolheaders, ln)
if m:
# Get all the column headers
raw_headers = re.findall(r"[a-zA-Z][a-zA-Z+ ]+", ln)
headers = [header_map[k.strip()] for k in raw_headers]
assert headers[-1] == 'pass_name'
# compute the list of available attributes from the column headers
attrs = []
for k in headers[:-1]:
# put default value 0.0 to all missing attributes
missing = {}
for k in PassTimingRecord._fields:
if k not in attrs and k != 'pass_name':
missing[k] = 0.0
# parse timings
n = r"\s*((?:[0-9]+\.)?[0-9]+)"
pat = f"\\s+(?:{n}\\s*\\({n}%\\)|-+)" * (len(headers) - 1)
pat += r"\s*(.*)"
for ln in line_iter:
m = re.match(pat, ln)
if m is not None:
raw_data = list(m.groups())
data = {k: float(v) if v is not None else 0.0
for k, v in zip(attrs, raw_data)}
pass_name = raw_data[-1]
rec = PassTimingRecord(
pass_name=pass_name, **data,
yield rec
if rec.pass_name == "Total":
# "Total" means the report has ended
# Check that we have reach the end of the report
remaining = '\n'.join(line_iter)
if remaining:
raise ValueError(
f"unexpected text after parser finished:\n{remaining}"
# Parse raw data
records = list(parse(self._raw_data))
return _adjust_timings(records)
NamedTimings = namedtuple("NamedTimings", ["name", "timings"])
class PassTimingsCollection(Sequence):
"""A collection of pass timings.
This class implements the ``Sequence`` protocol for accessing the
individual timing records.
def __init__(self, name):
self._name = name
self._records = []
def record(self, name):
"""Record new timings and append to this collection.
Note: this is mainly for internal use inside the compiler pipeline.
See also ``RecordLLVMPassTimings``
name: str
Name for the records.
# Recording of pass timings is enabled
with RecordLLVMPassTimings() as timings:
rec = timings.get()
# Only keep non-empty records
if rec:
self._append(name, rec)
# Do nothing. Recording of pass timings is disabled.
def _append(self, name, timings):
"""Append timing records
name: str
Name for the records.
timings: ProcessedPassTimings
the timing records.
self._records.append(NamedTimings(name, timings))
def get_total_time(self):
"""Computes the sum of the total time across all contained timings.
res: float or None
Returns the total number of seconds or None if no timings were
if self._records:
return sum(r.timings.get_total_time() for r in self._records)
return None
def list_longest_first(self):
"""Returns the timings in descending order of total time duration.
res: List[ProcessedPassTimings]
return sorted(self._records,
key=lambda x: x.timings.get_total_time(),
def is_empty(self):
return not self._records
def summary(self, topn=5):
"""Return a string representing the summary of the timings.
topn: int; optional, default=5.
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
res: str
See also ``ProcessedPassTimings.summary()``
if self.is_empty:
return "No pass timings were recorded"
buf = []
ap = buf.append
ap(f"Printing pass timings for {self._name}")
overall_time = self.get_total_time()
ap(f"Total time: {overall_time:.4f}")
for i, r in enumerate(self._records):
ap(f"== #{i} {}")
percent = r.timings.get_total_time() / overall_time * 100
ap(f" Percent: {percent:.1f}%")
ap(r.timings.summary(topn=topn, indent=1))
return "\n".join(buf)
def __getitem__(self, i):
"""Get the i-th timing record.
res: (name, timings)
A named tuple with two fields:
- name: str
- timings: ProcessedPassTimings
return self._records[i]
def __len__(self):
"""Length of this collection.
return len(self._records)
def __str__(self):
return self.summary()