Skip to content

Commit

Permalink
[CI] Benchmark periodically (#155)
Browse files Browse the repository at this point in the history
* .

* .

* .

* .

* .

* .

* .

* .

* .

* .
  • Loading branch information
yaoyaoding committed Apr 5, 2023
1 parent 25657c9 commit 3b764ab
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 34 deletions.
67 changes: 67 additions & 0 deletions .github/workflows/benchmark.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: Benchmark
on:
workflow_dispatch:
pull_request:
schedule:
- cron: "0 3 * * *"
jobs:
benchmark:
if: github.repository == 'hidet-org/hidet'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref }}
cancel-in-progress: true
runs-on: [self-hosted, Linux, X64, gpu]
container:
image: nvidia/cuda:11.8.0-devel-ubuntu20.04
options: --gpus all
steps:
- name: Install dependencies via apt
run: |
apt update && DEBIAN_FRONTEND=noninteractive apt install -y ccache git graphviz
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"

- name: Setup cmake
uses: jwlawson/actions-setup-cmake@v1.13
with:
cmake-version: '3.19.x'

- name: Install dependencies via pip
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Build hidet
run: |
bash scripts/wheel/build_wheel.sh
WHEEL=$(find ./scripts/wheel/built_wheel -maxdepth 1 -name '*.whl')
echo "WHEEL_NAME=$WHEEL" >> $GITHUB_ENV
echo "Built wheel: ${{ env.WHEEL_NAME }}"
- name: Install hidet
run: |
pip install --no-deps --force-reinstall ${{ env.WHEEL_NAME }}
- name: Install dependencies
run: |
pip install -r requirements.txt
- name: Run benchmark
run: |
python scripts/ci/benchmark.py --report result.txt --space 0 --git-commit ${{ github.sha }}
- name: Post result
uses: peter-evans/create-or-update-comment@v2
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
issue-number: 154
body-file: ./result.txt
40 changes: 36 additions & 4 deletions python/hidet/cli/bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import click
import torch
from hidet.utils import initialize
from . import vision
from . import nlp
Expand All @@ -27,15 +29,45 @@
help='Schedule space. 0: default schedule. 1: small schedule space. 2: large schedule space.',
)
@click.option(
'--torch-tf32',
'--dtype', default='float32', show_default=True, type=click.Choice(['float32', 'float16']), help='Data type to use.'
)
@click.option(
'--tensor-core',
default=False,
show_default=True,
is_flag=True,
type=bool,
help='Whether to use tensor core in hidet.',
)
@click.option('--report', type=click.Path(exists=False, dir_okay=False, writable=True), help='Report file path.')
@click.option(
'--disable-torch-cudnn-tf32',
default=False,
is_flag=True,
type=bool,
help='Set torch.backends.cudnn.allow_tf32=False.',
)
@click.option(
'--enable-torch-cublas-tf32',
default=False,
is_flag=True,
type=bool,
help='Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32.',
help='Set torch.backends.cuda.matmul.allow_tf32=True.',
)
def bench_group(space: str, torch_tf32: bool):
def bench_group(
space: str,
dtype: str,
tensor_core: bool,
report: Optional[click.Path],
disable_torch_cudnn_tf32: bool,
enable_torch_cublas_tf32: bool,
):
BenchModel.search_space = int(space)
BenchModel.allow_tf32 = torch_tf32
BenchModel.dtype = getattr(torch, dtype)
BenchModel.tensor_core = tensor_core
BenchModel.disable_torch_cudnn_tf32 = disable_torch_cudnn_tf32
BenchModel.enable_torch_cublas_tf32 = enable_torch_cublas_tf32
BenchModel.report_path = report


