Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Remove unnecessary synchronization #420

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/scripts/bench/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def bench_matmul_f16(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_batch_matmul(params: str, *args, **kwargs) -> float:
# Default to benchmarking f32 for now, though this op can run other dtypes
Expand All @@ -28,7 +28,7 @@ def bench_batch_matmul(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
Expand All @@ -40,7 +40,7 @@ def bench_conv2d(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
Expand All @@ -52,7 +52,7 @@ def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_attn(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
Expand All @@ -66,7 +66,7 @@ def bench_attn(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(o, inputs=[q, k, v])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_attn_mask_add(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
Expand All @@ -82,7 +82,7 @@ def bench_attn_mask_add(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(o, inputs=[q, k, v, mask])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

def bench_reduce(params: str, *args, **kwargs) -> float:
x_shape, axis = params.split(',', maxsplit=1)
Expand All @@ -95,7 +95,7 @@ def bench_reduce(params: str, *args, **kwargs) -> float:
g = hidet.trace_from(o, inputs=[x])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(g, [])
return bench_torch_model(lambda: g.run_async(), [])

bench_func_map = {
'matmul_f16': bench_matmul_f16,
Expand Down
13 changes: 7 additions & 6 deletions .github/scripts/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10):
return latency

def enable_compile_server(enable=True):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=enable)
if os.environ.get('CI_CS_HOSTNAME'):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=enable)