Skip to content

Commit

Permalink
[Fixbug] Allow one backend fail in benchmark script (#170)
Browse files Browse the repository at this point in the history
.
  • Loading branch information
yaoyaoding committed Apr 12, 2023
1 parent 634a3a2 commit d2465c2
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 24 deletions.
42 changes: 24 additions & 18 deletions python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,30 @@ def report_table(cls, table_str):
click.echo(table_str)

def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repeat=10):
import torch.backends.cudnn # pylint: disable=redefined-outer-name
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

torch.backends.cudnn.allow_tf32 = not self.disable_torch_cudnn_tf32
torch.backends.cuda.matmul.allow_tf32 = self.enable_torch_cublas_tf32

model, (args, kwargs) = self.converted_model(), self.converted_inputs()
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
return latency
try:
import torch.backends.cudnn # pylint: disable=redefined-outer-name
import torch.backends.cuda

if not hidet.torch.dynamo_available():
raise RuntimeError('Torch Dynamo is not available, please install pytorch 2.0 or higher.')
import torch._dynamo as dynamo

torch.backends.cudnn.allow_tf32 = not self.disable_torch_cudnn_tf32
torch.backends.cuda.matmul.allow_tf32 = self.enable_torch_cublas_tf32

model, (args, kwargs) = self.converted_model(), self.converted_inputs()
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
return latency
except Exception as e: # pylint: disable=broad-except
from traceback import format_exc

print('Failed to benchmark {} with {}: {}\nTraceback:\n{}'.format(self, backend, e, format_exc()))
return float('NaN')

def bench_eager(self) -> float:
print('Benchmarking {} with backend {}...'.format(self, 'eager'))
Expand Down
10 changes: 6 additions & 4 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,15 +1008,15 @@ def symbol(shape: Sequence[int], dtype='float32', device='cpu', layout=None) ->
return Tensor(shape=shape, dtype=dtype, device=device, storage=None, layout=layout)


def zeros(shape: Sequence[int], dtype='float32', device='cpu') -> Tensor:
def zeros(shape: Sequence[int], dtype: Union[DataType, str] = 'float32', device='cpu') -> Tensor:
"""Create a tensor initialized with zero.
Parameters
----------
shape: Sequence[int]
The shape of new tensor.
dtype: str
dtype: str or DataType
The data type of element of the tensor.
device: Device or str, default 'cpu'
Expand Down Expand Up @@ -1114,9 +1114,11 @@ def randn(shape, dtype='float32', mean=0.0, stddev=1.0, device='cpu') -> Tensor:
[[ 0.10720467 -1.6906018 0.06347568]
[-0.37061226 0.562728 1.857547 ]]
"""
np_tensor = np.random.randn(*shape) * stddev + mean

if isinstance(np_tensor, float): # shape = []
np_tensor = np.array(np_tensor)

np_tensor = np.random.randn(*shape).astype(np.float32)
np_tensor = np_tensor * stddev + mean
hidet_tensor = from_numpy(np_tensor)
return hidet_tensor.to(device=device, dtype=dtype)

Expand Down
7 changes: 6 additions & 1 deletion python/hidet/ir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def dummy_arguments(self, device: str):
if isinstance(param, Var):
arguments.append(10)
elif isinstance(param, TensorNode):
arguments.append(hidet.randn(param.const_shape(), param.type.dtype, device=device))
if param.type.dtype.is_integer():
arguments.append(hidet.zeros(param.const_shape(), dtype=param.type.dtype, device=device))
elif param.type.dtype.is_float():
arguments.append(hidet.randn(param.const_shape(), dtype=param.type.dtype, device=device))
else:
raise ValueError('Unknown dtype: {}'.format(param.type.dtype))
else:
raise ValueError('Unknown parameter type: {}'.format(type(param)))
return arguments
Expand Down
9 changes: 8 additions & 1 deletion scripts/bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import hidet

hidet.option.cache_dir(os.path.join(hidet.option.get_cache_dir(), 'benchmark'))
hidet.utils.hidet_clear_op_cache()
cache_dir = hidet.option.get_cache_dir()


parser = argparse.ArgumentParser('Benchmark hidet performance.')
parser.add_argument('--git-prev-commit', default=None, type=str, help='Previous git commit hash.')
parser.add_argument('--git-commit', type=str, help='Git commit hash.')
parser.add_argument('--keep-cache', default=False, action='store_true', help='Keep cache.')
parser.add_argument('--space', default=2, type=int, help='Search space of hidet.')
parser.add_argument('--report', default='./report.txt', type=str, help='Report file path.')

Expand Down Expand Up @@ -54,11 +54,18 @@ def info(args) -> str:

def main():
args = parser.parse_args()

if not args.keep_cache:
print('Clearing hidet operator cache...')
hidet.utils.hidet_clear_op_cache()

commands = [
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report resnet50_f32.txt --tensor-core resnet --models resnet50',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report resnet50_f16.txt --tensor-core resnet --models resnet50',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report bert-seq128-f32.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased',
f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report bert-seq128-f16.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased',
# f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float32 --report gpt2-seq128-f32.txt --tensor-core nlp --seq-length 128 --models gpt2',
# f'hidet bench --cache-dir {cache_dir} --space {args.space} --dtype float16 --report gpt2-seq128-f16.txt --tensor-core nlp --seq-length 128 --models gpt2',
]
with open(args.report, 'w') as f:
t1 = time.time()
Expand Down

0 comments on commit d2465c2

Please sign in to comment.