-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Docs][Opeartor] Add more pytorch operator bindings and docs (#50)
* . * more torch operator bindings
- Loading branch information
1 parent
91abff6
commit 57551cc
Showing
29 changed files
with
555 additions
and
527 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
hidet.dtypes | ||
============ | ||
|
||
.. data:: hidet.uint8 | ||
.. data:: hidet.uint16 | ||
.. data:: hidet.uint32 | ||
.. data:: hidet.uint64 | ||
.. data:: hidet.int8 | ||
.. data:: hidet.int16 | ||
.. data:: hidet.int32 | ||
.. data:: hidet.int64 | ||
.. data:: hidet.float16 | ||
.. data:: hidet.float32 | ||
.. data:: hidet.float64 | ||
.. data:: hidet.bfloat16 | ||
.. data:: hidet.tfloat32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ Python API | |
driver | ||
cuda | ||
tensor | ||
data_types | ||
ops/index | ||
ir/index | ||
graph/index | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
hidet.ir.type | ||
============= | ||
|
||
.. automodule:: hidet.ir.type | ||
:members: | ||
:autosummary: | ||
.. autoclass:: hidet.ir.type.DataType | ||
|
||
.. autofunction:: hidet.ir.type.data_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import click | ||
from hidet.utils import initialize | ||
from . import vision | ||
from . import nlp | ||
from .model import BenchModel | ||
from .bench_common import bench_common | ||
from .bench_all import bench_all | ||
|
||
|
||
@click.group(name='bench', help='Benchmark models.') | ||
@click.option( | ||
'--space', | ||
default='0', | ||
show_default=True, | ||
type=click.Choice(['0', '1', '2']), | ||
help='Schedule space. 0: default schedule. 1: small schedule space. 2: large schedule space.', | ||
) | ||
@click.option( | ||
'--torch-tf32', | ||
default=False, | ||
show_default=True, | ||
type=bool, | ||
help='Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32.', | ||
) | ||
def bench_group(space: str, torch_tf32: bool): | ||
BenchModel.search_space = int(space) | ||
BenchModel.allow_tf32 = torch_tf32 | ||
|
||
|
||
@initialize() | ||
def register_commands(): | ||
for command in [ | ||
bench_common, | ||
bench_all, | ||
vision.bench_resnet, | ||
vision.bench_resnext, | ||
vision.bench_inception_v3, | ||
vision.bench_mobilenet_v2, | ||
nlp.bench_nlp, | ||
]: | ||
assert isinstance(command, click.Command) | ||
bench_group.add_command(command) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import click | ||
from tabulate import tabulate | ||
from hidet.cli.bench.model import BenchModel, all_registered_models | ||
|
||
|
||
@click.command(name='all') | ||
def bench_all(): | ||
header = BenchModel.headers() | ||
result = [bench_model.benchmark() for bench_model in all_registered_models] | ||
|
||
click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) | ||
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import click | ||
from tabulate import tabulate | ||
from hidet.cli.bench.model import BenchModel, commonly_used_models | ||
|
||
|
||
@click.command(name='common') | ||
def bench_common(): | ||
header = BenchModel.headers() | ||
result = [bench_model.benchmark() for bench_model in commonly_used_models] | ||
|
||
click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) | ||
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .models import bench_nlp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import click | ||
from tabulate import tabulate | ||
from hidet.cli.bench.model import commonly_used_models, all_registered_models | ||
from hidet.utils import initialize | ||
from .nlp_model import NLPModel, BenchModel | ||
|
||
|
||
available_models = ['bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased', 'gpt2'] | ||
|
||
commonly_used_models.extend( | ||
[ | ||
NLPModel('huggingface/pytorch-transformers', 'model', 'bert-base-uncased', 1, 128), | ||
NLPModel('huggingface/pytorch-transformers', 'model', 'gpt2', 1, 128), | ||
] | ||
) | ||
|
||
|
||
@initialize() | ||
def initialize_models(): | ||
for model in available_models: | ||
for batch_size in [1, 8]: | ||
for seq_length in [128, 512]: | ||
all_registered_models.append( | ||
NLPModel('huggingface/pytorch-transformers', 'model', model, batch_size, seq_length) | ||
) | ||
|
||
|
||
@click.command(name='nlp') | ||
@click.option( | ||
'--models', | ||
type=str, | ||
default='bert-base-uncased', | ||
show_default=True, | ||
help='Comma seperated models to benchmark. Available models: {}'.format(', '.join(available_models)), | ||
) | ||
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size') | ||
@click.option('-q', '--seq-length', default=128, show_default=True, help='Sequence length') | ||
def bench_nlp(models: str, batch_size: int, seq_length: int): | ||
models = [model.strip() for model in models.split(',')] | ||
for model in models: | ||
if model not in available_models: | ||
raise ValueError('Unknown model: {}, candidates: {}'.format(model, list(available_models))) | ||
|
||
bench_models = [ | ||
NLPModel('huggingface/pytorch-transformers', 'model', model, batch_size, seq_length) for model in models | ||
] | ||
header = BenchModel.headers() | ||
result = [bench_model.benchmark() for bench_model in bench_models] | ||
|
||
click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) | ||
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from hidet.cli.bench.model import BenchModel | ||
|
||
|
||
class NLPModel(BenchModel): | ||
def __init__(self, repo_name, model_name, label, batch_size: int, sequence_length: int): | ||
self.repo_name = repo_name | ||
self.model_name = model_name | ||
self.label = label | ||
self.batch_size = batch_size | ||
self.sequence_length = sequence_length | ||
|
||
def __str__(self): | ||
return '{}/{}'.format(self.model_name, self.label) | ||
|
||
def model(self): | ||
import torch | ||
|
||
return torch.hub.load(self.repo_name, self.model_name, self.label) | ||
|
||
def example_inputs(self): | ||
import torch | ||
|
||
tokens_tensor = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda') | ||
segments_tensors = torch.zeros((self.batch_size, self.sequence_length), dtype=torch.long, device='cuda') | ||
args = (tokens_tensor,) | ||
kwargs = {'token_type_ids': segments_tensors} | ||
return args, kwargs |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .inception_v3 import bench_inception_v3 | ||
from .mobilenet_v2 import bench_mobilenet_v2 | ||
from .resnet import bench_resnet | ||
from .resnext import bench_resnext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import click | ||
from tabulate import tabulate | ||
from hidet.cli.bench.model import BenchModel, all_registered_models, commonly_used_models | ||
from .vision_model import VisionModel | ||
|
||
|
||
inception_models = {'inception_v3': VisionModel('inception_v3', 1, 3, 299, 299)} | ||
|
||
|
||
all_registered_models.extend(inception_models.values()) | ||
commonly_used_models.append(inception_models['inception_v3']) | ||
|
||
|
||
@click.command(name='inception-v3') | ||
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size') | ||
@click.option('-c', '--channels', default=3, show_default=True, help='Input channels') | ||
@click.option('-h', '--height', default=224, show_default=True, help='Input image height') | ||
@click.option('-w', '--width', default=224, show_default=True, help='Input image width') | ||
def bench_inception_v3(batch_size: int, channels: int, height: int, width: int): | ||
bench_models = [VisionModel('inception_v3', batch_size, channels, height, width)] | ||
header = BenchModel.headers() | ||
result = [bench_model.benchmark() for bench_model in bench_models] | ||
|
||
click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) | ||
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import click | ||
from tabulate import tabulate | ||
from hidet.cli.bench.model import BenchModel, all_registered_models, commonly_used_models | ||
from .vision_model import VisionModel | ||
|
||
|
||
inception_models = {'mobilenet_v2': VisionModel('mobilenet_v2', 1, 3, 224, 224)} | ||
|
||
|
||
all_registered_models.extend(inception_models.values()) | ||
commonly_used_models.append(inception_models['mobilenet_v2']) | ||
|
||
|
||
@click.command(name='mobilenet-v2') | ||
@click.option('-n', '--batch-size', default=1, show_default=True, help='Batch size') | ||
@click.option('-c', '--channels', default=3, show_default=True, help='Input channels') | ||
@click.option('-h', '--height', default=224, show_default=True, help='Input image height') | ||
@click.option('-w', '--width', default=224, show_default=True, help='Input image width') | ||
def bench_mobilenet_v2(batch_size: int, channels: int, height: int, width: int): | ||
bench_models = [VisionModel('mobilenet_v2', batch_size, channels, height, width)] | ||
header = BenchModel.headers() | ||
result = [bench_model.benchmark() for bench_model in bench_models] | ||
|
||
click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) | ||
click.echo('(PyTorch backend: allow_tf32={})'.format(BenchModel.allow_tf32)) |
Oops, something went wrong.