Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] byteir dynamo backend #291

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel
fx_utils.py
ts_utils.py

byteir_backend/__init__.py
byteir_backend/compilation_cache.py
byteir_backend/compiled_function.py
byteir_backend/compiler.py
byteir_backend/config.py
byteir_backend/inner_compile.py
byteir_backend/utils.py
byteir_backend/byteir_fusible_pattern.py
byteir_backend/fx_match_utils.py
byteir_backend/fx_utils.py
byteir_backend/partitioners.py

tools/compiler.py
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
*Requires torch version 2.4 or higher*

- Usage

```python
import torch

import torch_frontend
from torch_frontend import byteir_backend as byteir_backend
from torch_frontend.byteir_backend.utils import *

class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0, x1, x2):
r0 = torch.ops.aten.mul(x0, x1)
r1 = torch.ops.aten.div(r0, x2)
x0 = torch.ops.aten.mul(r1, r1) - x0
r2 = torch.ops.aten.slice(x0, 1, 1, 3, 1)
return r1, r2

model = NaiveModel()
opt_mod = torch.compile(model, backend="byteir")

x0 = torch.rand(32, 64).to('cuda')
x1 = torch.rand(32, 64).to('cuda')
x2 = torch.rand(32, 64).to('cuda')

x0 = x0.as_strided(size=(32,16), stride=(64,2), storage_offset=16)
x1 = x1.as_strided(size=(32,16), stride=(64,1), storage_offset=8)
x2 = x2.as_strided(size=(32,16), stride=(32,1), storage_offset=32)

golden = model(x0, x1, x2)
outs = opt_mod(x0, x1, x2)
torch.cuda.synchronize()

torch.testing.assert_close(golden, outs)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch._dynamo import register_backend


@register_backend
def byteir(*args, **kwargs):
from .compiler import byteir_compiler

return byteir_compiler(*args, **kwargs)

def set_cache_dir(path: str):
from .compilation_cache import ByteIRFxGraphCache

ByteIRFxGraphCache.base_cache_dir = path
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import torch
import torch.fx as fx

from .fx_utils import get_aten_target
from .fx_match_utils import get_node_consumer, match_chain

byteir_fusible_patterns = {}
aten = torch.ops.aten


def register_byteir_pattern(name):

def register(pattern):
if name in byteir_fusible_patterns.keys():
raise ValueError("Pattern " + name +
" has already been registerd.")
byteir_fusible_patterns[name] = pattern
return pattern

return register


class ByteIRFusiblePattern:

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
raise NotImplementedError

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
raise NotImplementedError


@register_byteir_pattern("transpose_dot")
class TransposeDotPattern(ByteIRFusiblePattern):

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
post_fusible_ops = [aten.mm, aten.bmm]
if get_aten_target(node) in [aten.t, aten.transpose]:
can_fuse = all(
get_aten_target(user) in post_fusible_ops
for user in node.users)
all_fw_node = all(user in required_fw_nodes for user in node.users)
return (not all_fw_node) and can_fuse
return False

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
if cls.match(node, required_fw_nodes):
return [node]
return []


@register_byteir_pattern("transpose_reshape_transpose_dot")
class TransposeReshapeTransposeDotPattern(ByteIRFusiblePattern):

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
post_fusible_ops = [aten.mm, aten.bmm, aten.transpose]
if get_aten_target(node) not in [aten.transpose]:
return False
if match_chain(node,
target_chain=(aten.transpose, aten.expand, aten.clone,
aten._unsafe_view)):
expand_node = get_node_consumer(node, 0)
clone_node = get_node_consumer(expand_node, 0)
view_node = get_node_consumer(clone_node, 0)
all_fw_node = all(user in required_fw_nodes
for user in view_node.users)
can_fuse = all(
get_aten_target(user) in post_fusible_ops
for user in view_node.users)
return (not all_fw_node) and can_fuse
return False

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
if cls.match(node, required_fw_nodes):
expand_node = get_node_consumer(node, 0)
clone_node = get_node_consumer(expand_node, 0)
view_node = get_node_consumer(clone_node, 0)
recompute_nodes = [node, expand_node, clone_node, view_node]
for user in view_node.users:
if user not in required_fw_nodes:
recompute_nodes.append(user)
return recompute_nodes
return []


@register_byteir_pattern("transpose_transpose")
class TransposeTransposePattern(ByteIRFusiblePattern):

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
if get_aten_target(node) in [aten.t, aten.transpose]:
for user in node.users:
if get_aten_target(user) in [aten.t, aten.transpose]:
all_fw_node = all(n in required_fw_nodes
for n in user.users)
if not all_fw_node:
return True
return False

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
if cls.match(node, required_fw_nodes):
recompute_nodes = [node]
for user in node.users:
if get_aten_target(user) == aten.t:
recompute_nodes.append(user)
return recompute_nodes
return []


@register_byteir_pattern("full_bitwise_not_expand")
class FullBitwiseNotExpandPattern(ByteIRFusiblePattern):

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
if match_chain(node,
target_chain=(aten.full, aten.bitwise_not,
aten.expand)):
return True
return False

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
if cls.match(node, required_fw_nodes):
bitwise_node = get_node_consumer(node, 0)
expand_node = get_node_consumer(bitwise_node, 0)
recompute_nodes = [node, bitwise_node, expand_node]
return recompute_nodes
return []


# Note: This pattern is temporary.
# It is only used to fix issue that full op(dtype is bool) is not supported in byteir.
@register_byteir_pattern("copy_bitwise_not_expand")
class CopyBitwiseNotExpandPattern(ByteIRFusiblePattern):

@classmethod
def match(cls, node, required_fw_nodes) -> bool:
if match_chain(node,
target_chain=(aten._to_copy, aten.bitwise_not,
aten.expand, aten.bitwise_or)):
bitwise_not_node = get_node_consumer(node, 0)
expand_node = get_node_consumer(bitwise_not_node, 0)
bitwise_or_node = get_node_consumer(expand_node, 0)
return True
return False

@classmethod
def get_pattern_recompute_nodes(cls, node, required_fw_nodes):
if cls.match(node, required_fw_nodes):
bitwise_not = get_node_consumer(node, 0)
expand = get_node_consumer(bitwise_not, 0)
bitwise_or = get_node_consumer(expand, 0)
recompute_nodes = [node, bitwise_not, expand, bitwise_or]
return recompute_nodes
return []


def greedy_transpose_fusion(joint_graph, required_fw_nodes):
recompute_nodes = []
post_fuse_ops = [aten.bmm, aten.mm]
transparent_ops = [aten.clone, aten._to_copy, aten.expand]
view_ops = [aten.view, aten._unsafe_view]
transpose_ops = [aten.t, aten.transpose]
fusible_tag = {}

INIT_TAG = 0
POST_FUSION_TAG = 1
TRANSPOSE_TAG = 2

for node in reversed(joint_graph.nodes):
fusible_tag[node] = INIT_TAG

for node in reversed(joint_graph.nodes):
if get_aten_target(
node) in post_fuse_ops and node not in required_fw_nodes:
fusible_tag[node] = POST_FUSION_TAG

if get_aten_target(node) in transparent_ops:
for user in node.users:
if user in fusible_tag.keys(
) and fusible_tag[user] >= POST_FUSION_TAG:
fusible_tag[node] = POST_FUSION_TAG
recompute_nodes.append(node)

if get_aten_target(node) in transpose_ops:
for user in node.users:
if user in fusible_tag.keys(
) and fusible_tag[user] >= POST_FUSION_TAG:
recompute_nodes.append(node)
fusible_tag[node] = INIT_TAG

return recompute_nodes


def get_byteir_recompute_nodes(joint_graph, required_fw_nodes):
recompute_nodes = []
recompute_nodes.extend(
greedy_transpose_fusion(joint_graph, required_fw_nodes))
for name, pattern in byteir_fusible_patterns.items():
for node in joint_graph.nodes:
if node.op == 'output':
continue
recompute_nodes.extend(
pattern.get_pattern_recompute_nodes(node, required_fw_nodes))
recompute_nodes = list(set(recompute_nodes))
return recompute_nodes
Loading
Loading