Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs][Opeartor] Add more pytorch operator bindings and docs #50

Merged
merged 2 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/python_api/data_types.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
hidet.dtypes
============

.. data:: hidet.uint8
.. data:: hidet.uint16
.. data:: hidet.uint32
.. data:: hidet.uint64
.. data:: hidet.int8
.. data:: hidet.int16
.. data:: hidet.int32
.. data:: hidet.int64
.. data:: hidet.float16
.. data:: hidet.float32
.. data:: hidet.float64
.. data:: hidet.bfloat16
.. data:: hidet.tfloat32
1 change: 1 addition & 0 deletions docs/source/python_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Python API
driver
cuda
tensor
data_types
ops/index
ir/index
graph/index
Expand Down
6 changes: 3 additions & 3 deletions docs/source/python_api/ir/type.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
hidet.ir.type
=============

.. automodule:: hidet.ir.type
:members:
:autosummary:
.. autoclass:: hidet.ir.type.DataType

.. autofunction:: hidet.ir.type.data_type
2 changes: 1 addition & 1 deletion python/hidet/cli/bench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hidet.graph.frontend.torch import availability as torch_availability
from .root import bench_group
from .bench import bench_group

if not torch_availability.dynamo_available():
raise RuntimeError(
Expand Down
42 changes: 42 additions & 0 deletions python/hidet/cli/bench/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import click
from hidet.utils import initialize
from . import vision
from . import nlp
from .model import BenchModel
from .bench_common import bench_common
from .bench_all import bench_all


@click.group(name='bench', help='Benchmark models.')
@click.option(
'--space',
default='0',
show_default=True,
type=click.Choice(['0', '1', '2']),
help='Schedule space. 0: default schedule. 1: small schedule space. 2: large schedule space.',
)
@click.option(
'--torch-tf32',
default=False,
show_default=True,
type=bool,
help='Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32.',
)
def bench_group(space: str, torch_tf32: bool):
BenchModel.search_space = int(space)
BenchModel.allow_tf32 = torch_tf32


@initialize()
def register_commands():
for command in [
bench_common,
bench_all,
vision.bench_resnet,
vision.bench_resnext,
vision.bench_inception_v3,
vision.bench_mobilenet_v2,
nlp.bench_nlp,
]:
assert isinstance(command, click.Command)
bench_group.add_command(command)
12 changes: 12 additions & 0 deletions python/hidet/cli/bench/bench_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import click
from tabulate import tabulate
from hidet.cli.bench.model import BenchModel, all_registered_models


@click.command(name='all')
def bench_all():
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in all_registered_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
12 changes: 12 additions & 0 deletions python/hidet/cli/bench/bench_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import click
from tabulate import tabulate
from hidet.cli.bench.model import BenchModel, commonly_used_models


@click.command(name='common')
def bench_common():
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in commonly_used_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from hidet.testing import benchmark_func
import hidet

if hidet.torch.dynamo_available():
import torch._dynamo as dynamo
else:
dynamo = None


class BenchModel:
search_space = 0
allow_tf32 = False

def __str__(self):
raise NotImplementedError()

Expand Down Expand Up @@ -71,7 +69,16 @@ def inputs_str(self) -> str:
return ', '.join(items)

def bench_with_backend(self, backend: str, mode=None, passes=None, warmup=3, number=10, repeat=10):
import torch
import torch.backends.cudnn
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

hidet.torch.register_dynamo_backends()
torch.backends.cudnn.allow_tf32 = self.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = self.allow_tf32

model, (args, kwargs) = self.model(), self.example_inputs()
model = model.cuda().eval()
Expand All @@ -86,31 +93,41 @@ def bench_with_backend(self, backend: str, mode=None, passes=None, warmup=3, num
return latency

def bench_eager(self) -> float:
print('Benchmarking {} with backend {}...'.format(self, 'eager'))
return self.bench_with_backend('eager')

def bench_inductor(self) -> float:
return self.bench_with_backend('inductor', mode='max-autotune')
def bench_inductor(self, mode: str) -> float:
print('Benchmarking {} with backend {}...'.format(self, 'inductor(mode={})'.format(mode)))
return self.bench_with_backend('inductor', mode=mode)

def bench_hidet(self, use_cuda_graph=True, use_fp16=False, use_fp16_reduction=False, space=2) -> float:
def bench_hidet(self, use_cuda_graph=True, use_fp16=False, use_fp16_reduction=False) -> float:
print('Benchmarking {} with backend {}...'.format(self, 'hidet(space={})'.format(self.search_space)))
config = hidet.torch.dynamo_config
config.search_space(space)
config.search_space(self.search_space)
config.use_cuda_graph(use_cuda_graph)
config.use_fp16(use_fp16)
config.use_fp16_reduction(use_fp16_reduction)
return self.bench_with_backend('hidet')

@staticmethod
def headers() -> List[str]:
return ['model', 'inputs', 'eager', 'inductor', 'hidet', 'hidet_f16']
return [
'model',
'inputs',
'eager',
'inductor(mode=reduce-overhead)',
# 'inductor(mode=max-autotune)'
'hidet(space={})'.format(BenchModel.search_space),
]

def benchmark(self) -> List[Any]:
return [
str(self),
self.inputs_str(),
self.bench_eager(),
self.bench_inductor(),
self.bench_inductor('reduce-overhead'),
# self.bench_inductor('max-autotune'),
self.bench_hidet(),
self.bench_hidet(use_fp16=True),
]


Expand Down
1 change: 0 additions & 1 deletion python/hidet/cli/bench/models/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions python/hidet/cli/bench/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .models import bench_nlp
51 changes: 51 additions & 0 deletions python/hidet/cli/bench/nlp/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import click
from tabulate import tabulate
from hidet.cli.bench.model import commonly_used_models, all_registered_models
from hidet.utils import initialize
from .nlp_model import NLPModel, BenchModel


available_models = ['bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased', 'gpt2']

commonly_used_models.extend(
[
NLPModel('huggingface/pytorch-transformers', 'model', 'bert-base-uncased', 1, 128),
NLPModel('huggingface/pytorch-transformers', 'model', 'gpt2', 1, 128),
]
)


@initialize()
def initialize_models():
for model in available_models:
for batch_size in [1, 8]:
for seq_length in [128, 512]:
all_registered_models.append(
NLPModel('huggingface/pytorch-transformers', 'model', model, batch_size, seq_length)
)


@click.command(name='nlp')
@click.option(
'--models',
type=str,
default='bert-base-uncased',
show_default=True,
help='Comma seperated models to benchmark. Available models: {}'.format(', '.join(available_models)),
)
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size')
@click.option('-q', '--seq-length', default=128, show_default=True, help='Sequence length')
def bench_nlp(models: str, batch_size: int, seq_length: int):
models = [model.strip() for model in models.split(',')]
for model in models:
if model not in available_models:
raise ValueError('Unknown model: {}, candidates: {}'.format(model, list(available_models)))

bench_models = [
NLPModel('huggingface/pytorch-transformers', 'model', model, batch_size, seq_length) for model in models
]
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
27 changes: 27 additions & 0 deletions python/hidet/cli/bench/nlp/nlp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from hidet.cli.bench.model import BenchModel


class NLPModel(BenchModel):
def __init__(self, repo_name, model_name, label, batch_size: int, sequence_length: int):
self.repo_name = repo_name
self.model_name = model_name
self.label = label
self.batch_size = batch_size
self.sequence_length = sequence_length

def __str__(self):
return '{}/{}'.format(self.model_name, self.label)

def model(self):
import torch

return torch.hub.load(self.repo_name, self.model_name, self.label)

def example_inputs(self):
import torch

tokens_tensor = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda')
segments_tensors = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda')
args = (tokens_tensor,)
kwargs = {'token_type_ids': segments_tensors}
return args, kwargs
15 changes: 0 additions & 15 deletions python/hidet/cli/bench/root.py

This file was deleted.

4 changes: 4 additions & 0 deletions python/hidet/cli/bench/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .inception_v3 import bench_inception_v3
from .mobilenet_v2 import bench_mobilenet_v2
from .resnet import bench_resnet
from .resnext import bench_resnext
25 changes: 25 additions & 0 deletions python/hidet/cli/bench/vision/inception_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import click
from tabulate import tabulate
from hidet.cli.bench.model import BenchModel, all_registered_models, commonly_used_models
from .vision_model import VisionModel


inception_models = {'inception_v3': VisionModel('inception_v3', 1, 3, 299, 299)}


all_registered_models.extend(inception_models.values())
commonly_used_models.append(inception_models['inception_v3'])


@click.command(name='inception-v3')
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size')
@click.option('-c', '--channels', default=3, show_default=True, help='Input channels')
@click.option('-h', '--height', default=224, show_default=True, help='Input image height')
@click.option('-w', '--width', default=224, show_default=True, help='Input image width')
def bench_inception_v3(batch_size: int, channels: int, height: int, width: int):
bench_models = [VisionModel('inception_v3', batch_size, channels, height, width)]
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))
25 changes: 25 additions & 0 deletions python/hidet/cli/bench/vision/mobilenet_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import click
from tabulate import tabulate
from hidet.cli.bench.model import BenchModel, all_registered_models, commonly_used_models
from .vision_model import VisionModel


inception_models = {'mobilenet_v2': VisionModel('mobilenet_v2', 1, 3, 224, 224)}


all_registered_models.extend(inception_models.values())
commonly_used_models.append(inception_models['mobilenet_v2'])


@click.command(name='mobilenet-v2')
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size')
@click.option('-c', '--channels', default=3, show_default=True, help='Input channels')
@click.option('-h', '--height', default=224, show_default=True, help='Input image height')
@click.option('-w', '--width', default=224, show_default=True, help='Input image width')
def bench_mobilenet_v2(batch_size: int, channels: int, height: int, width: int):
bench_models = [VisionModel('mobilenet_v2', batch_size, channels, height, width)]
header = BenchModel.headers()
result = [bench_model.benchmark() for bench_model in bench_models]

click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left'))
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32))