@initialize()
Expand Down
1 change: 0 additions & 1 deletion python/hidet/cli/bench/bench_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ def bench_all():
result = [bench_model.benchmark() for bench_model in all_registered_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
1 change: 0 additions & 1 deletion python/hidet/cli/bench/bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ def bench_common():
result = [bench_model.benchmark() for bench_model in commonly_used_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
57 changes: 43 additions & 14 deletions python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=ungrouped-imports, no-name-in-module
from typing import List, Any
from typing import List, Any, Optional
import click
import torch
from hidet.testing import benchmark_func
import hidet


class BenchModel:
search_space = 0
allow_tf32 = False
dtype: torch.dtype = torch.float32
tensor_core: bool = False
disable_torch_cudnn_tf32 = False
enable_torch_cublas_tf32 = False
report_path: Optional[str] = None

def __str__(self):
raise NotImplementedError()
Expand All @@ -44,6 +50,24 @@ def example_inputs(self):
"""
raise NotImplementedError()

def converted_model(self):
model = self.model().eval()
model = model.to(dtype=BenchModel.dtype)
model = model.cuda()
return model

def converted_inputs(self):
args, kwargs = self.example_inputs()

def convert_f32(arg):
if arg.dtype == torch.float32:
return arg.to(dtype=BenchModel.dtype)
return arg

args = [convert_f32(arg.cuda()) for arg in args]
kwargs = {k: convert_f32(v.cuda()) for k, v in kwargs.items()}
return args, kwargs

@staticmethod
def tensor_str(tensor) -> str:
"""
Expand Down Expand Up @@ -71,29 +95,33 @@ def inputs_str(self) -> str:
ret: str
The string representation of the inputs to the model.
"""
args, kwargs = self.example_inputs()
args, kwargs = self.converted_inputs()
items = []
for arg in args:
items.append(self.tensor_str(arg))
for k, v in kwargs.items():
items.append('{}={}'.format(k, self.tensor_str(v)))
return ', '.join(items)

@classmethod
def report_table(cls, table_str):
if cls.report_path is not None:
with open(cls.report_path, 'w') as f:
click.echo(table_str, file=f)
click.echo(table_str)

def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repeat=10):
import torch.backends.cudnn
import torch.backends.cudnn # pylint: disable=redefined-outer-name
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

torch.backends.cudnn.allow_tf32 = self.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = self.allow_tf32
torch.backends.cudnn.allow_tf32 = not self.disable_torch_cudnn_tf32
torch.backends.cuda.matmul.allow_tf32 = self.enable_torch_cublas_tf32

model, (args, kwargs) = self.model(), self.example_inputs()
model = model.cuda().eval()
args = [arg.cuda() for arg in args]
kwargs = {k: v.cuda() for k, v in kwargs.items()}
model, (args, kwargs) = self.converted_model(), self.converted_inputs()
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode)
Expand All @@ -115,6 +143,7 @@ def bench_hidet(self, use_cuda_graph=True, use_fp16=False, use_fp16_reduction=Fa
config = hidet.torch.dynamo_config
config.search_space(self.search_space)
config.use_cuda_graph(use_cuda_graph)
config.use_tensor_core(self.tensor_core)
config.use_fp16(use_fp16)
config.use_fp16_reduction(use_fp16_reduction)
return self.bench_with_backend('hidet')
Expand All @@ -125,9 +154,9 @@ def headers() -> List[str]:
'model',
'inputs',
'eager',
'inductor(mode=reduce-overhead)',
# 'inductor(mode=max-autotune)'
'hidet(space={})'.format(BenchModel.search_space),
'reduce-overhead',
'max-autotune',
'hidet({})'.format(BenchModel.search_space),
]

def benchmark(self) -> List[Any]:
Expand All @@ -136,7 +165,7 @@ def benchmark(self) -> List[Any]:
self.inputs_str(),
self.bench_eager(),
self.bench_inductor('reduce-overhead'),
# self.bench_inductor('max-autotune'),
self.bench_inductor('max-autotune'),
self.bench_hidet(),
]

Expand Down
5 changes: 3 additions & 2 deletions python/hidet/cli/bench/nlp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ def bench_nlp(models: str, batch_size: int, seq_length: int):
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
BenchModel.report_table(
tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')
)
16 changes: 12 additions & 4 deletions python/hidet/cli/bench/nlp/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from hidet.cli.bench.model import BenchModel


Expand All @@ -24,15 +25,22 @@ def __str__(self):
return '{}/{}'.format(self.model_name, self.label)

def model(self):
import torch

return torch.hub.load(self.repo_name, self.model_name, self.label)

def example_inputs(self):
import torch

tokens_tensor = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda')
segments_tensors = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda')
args = (tokens_tensor,)
kwargs = {'token_type_ids': segments_tensors}
return args, kwargs

def inputs_str(self) -> str:
if self.dtype == torch.float16:
dtype = 'f16'
elif self.dtype == torch.float32:
dtype = 'f32'
elif self.dtype == torch.float64:
dtype = 'f64'
else:
raise ValueError('Unknown dtype: {}'.format(self.dtype))
return f'{dtype}, bs={self.batch_size}, seq={self.sequence_length}'
5 changes: 3 additions & 2 deletions python/hidet/cli/bench/vision/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ def bench_inception_v3(batch_size: int, channels: int, height: int, width: int):
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
BenchModel.report_table(
tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')
)
5 changes: 3 additions & 2 deletions python/hidet/cli/bench/vision/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ def bench_mobilenet_v2(batch_size: int, channels: int, height: int, width: int):
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
BenchModel.report_table(
tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')
)
5 changes: 3 additions & 2 deletions python/hidet/cli/bench/vision/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@ def bench_resnet(models: str, batch_size: int, channels: int, height: int, width
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
BenchModel.report_table(
tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')
)
5 changes: 3 additions & 2 deletions python/hidet/cli/bench/vision/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ def bench_resnext(models: str, batch_size: int, channels: int, height: int, widt
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
BenchModel.report_table(
tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')
)
3 changes: 3 additions & 0 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
use_cuda_graph = dynamo_config['use_cuda_graph']
search_space = dynamo_config['search_space']
parallel_k = dynamo_config['parallel_k']
tensor_core = dynamo_config['use_tensor_core']
save_dir = dynamo_config['dump_graph_ir']

with PassContext() as ctx:
Expand All @@ -41,6 +42,8 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
ctx.set_reduce_precision('float16')
if save_dir:
ctx.save_graph_instrument(save_dir)
if tensor_core:
ctx.set_mma('mma' if tensor_core else 'simt')
ctx.set_parallel_k(disabled=(parallel_k == 'disabled'), search=(parallel_k == 'search'))
logger.info('start to optimize the flow graph')
graph_opt: FlowGraph = optimize(flow_graph)
Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pytz

# python test
pytest==7.2

Expand All @@ -11,6 +13,8 @@ pylint==2.13.9
torch
torchvision
transformers
sentencepiece
sacremoses

# check the correctness with onnxruntime
onnx
Expand Down

0 comments on commit 3b764ab

Please sign in to comment.