Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 76 additions & 82 deletions codeflash/verification/parse_line_profile_test_output.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,82 @@
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
import linecache
import inspect
from codeflash.code_utils.tabulate import tabulate
import os
import dill as pickle
from pathlib import Path
from typing import Optional
from rich.console import Console
from rich.table import Table

def show_func(filename, start_lineno, func_name, timings, unit):
total_hits = sum(t[1] for t in timings)
total_time = sum(t[2] for t in timings)
out_table = ""
table_rows = []
if total_hits == 0:
return ''
scalar = 1
if os.path.exists(filename):
out_table+=f'## Function: {func_name}\n'
# Clear the cache to ensure that we get up-to-date results.
linecache.clearcache()
all_lines = linecache.getlines(filename)
sublines = inspect.getblock(all_lines[start_lineno - 1:])
out_table+='## Total time: %g s\n' % (total_time * unit)
# Define minimum column sizes so text fits and usually looks consistent
default_column_sizes = {
'hits': 9,
'time': 12,
'perhit': 8,
'percent': 8,
}
display = {}
# Loop over each line to determine better column formatting.
# Fallback to scientific notation if columns are larger than a threshold.
for lineno, nhits, time in timings:
if total_time == 0: # Happens rarely on empty function
percent = ''
else:
percent = '%5.1f' % (100 * time / total_time)

time_disp = '%5.1f' % (time * scalar)
if len(time_disp) > default_column_sizes['time']:
time_disp = '%5.1g' % (time * scalar)
perhit_disp = '%5.1f' % (float(time) * scalar / nhits)
if len(perhit_disp) > default_column_sizes['perhit']:
perhit_disp = '%5.1g' % (float(time) * scalar / nhits)
nhits_disp = "%d" % nhits
if len(nhits_disp) > default_column_sizes['hits']:
nhits_disp = '%g' % nhits
display[lineno] = (nhits_disp, time_disp, perhit_disp, percent)
linenos = range(start_lineno, start_lineno + len(sublines))
empty = ('', '', '', '')
table_cols = ('Hits', 'Time', 'Per Hit', '% Time', 'Line Contents')
for lineno, line in zip(linenos, sublines):
nhits, time, per_hit, percent = display.get(lineno, empty)
line_ = line.rstrip('\n').rstrip('\r')
if 'def' in line_ or nhits!='':
table_rows.append((nhits, time, per_hit, percent, line_))
pass
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table+='\n'
return out_table
class LineProfileFormatter:
def __init__(self, unit: float = 1.0):
self.unit = unit
self.console = Console(record=True)

def show_text(stats: dict) -> str:
""" Show text for the given timings.
"""
out_table = ""
out_table+='# Timer unit: %g s\n' % stats['unit']
stats_order = sorted(stats['timings'].items())
# Show detailed per-line information for each function.
for (fn, lineno, name), timings in stats_order:
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table+=table_md
return out_table
def format_time(self, time: float) -> str:
return f"{time * self.unit:5.1f}"

def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
line_profiler_output_file = line_profiler_output_file.with_suffix(".lprof")
stats_dict = {}
if not line_profiler_output_file.exists():
return {'timings':{},'unit':0, 'str_out':''}, None
else:
with open(line_profiler_output_file,'rb') as f:
stats = pickle.load(f)
stats_dict['timings'] = stats.timings
stats_dict['unit'] = stats.unit
str_out=show_text(stats_dict)
stats_dict['str_out']=str_out
return stats_dict, None
def format_per_hit(self, time: float, hits: int) -> str:
return f"{(time * self.unit) / hits:5.1f}" if hits else "0.0"

def format_percent(self, part: float, total: float) -> str:
return f"{100 * part / total:5.1f}" if total else ""

def show_func(
self,
filename: str,
start_lineno: int,
func_name: str,
timings: list[tuple[int, int, float]],
) -> str:
total_time = sum(t[2] for t in timings)
if not total_time:
return ""

table = self._create_timing_table(filename, start_lineno, timings, total_time)
self.console.print(table)
return f"## Function: {func_name}\n## Total time: {total_time * self.unit:.6g} s\n{self.console.export_text()}\n"

def _create_timing_table(
self,
filename: str,
start_lineno: int,
timings: list[tuple[int, int, float]],
total_time: float,
) -> Table:
table = Table(show_header=True, header_style="bold")
table.add_column("Hits")
table.add_column("Time")
table.add_column("Per Hit")
table.add_column("% Time")
table.add_column("Line Contents", no_wrap=True)

source_lines = self._get_source_lines(filename, start_lineno)
if isinstance(source_lines, str):
return Table(title=source_lines)

for lineno, nhits, time in timings:
line = source_lines[lineno - start_lineno].rstrip("\n").rstrip("\r")
table.add_row(
str(nhits),
self.format_time(time),
self.format_per_hit(time, nhits),
self.format_percent(time, total_time),
line,
)
return table

def _get_source_lines(self, filename: str, start_lineno: int) -> list[str] | str:
try:
return inspect.getblock(
linecache.getlines(str(filename))[start_lineno - 1 :]
)
except Exception:
return f"File not found: {filename}"


def show_func(
filename: str,
start_lineno: int,
func_name: str,
timings: list[tuple[int, int, float]],
unit: float = 1.0,
) -> str:
formatter = LineProfileFormatter(unit)
return formatter.show_func(filename, start_lineno, func_name, timings)
Loading