Skip to content

Commit

Permalink
[Graph][Benchmark] Update benchmark function (#363)
Browse files Browse the repository at this point in the history
The old benchmarking function did not clear the l2 cache, so repeated
runs are biased.
This is especially prevalent in tuning for parallel-k parts, which
always selects k_parts=1 due to l2 cache hits, even when it is not the
fastest implementation.

---------

Co-authored-by: Allan Lin <allan.lin@centml.ai>
  • Loading branch information
Aalanli and Allan Lin committed Oct 12, 2023
1 parent 3272fc3 commit 82ddb8c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 46 deletions.
29 changes: 5 additions & 24 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.operator import Operator, SymbolVar
from hidet.utils.benchmark import do_bench

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -395,7 +396,7 @@ def f_run(inputs: List[Tensor]) -> List[Tensor]:
return CudaGraph(f_create_inputs, f_run, ref_objs=[self])

def latency(
self, warmup=1, number=3, repeat=3, median=True, dummy_inputs: Optional[Sequence[Tensor]] = None
self, warmup=25, repeat=100, dummy_inputs: Optional[Sequence[Tensor]] = None
) -> Union[float, List[float]]:
"""Measure the latency of the flow graph.
Expand All @@ -404,15 +405,9 @@ def latency(
warmup: int
The number of warmup runs.
number: int
The number of runs to measure the latency.
repeat: int
The number of times to repeat the measurement.
median: bool
Whether to return the median latency.
dummy_inputs: Optional[Sequence[Tensor]]
The dummy inputs to run the flow graph. If not given, automatic generated dummy inputs would be used.
Expand All @@ -421,26 +416,12 @@ def latency(
ret: Union[float, List[float]]
The measured latency in milliseconds.
"""
import time
import numpy as np

if dummy_inputs is None:
dummy_inputs = self.dummy_inputs()
for _ in range(warmup):
self.forward(dummy_inputs)
results = []
for _ in range(repeat):
hidet.cuda.synchronize()
t1 = time.time()
for _ in range(number):
self.forward(dummy_inputs)
hidet.cuda.synchronize()
t2 = time.time()
results.append((t2 - t1) * 1000 / number)
if median:
return float(np.median(results))
else:
return results

# return the median
return do_bench(lambda: self.forward(dummy_inputs), warmup=warmup, rep=repeat)[1]

@staticmethod
def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]:
Expand Down
33 changes: 13 additions & 20 deletions python/hidet/utils/benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,47 +33,40 @@ def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
"""

# Estimate the runtime of the function
import torch
import hidet

fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
hidet.cuda.synchronize()
start_event = hidet.cuda.Event(enable_timing=True)
end_event = hidet.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
hidet.cuda.synchronize()
estimate_ms = end_event.elapsed_time(start_event) / 5
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]

cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
start_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)]

# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
hidet.cuda.synchronize()
times = np.array([e.elapsed_time(s) for s, e in zip(start_event, end_event)])
if percentiles:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
percentiles = np.quantile(times, percentiles)
return tuple(percentiles)
else:
return torch.mean(times).item()
return np.mean(times).item()


def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
Expand Down
4 changes: 2 additions & 2 deletions scripts/regression/op_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def bench_matmul(m, n, k, dtype):
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, [a, b])
g = hidet.graph.optimize(g)
return g.latency(warmup=10, number=5, repeat=100)
return g.latency(warmup=25, repeat=100)

def bench_fmha(sq, skv, d):
hidet.option.search_space(2)
Expand All @@ -25,7 +25,7 @@ def bench_fmha(sq, skv, d):
o = hidet.ops.attention(q, k, v)
g = hidet.trace_from(o, [q, k, v])
g = hidet.graph.optimize(g)
return g.latency(warmup=10, number=5, repeat=100)
return g.latency(warmup=25, repeat=100)

def matmul_regression() -> ResultGroup:
regression_data = load_regression_data()
Expand Down

0 comments on commit 82ddb8c

Please sign in to comment.