Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Allow one backend fail in benchmark script #170

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 24 additions & 18 deletions python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,30 @@ def report_table(cls, table_str):
click.echo(table_str)

def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repeat=10):
import torch.backends.cudnn # pylint: disable=redefined-outer-name
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

torch.backends.cudnn.allow_tf32 = not self.disable_torch_cudnn_tf32
torch.backends.cuda.matmul.allow_tf32 = self.enable_torch_cublas_tf32

model, (args, kwargs) = self.converted_model(), self.converted_inputs()
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
return latency
try:
import torch.backends.cudnn # pylint: disable=redefined-outer-name
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

torch.backends.cudnn.allow_tf32 = not self.disable_torch_cudnn_tf32
torch.backends.cuda.matmul.allow_tf32 = self.enable_torch_cublas_tf32

model, (args, kwargs) = self.converted_model(), self.converted_inputs()
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
return latency
except Exception as e: # pylint: disable=broad-except
from traceback import format_exc

print('Failed to benchmark {} with {}: {}\nTraceback:\n{}'.format(self, backend, e, format_exc()))
return float('NaN')

def bench_eager(self) -> float:
print('Benchmarking {} with backend {}...'.format(self, 'eager'))
Expand Down
10 changes: 6 additions & 4 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,15 +1008,15 @@ def symbol(shape: Sequence[int], dtype='float32', device='cpu', layout=None) ->
return Tensor(shape=shape, dtype=dtype, device=device, storage=None, layout=layout)


def zeros(shape: Sequence[int], dtype='float32', device='cpu') -> Tensor:
def zeros(shape: Sequence[int], dtype: Union[DataType, str] = 'float32', device='cpu') -> Tensor:
"""Create a tensor initialized with zero.
Parameters
----------
shape: Sequence[int]
The shape of new tensor.
dtype: str
dtype: str or DataType
The data type of element of the tensor.
device: Device or str, default 'cpu'
Expand Down Expand Up @@ -1114,9 +1114,11 @@ def randn(shape, dtype='float32', mean=0.0, stddev=1.0, device='cpu') -> Tensor:
[[ 0.10720467 -1.6906018 0.06347568]
[-0.37061226 0.562728 1.857547 ]]
"""
np_tensor = np.random.randn(*shape) * stddev + mean

if isinstance(np_tensor, float): # shape = []
np_tensor = np.array(np_tensor)

np_tensor = np.random.randn(*shape).astype(np.float32)
np_tensor = np_tensor * stddev + mean
hidet_tensor = from_numpy(np_tensor)
return hidet_tensor.to(device=device, dtype=dtype)

Expand Down
7 changes: 6 additions & 1 deletion python/hidet/ir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def dummy_arguments(self, device: str):
if isinstance(param, Var):
arguments.append(10)
elif isinstance(param, TensorNode):
arguments.append(hidet.randn(param.const_shape(), param.type.dtype, device=device))
if param.type.dtype.is_integer():
arguments.append(hidet.zeros(param.const_shape(), dtype=param.type.dtype, device=device))
elif param.type.dtype.is_float():
arguments.append(hidet.randn(param.const_shape(), dtype=param.type.dtype, device=device))
else:
raise ValueError('Unknown dtype: {}'.format(param.type.dtype))
else:
raise ValueError('Unknown parameter type: {}'.format(type(param)))
return arguments
Expand Down
9 changes: 8 additions & 1 deletion scripts/bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import hidet

hidet.option.cache_dir(os.path.join(hidet.option.get_cache_dir(), 'benchmark'))
hidet.utils.hidet_clear_op_cache()
cache_dir = hidet.option.get_cache_dir()


parser = argparse.ArgumentParser('Benchmark hidet performance.')
parser.add_argument('--git-prev-commit', default=None, type=str, help='Previous git commit hash.')
parser.add_argument('--git-commit', type=str, help='Git commit hash.')
parser.add_argument('--keep-cache', default=False, action='store_true', help='Keep cache.')
parser.add_argument('--space', default=2, type=int, help='Search space of hidet.')
parser.add_argument('--report', default='./report.txt', type=str, help='Report file path.')

Expand Down Expand Up @@ -54,11 +54,18 @@ def info(args) -> str:

def main():
args = parser.parse_args()

if not args.keep_cache:
print('Clearing hidet operator cache...')
hidet.utils.hidet_clear_op_cache()

commands = [
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report resnet50_f32.txt --tensor-core resnet --models resnet50',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report resnet50_f16.txt --tensor-core resnet --models resnet50',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report bert-seq128-f32.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report bert-seq128-f16.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased',
# f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report gpt2-seq128-f32.txt --tensor-core nlp --seq-length 128 --models gpt2',
# f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report gpt2-seq128-f16.txt --tensor-core nlp --seq-length 128 --models gpt2',
]
with open(args.report, 'w') as f:
t1 = time.time()
Expand Down