Skip to content

Commit

Permalink
[CI] Update benchmark script (#160)
Browse files Browse the repository at this point in the history
.
  • Loading branch information
yaoyaoding committed Apr 5, 2023
1 parent 916feb4 commit 63d75a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/hidet/cli/bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional
import click
import torch
import hidet
from hidet.utils import initialize
from . import vision
from . import nlp
Expand Down Expand Up @@ -54,20 +55,29 @@
type=bool,
help='Set torch.backends.cuda.matmul.allow_tf32=True.',
)
@click.option(
'--cache-dir',
default=None,
type=click.Path(dir_okay=True, file_okay=False, writable=True),
help='The cache directory to store the generated kernels.',
)
def bench_group(
space: str,
dtype: str,
tensor_core: bool,
report: Optional[click.Path],
disable_torch_cudnn_tf32: bool,
enable_torch_cublas_tf32: bool,
cache_dir: Optional[click.Path],
):
BenchModel.search_space = int(space)
BenchModel.dtype = getattr(torch, dtype)
BenchModel.tensor_core = tensor_core
BenchModel.disable_torch_cudnn_tf32 = disable_torch_cudnn_tf32
BenchModel.enable_torch_cublas_tf32 = enable_torch_cublas_tf32
BenchModel.report_path = report
if cache_dir:
hidet.option.cache_dir(str(cache_dir))


@initialize()
Expand Down
1 change: 1 addition & 0 deletions scripts/bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reinstall_hidet():
subprocess.run(['mkdir', '-p', 'build'], check=True)
subprocess.run(['rm', '-rf', 'build/*'], check=True)
subprocess.run(['cmake', '-S', '.', '-B', 'build'], check=True)
subprocess.run(['cmake', '--build', 'build'], check=True)
subprocess.run(['pip', 'install', '-e', '.'], check=True)


Expand Down

0 comments on commit 63d75a7

Please sign in to comment.