Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Graph][Benchmark] Update benchmark function #363

Merged
merged 11 commits into from
Oct 12, 2023
1 change: 1 addition & 0 deletions python/hidet/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
from .capability import capability
from .device import available, device_count, synchronize, compute_capability, properties, profiler_start, profiler_stop
from .device import create_event, event_record, event_elapsed_time
from .device import cudaDeviceProp, set_device, current_device, device
from .stream import Stream, ExternalStream, stream, default_stream, current_stream
from .memory import malloc, free, malloc_async, free_async, malloc_host, free_host, memcpy_peer, memcpy_peer_async
Expand Down
45 changes: 45 additions & 0 deletions python/hidet/cuda/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,48 @@ def profiler_stop():
"""
(err,) = cudart.cudaProfilerStop()
assert err == 0, err


def create_event():
"""
Create an event object.
"""
(err, event) = cudart.cudaEventCreate()
assert err == 0, err
return event


def event_record(event, stream=0):
"""
Record an event.

Parameters
----------
event: int
The event to record.
stream: int
The stream to record the event.
"""
(err,) = cudart.cudaEventRecord(event, stream)
assert err == 0, err


def event_elapsed_time(start, end):
"""
Compute the elapsed time between two events.

Parameters
----------
start: int
The start event.
end: int
The end event.

Returns
-------
elapsed: float
The elapsed time in milliseconds.
"""
(err, t) = cudart.cudaEventElapsedTime(start, end)
assert err == 0, err
return t
Aalanli marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 5 additions & 24 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.operator import Operator, SymbolVar
from hidet.utils.benchmark import do_bench

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -395,7 +396,7 @@ def f_run(inputs: List[Tensor]) -> List[Tensor]:
return CudaGraph(f_create_inputs, f_run, ref_objs=[self])

def latency(
self, warmup=1, number=3, repeat=3, median=True, dummy_inputs: Optional[Sequence[Tensor]] = None
self, warmup=25, repeat=100, dummy_inputs: Optional[Sequence[Tensor]] = None
) -> Union[float, List[float]]:
"""Measure the latency of the flow graph.

Expand All @@ -404,15 +405,9 @@ def latency(
warmup: int
The number of warmup runs.

number: int
The number of runs to measure the latency.

repeat: int
The number of times to repeat the measurement.

median: bool
Whether to return the median latency.

dummy_inputs: Optional[Sequence[Tensor]]
The dummy inputs to run the flow graph. If not given, automatic generated dummy inputs would be used.

Expand All @@ -421,26 +416,12 @@ def latency(
ret: Union[float, List[float]]
The measured latency in milliseconds.
"""
import time
import numpy as np

if dummy_inputs is None:
dummy_inputs = self.dummy_inputs()
for _ in range(warmup):
self.forward(dummy_inputs)
results = []
for _ in range(repeat):
hidet.cuda.synchronize()
t1 = time.time()
for _ in range(number):
self.forward(dummy_inputs)
hidet.cuda.synchronize()
t2 = time.time()
results.append((t2 - t1) * 1000 / number)
if median:
return float(np.median(results))
else:
return results

# return the median
return do_bench(lambda: self.forward(dummy_inputs), warmup=warmup, rep=repeat)[1]

@staticmethod
def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]:
Expand Down
34 changes: 16 additions & 18 deletions python/hidet/utils/benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,47 +33,45 @@ def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
"""

# Estimate the runtime of the function
import torch
import hidet

fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
hidet.cuda.synchronize()
start_event = hidet.cuda.create_event()
end_event = hidet.cuda.create_event()
hidet.cuda.event_record(start_event)
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
hidet.cuda.event_record(end_event)
hidet.cuda.synchronize()
estimate_ms = hidet.cuda.event_elapsed_time(start_event, end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
start_event = [hidet.cuda.create_event() for i in range(n_repeat)]
end_event = [hidet.cuda.create_event() for i in range(n_repeat)]

cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
hidet.cuda.event_record(start_event[i])
fn()
end_event[i].record()
hidet.cuda.event_record(end_event[i])
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
hidet.cuda.synchronize()
times = np.array([hidet.cuda.event_elapsed_time(s, e) for s, e in zip(start_event, end_event)])
if percentiles:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
percentiles = np.quantile(times, percentiles)
return tuple(percentiles)
else:
return torch.mean(times).item()
return np.mean(times).item()


def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
Expand Down
4 changes: 2 additions & 2 deletions scripts/regression/op_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def bench_matmul(m, n, k, dtype):
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, [a, b])
g = hidet.graph.optimize(g)
return g.latency(warmup=10, number=5, repeat=100)
return g.latency(warmup=25, repeat=100)

def bench_fmha(sq, skv, d):
hidet.option.search_space(2)
Expand All @@ -25,7 +25,7 @@ def bench_fmha(sq, skv, d):
o = hidet.ops.attention(q, k, v)
g = hidet.trace_from(o, [q, k, v])
g = hidet.graph.optimize(g)
return g.latency(warmup=10, number=5, repeat=100)
return g.latency(warmup=25, repeat=100)

def matmul_regression() -> ResultGroup:
regression_data = load_regression_data()
Expand Down