Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Add the documentation for 'hidet' backend of PyTorch dynamo #42

Merged
merged 4 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Install dependencies via pip
run: |
python -m pip install --upgrade pip
pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116
pip install -r requirements.txt
pip install -r requirements-dev.txt

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Format with black
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Install dependencies via pip
run: |
python -m pip install --upgrade pip
pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116
pip install -r requirements.txt
pip install -r requirements-dev.txt

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Hidet is an open-source DNN inference framework, it features
:maxdepth: 1
:caption: Tutorials

gallery/tutorials/optimize-pytorch-model
gallery/tutorials/run-onnx-model


Expand Down
71 changes: 70 additions & 1 deletion gallery/getting-started/quick-start.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,81 @@

This guide walks through the key functionality of Hidet for tensor computation.
"""

# %%
# We should first import hidet.
import hidet

# %%
# Optimize PyTorch model with Hidet
# ---------------------------------
# .. note::
# :class: margin
#
# Torch dynamo is a feature introduced in PyTorch 2.0, which has not been officially released yet. Please install the
# nightly build of PyTorch to use this feature.
#
# The easiest way to use Hidet is to use the :func:`torch.compile` function with 'hidet' as the backend, such as
#
# .. code-block:: python
#
# model_opt = torch.compile(model, backend='hidet')
#
# Next, we use resnet18 model as an example to show how to optimize a PyTorch model with Hidet.

# disable tf32 to make the result of torch more accurate
import torch.backends.cudnn
torch.backends.cudnn.allow_tf32 = False

# take resnet18 as an example
x = torch.randn(1, 3, 224, 224).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False)
model = model.cuda().eval()

# currently, hidet only support inference
with torch.no_grad():
# optimize the model with 'hidet' backend
model_opt = torch.compile(model, backend='hidet')

# run the optimized model
y1 = model_opt(x)
y2 = model(x)

# check the correctness (when tf32 is used, the error tolerance would go to 1e-3)
torch.testing.assert_close(actual=y1, expected=y2, rtol=1e-5, atol=1e-5)


# benchmark the performance
for name, model in [('eager', model), ('hidet', model_opt)]:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(100):
y = model(x)
end_event.record()
torch.cuda.synchronize()
print('{:>10}: {:.3f} ms'.format(name, start_event.elapsed_time(end_event) / 100.0))


# %%
# Hidet provides some configurations to control the optimization of hidet backend. such as
#
# - **Search Space**: you can choose the search space of operator kernel tuning. A larger schedule space usually
# achieves the better performance, but takes longer time to optimize.
# - **Correctness Checking**: print the correctness checking report. You can know the numerical difference between the
# hidet generated operator and the original pytorch operator.
# - **Other Configurations**: you can also configure the other optimizations of hidet backend, such as using a lower
# precision of data type automatically (e.g., float16), or control the behavior of parallelization of the reduction
# dimension of the matrix multiplication and convolution operators.
#
# .. seealso::
#
# You can learn more about the configuration of hidet as a backend in torch dynamo in the tutorial
# :doc:`/gallery/tutorials/optimize-pytorch-model`.
#
# In the remaining parts, we will show you the key components of Hidet.
#
#
# Define tensors
# --------------
#
Expand Down
2 changes: 1 addition & 1 deletion gallery/how-to-guides/visualize-flow-graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor):
#
# :download:`Download 1_FoldConstantPass.json <../../../../gallery/how-to-guides/outs/1_FoldConstantPass.json>`
#
# :download:`Download 2_PatternTransformPass.json <../../../../gallery/how-to-guides/outs/2_PatternTransformPass.json>`
# :download:`Download 2_PatternTransformPass.json <../../../../gallery/how-to-guides/outs/2_SubgraphRewritePass.json>`
#
# :download:`Download 4_ResolveVariantPass.json <../../../../gallery/how-to-guides/outs/4_ResolveVariantPass.json>`
#
Expand Down
160 changes: 160 additions & 0 deletions gallery/tutorials/optimize-pytorch-model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
Optimize PyTorch Model
======================

Hidet provides a backend to pytorch dynamo to optimize PyTorch models. To use this backend, you need to specify 'hidet'
as the backend when calling :func:`torch.compile` such as

.. code-block:: python

# optimize the model with hidet provided backend 'hidet'
model_hidet = torch.compile(model, backend='hidet')

.. note::
:class: margin

Currently, all the operators in hidet are generated by hidet itself and
there is no dependency on kernel libraries such as cuDNN or cuBLAS. In the future, we might support to lower some
operators to these libraries if they perform better.

Under the hood, hidet will convert the PyTorch model to hidet's graph representation and optimize the computation graph
(such as sub-graph rewrite and fusion, constant folding, etc.). After that, each operator will be lowered to hidet's
scheduling system to generate the final kernel.


Hidet provides some configurations to control the hidet backend of torch dynamo.

Search in a larger search space
-------------------------------
There are some operators that are compute-intensive and their scheduling is critical to the performance. We usually need
to search in a schedule space to find the best schedule for them to achieve the best performance on given input shapes.
However, searching in a larger schedule space usually takes longer time to optimize the model. By default, hidet will
use their default schedule to generate the kernel for all input shapes. To search in a larger schedule space to get
better performance, you can configure the search space via

.. code-block:: python

# There are three search spaces:
# 0 - use default schedule, no search [Default]
# 1 - search in a small schedule space (usually 1~30 schedules)
# 2 - search in a large schedule space (usually more than 30 schedules)
hidet.torch.dynamo_config.set_search_space(2)

# After configure the search space, you can optimize the model
model_opt = torch.compile(model, backend='hidet')

# The actual searching happens when you first run the model to know the input shapes
outputs = model_opt(inputs)

Please note that the search space we set through :func:`~hidet.torch.dynamo_config.set_search_space` will be read and
used when we first run the model, instead of when we call :func:`torch.compile`.

Check the correctness
---------------------
It is important to make sure the optimized model is correct. Hidet provides a configuration to print the numerical
difference between the hidet generated operator and the original pytorch operator. You can configure it via

.. code-block:: python

# enable the correctness checking
hidet.torch.dynamo_config.correctness_report()

After enabling the correctness report, every time a new graph is received to compile, hidet will print the numerical
difference using the dummy inputs (for now, torch dynamo does not expose the actual inputs to backends, thus we can
not use the actual inputs). Let's take the resnet18 model as an example:
"""
import torch.backends.cudnn
import hidet

x = torch.randn(1, 3, 224, 224).cuda()
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False).cuda()
model.eval()

torch.backends.cudnn.allow_tf32 = False # tf32 would harm the effective precision of torch's results

with torch.no_grad():
hidet.torch.dynamo_config.correctness_report()
model_opt = torch.compile(model, backend='hidet')
model_opt(x)

# %%
#
# .. tip::
# :class: margin
#
# Usually, we can expect:
#
# - for float32: :math:`e_h \leq 10^{-5}`, and
# - for float16: :math:`e_h \leq 10^{-2}`.
#
# The correctness report will print the harmonic mean of the absolute error and relative error for each operator:
#
# .. math::
# e_h = \frac{|actual - expected|}{|expected| + 1} \quad (\frac{1}{e_h} = \frac{1}{e_a} + \frac{1}{e_r})
#
#
# where :math:`actual`, :math:`expected` are the actual and expected results of the operator, respectively.
# The :math:`e_a` and :math:`e_r` are the absolute error and relative error, respectively. The harmonic mean error is
# printed for each operator.
#

# %%
# Operator configurations
# -----------------------
#
# Use CUDA Graph to dispatch kernels
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Hidet provides a configuration to use CUDA Graph to dispatch kernels. CUDA Graph is a new feature in CUDA 11.0
# that allows us to record the kernel dispatches and replay them later. This feature is useful when we want to
# dispatch the same kernels multiple times. Hidet will enable CUDA Graph by default. You can disable it via
#
# .. code-block:: python
#
# # disable CUDA Graph
# hidet.torch.dynamo_config.use_cuda_graph(False)
#
# in case you want to use PyTorch's CUDA Graph feature.
#
# Use low-precision data type
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Hidet provides a configuration to use low-precision data type. By default, hidet will use the same data type as
# the original PyTorch model. You can configure it via
#
# .. code-block:: python
#
# # automatically transform the model to use float16 data type
# hidet.torch.dynamo_config.use_fp16(True)
#
# # use float16 data type as the accumulate data type in operators with reduction
# hidet.torch.dynamo_config.use_fp16_reduction(True)
#
# You do not need to change the inputs feed to the model, as hidet will automatically cast the inputs to the
# configured data type automatically in the optimized model.
#
#
# Print the input graph
# ~~~~~~~~~~~~~~~~~~~~~
#
# If you are interested in the graph that PyTorch dynamo dispatches to hidet backend, you can configure hidet to
# print the graph via
#
# .. code-block:: python
#
# # print the input graph
# hidet.torch.dynamo_config.print_input_graph(True)
#
# Because ResNet18 is a neat model without control flow, we can print the input graph to see how PyTorch dynamo
# dispatches the model to hidet backend:

# sphinx_gallery_start_ignore
import torch._dynamo as dynamo
hidet.torch.dynamo_config.correctness_report(False) # reset
dynamo.reset() # clear the compiled cache
# sphinx_gallery_end_ignore

with torch.no_grad():
hidet.torch.dynamo_config.print_input_graph(True)
model_opt = torch.compile(model, backend='hidet')
model_opt(x)
4 changes: 2 additions & 2 deletions gallery/tutorials/run-onnx-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
.. currentmodule:: hidet
.. _Run ONNX Model with Hidet:

Run ONNX Model with Hidet
=========================
Optimize ONNX Model
===================

This tutorial walks through the steps to run a model in `ONNX format <https://onnx.ai/>`_ with Hidet.
The ResNet50 onnx model exported from PyTorch model zoo would be used as an example.
Expand Down
8 changes: 4 additions & 4 deletions python/hidet/cli/bench/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
dynamo = None

torch.backends.cudnn.allow_tf32 = False # for fair comparison
# torch.backends.cudnn.allow_tf32 = False # for fair comparison


class BenchModel:
Expand All @@ -38,14 +38,14 @@ def inputs_str(self) -> str:
items.append('{}={}'.format(k, self.tensor_str(v)))
return ', '.join(items)

def bench_with_backend(self, backend: str, warmup=3, number=10, repeat=10):
def bench_with_backend(self, backend: str, mode=None, passes=None, warmup=3, number=10, repeat=10):
model, (args, kwargs) = self.model(), self.example_inputs()
model = model.cuda().eval()
args = [arg.cuda() for arg in args]
kwargs = {k: v.cuda() for k, v in kwargs.items()}
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend)
model_opt = torch.compile(model, backend=backend, mode=mode, passes=passes)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
Expand All @@ -55,7 +55,7 @@ def bench_eager(self) -> float:
return self.bench_with_backend('eager')

def bench_inductor(self) -> float:
return self.bench_with_backend('inductor')
return self.bench_with_backend('inductor', mode='max-autotune')

def bench_hidet(self, use_cuda_graph=True, use_fp16=False, use_fp16_reduction=False, space=2) -> float:
config = hidet.torch.dynamo_config
Expand Down