Skip to content

Commit

Permalink
COPYBARA SYNC:
Browse files Browse the repository at this point in the history
  - 03bedd4 [compiler] support splat value in CompareFolder (#294)
  - c34abdc [compiler] fix bug in RemoveCopy (#289)
  - b1914a7 [dynamo] byteir dynamo backend (#291)

GitOrigin-RevId: 03bedd4
  • Loading branch information
Vremold committed May 30, 2024
1 parent 9903070 commit 51ccfe9
Show file tree
Hide file tree
Showing 19 changed files with 2,528 additions and 45 deletions.
15 changes: 11 additions & 4 deletions compiler/include/byteir/Dialect/MemRef/Utils/MemEffect.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,19 @@ void getAllAlias(Operation *op,
llvm::SmallVectorImpl<SmallVector<Value>> &aliases,
bool skipNonOverlapedSubviews = false);

bool maybeOpOperandWrite(OpOperand &opOpernad);

bool maybeOpOperandRead(OpOperand &opOpernad);

// Note: this method would collect all **potential** read/write uses on given
// aliases
void getMemEffects(llvm::SmallVectorImpl<OpMemEffectOrder> &memEffects,
llvm::ArrayRef<SmallVector<Value>> aliases,
llvm::DenseMap<Operation *, unsigned> &opToIdx,
unsigned pivot);
using OperandMemEffectFn = std::function<bool(OpOperand &)>;
void getMemEffects(
llvm::SmallVectorImpl<OpMemEffectOrder> &memEffects,
llvm::ArrayRef<SmallVector<Value>> aliases,
llvm::DenseMap<Operation *, unsigned> &opToIdx, unsigned pivot,
const OperandMemEffectFn &hasReadEffect = maybeOpOperandRead,
const OperandMemEffectFn &hasWriteEffect = maybeOpOperandWrite);

} // namespace mlir

Expand Down
18 changes: 17 additions & 1 deletion compiler/lib/Dialect/MemRef/Transforms/RemoveCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,32 @@ class RemoveCopyPattern : public OpRewritePattern<memref::CopyOp> {
}
}

auto srcMemSpace = src.getType().cast<MemRefType>().getMemorySpace();
auto dstMemSpace = target.getType().cast<MemRefType>().getMemorySpace();
if (srcMemSpace && dstMemSpace && srcMemSpace != dstMemSpace) {
return failure();
}

SmallVector<SmallVector<Value>, 2> aliases(2);
getAllAlias(copyOp, aliases, /*skipNonOverlapedSubviews*/ true);
aliases[0].push_back(copyOp.getSource());

llvm::DenseMap<Operation *, unsigned> opToIdx;
unsigned idx = 0;
copyOp->getBlock()->walk<WalkOrder::PreOrder>(
[&](Operation *inner) { opToIdx[inner] = idx++; });

SmallVector<OpMemEffectOrder, 2> memEffects(2);
getMemEffects(memEffects, aliases, opToIdx, opToIdx[copyOp]);
auto hasReadEffectFn = [](OpOperand &opOpernad) -> bool {
if (maybeOpOperandRead(opOpernad) ||
llvm::isa<func::ReturnOp>(opOpernad.getOwner())) {
return true;
}
return false;
};

getMemEffects(memEffects, aliases, opToIdx, opToIdx[copyOp],
hasReadEffectFn);

auto hasReadAfterWrite = [&](ArrayRef<Operation *> reads,
ArrayRef<Operation *> writes) {
Expand Down
18 changes: 9 additions & 9 deletions compiler/lib/Dialect/MemRef/Utils/MemEffect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
using namespace mlir;
using namespace llvm;

namespace {
static bool maybeOpOperandWrite(OpOperand &opOpernad) {
bool mlir::maybeOpOperandWrite(OpOperand &opOpernad) {
if (auto memEffect =
dyn_cast<MemoryEffectOpInterface>(opOpernad.getOwner())) {
return memEffect.getEffectOnValue<MemoryEffects::Write>(opOpernad.get())
Expand All @@ -34,15 +33,14 @@ static bool maybeOpOperandWrite(OpOperand &opOpernad) {
return true;
}

static bool maybeOpOperandRead(OpOperand &opOpernad) {
bool mlir::maybeOpOperandRead(OpOperand &opOpernad) {
if (auto memEffect =
dyn_cast<MemoryEffectOpInterface>(opOpernad.getOwner())) {
return memEffect.getEffectOnValue<MemoryEffects::Read>(opOpernad.get())
.has_value();
}
return true;
}
} // namespace

void mlir::getAllAlias(Operation *op,
SmallVectorImpl<SmallVector<Value>> &aliases,
Expand Down Expand Up @@ -75,23 +73,25 @@ void mlir::getAllAlias(Operation *op,
void mlir::getMemEffects(SmallVectorImpl<OpMemEffectOrder> &memEffects,
ArrayRef<SmallVector<Value>> aliases,
llvm::DenseMap<Operation *, unsigned> &opToIdx,
unsigned pivot) {
unsigned pivot,
const OperandMemEffectFn &hasReadEffect,
const OperandMemEffectFn &hasWriteEffect) {
for (const auto &en : llvm::enumerate(aliases)) {
for (auto val : en.value()) {
for (auto &use : val.getUses()) {
auto user = use.getOwner();
if (opToIdx[user] < pivot) {
if (maybeOpOperandRead(use)) {
if (hasReadEffect(use)) {
memEffects[en.index()].before.reads.push_back(user);
}
if (maybeOpOperandWrite(use)) {
if (hasWriteEffect(use)) {
memEffects[en.index()].before.writes.push_back(user);
}
} else if (opToIdx[user] > pivot) {
if (maybeOpOperandRead(use)) {
if (hasReadEffect(use)) {
memEffects[en.index()].after.reads.push_back(user);
}
if (maybeOpOperandWrite(use)) {
if (hasWriteEffect(use)) {
memEffects[en.index()].after.writes.push_back(user);
}
}
Expand Down
13 changes: 8 additions & 5 deletions compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,14 @@ static Attribute CompareFolder(mhlo::CompareOp op, ArrayRef<Attribute> attrs) {
return {};
}

auto resultTy = op.getType().cast<ShapedType>();
if (lhs.isSplat() && rhs.isSplat()) {
bool value =
Convert()(addSign(lhs.getSplatValue<SrcType>(), lhs.getElementType()),
addSign(rhs.getSplatValue<SrcType>(), rhs.getElementType()));
return DenseElementsAttr::get(resultTy, value);
}

SmallVector<bool, 6> values;
values.reserve(lhs.getNumElements());
for (const auto zip :
Expand All @@ -1032,7 +1040,6 @@ static Attribute CompareFolder(mhlo::CompareOp op, ArrayRef<Attribute> attrs) {
addSign(std::get<1>(zip), rhs.getElementType())));
}

auto resultTy = op.getType().cast<ShapedType>();
return DenseElementsAttr::get(resultTy, values);
}

Expand Down Expand Up @@ -1208,10 +1215,6 @@ struct FoldLargeCompareOp : public OpRewritePattern<mhlo::CompareOp> {
if (elementType.isa<ComplexType>()) {
return failure();
}
// upstream handled splat value
if (lhsOp.getValue().isSplat() || rhsOp.getValue().isSplat()) {
return failure();
}

Attribute folded = nullptr;
#define COMPARE_FOLDER(comparison, Func) \
Expand Down
22 changes: 20 additions & 2 deletions compiler/numerical/hlo/canonicalize_ext.mlir

Large diffs are not rendered by default.

88 changes: 64 additions & 24 deletions compiler/test/Dialect/MemRef/removeCopy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,14 @@ module attributes {byre.container_module} {
// CHECK-LABEL: func.func @src_alloc_shape_transform_0(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x224x224x3xf16, "gpuhost"> {byre.argname = "input_tensor@Cast", byre.argtype = 1 : i32},
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x1001xf16, "gpuhost"> {byre.argname = "softmax_tensor@Cast", byre.argtype = 2 : i32}) attributes {byre.entry_point, byteir.entry_point = {inputs = ["input_tensor@Cast"], outputs = ["softmax_tensor@Cast"]}} {
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<2x1001xf16, "gpuhost"> into memref<2002xf16, "gpuhost">
// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<2002xf16, "gpu">
// CHECK: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2002xf16, "gpu">
// CHECK: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2002xf16, "gpu"> to memref<2002xf16, "gpuhost">
// CHECK-DAG: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK-DAG: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<2x1001xf16, "gpuhost"> into memref<2002xf16, "gpuhost">
// CHECK-DAG: %[[VAL_4:.*]] = memref.alloc() : memref<2002xf16, "gpu">
// CHECK-DAG: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK-DAG: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2002xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2002xf16, "gpu"> to memref<2002xf16, "gpuhost">
// CHECK: return
// CHECK: }

Expand All @@ -522,14 +522,14 @@ module attributes {byre.container_module} {
// CHECK-LABEL: func.func @src_alloc_shape_transform_1(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x224x224x3xf16, "gpuhost"> {byre.argname = "input_tensor@Cast", byre.argtype = 1 : i32},
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x1001xf16, "gpuhost"> {byre.argname = "softmax_tensor@Cast", byre.argtype = 2 : i32}) attributes {byre.entry_point, byteir.entry_point = {inputs = ["input_tensor@Cast"], outputs = ["softmax_tensor@Cast"]}} {
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : memref<2x1001xf16, "gpuhost"> into memref<2x1x1001xf16, "gpuhost">
// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<2x1x1001xf16, "gpu">
// CHECK: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2x1x1001xf16, "gpu">
// CHECK: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2x1x1001xf16, "gpu"> to memref<2x1x1001xf16, "gpuhost">
// CHECK-DAG: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK-DAG: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : memref<2x1001xf16, "gpuhost"> into memref<2x1x1001xf16, "gpuhost">
// CHECK-DAG: %[[VAL_4:.*]] = memref.alloc() : memref<2x1x1001xf16, "gpu">
// CHECK-DAG: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK-DAG: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2x1x1001xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2x1x1001xf16, "gpu"> to memref<2x1x1001xf16, "gpuhost">
// CHECK: return
// CHECK: }

Expand All @@ -554,14 +554,14 @@ module attributes {byre.container_module} {
// CHECK-LABEL: func.func @src_alloc_shape_transform_2(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x224x224x3xf16, "gpuhost"> {byre.argname = "input_tensor@Cast", byre.argtype = 1 : i32},
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x1001xf16, "gpuhost"> {byre.argname = "softmax_tensor@Cast", byre.argtype = 2 : i32}) attributes {byre.entry_point, byteir.entry_point = {inputs = ["input_tensor@Cast"], outputs = ["softmax_tensor@Cast"]}} {
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : memref<2x1001xf16, "gpuhost"> into memref<2x1x1001xf16, "gpuhost">
// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<2x1x1001xf16, "gpu">
// CHECK: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2x1x1001xf16, "gpu">
// CHECK: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2x1x1001xf16, "gpu"> to memref<2x1x1001xf16, "gpuhost">
// CHECK-DAG: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<2x224x224x3xf16, "gpuhost"> into memref<301056xf16, "gpuhost">
// CHECK-DAG: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : memref<2x1001xf16, "gpuhost"> into memref<2x1x1001xf16, "gpuhost">
// CHECK-DAG: %[[VAL_4:.*]] = memref.alloc() : memref<2x1x1001xf16, "gpu">
// CHECK-DAG: %[[VAL_5:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1, 2, 3]] {device = "gpuhost"} : memref<301056xf16, "gpuhost"> into memref<2x224x1x672xf16, "gpuhost">
// CHECK-DAG: %[[VAL_6:.*]] = memref.alloc() : memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<2x224x1x672xf16, "gpuhost"> to memref<2x224x1x672xf16, "gpu">
// CHECK-DAG: byre.compute @foo(%[[VAL_6]], %[[VAL_4]]) {device = "gpu", kernel_name = "main_gpu", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "gpu">, memref<2x1x1001xf16, "gpu">
// CHECK-DAG: memref.copy %[[VAL_4]], %[[VAL_3]] : memref<2x1x1001xf16, "gpu"> to memref<2x1x1001xf16, "gpuhost">
// CHECK: return
// CHECK: }

Expand Down Expand Up @@ -599,3 +599,43 @@ func.func @src_alloc_shape_transform_3(%arg0: memref<2x1x1001xf16>, %arg1: memre
// CHECK: }
// CHECK: return
// CHECK: }

// -----

func.func @insert_slice(%arg0: memref<1024x9xf32>, %arg1: memref<1024x9xf32>) -> (memref<1024x9xf32>, memref<1024x9xf32>) attributes {__byteir_elementwise_fusion__} {
%subview = memref.subview %arg0[0, 0] [1024, 4] [1, 1] : memref<1024x9xf32> to memref<1024x4xf32, strided<[9, 1]>>
%alloc = memref.alloc() : memref<1024x9xf32>
memref.copy %arg1, %alloc : memref<1024x9xf32> to memref<1024x9xf32>
%subview_0 = memref.subview %alloc[0, 0] [1024, 4] [1, 1] : memref<1024x9xf32> to memref<1024x4xf32, strided<[9, 1]>>
memref.copy %subview, %subview_0 : memref<1024x4xf32, strided<[9, 1]>> to memref<1024x4xf32, strided<[9, 1]>>
return %alloc, %arg1 : memref<1024x9xf32>, memref<1024x9xf32>
}

// CHECK-LABEL: func.func @insert_slice
// CHECK: memref.copy
// CHECK: memref.copy

// -----

module attributes {byre.container_module} {
func.func @h2dCopy(%arg0: memref<2x224x224x3xf16, "cpu"> {byre.argname = "input_tensor@Cast", byre.argtype = 1 : i32}, %arg1: memref<2x1001xf16, "cuda"> {byre.argname = "softmax_tensor@Cast", byre.argtype = 2 : i32}) attributes {byre.entry_point} {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] {device = "cpu"} : memref<2x224x224x3xf16, "cpu"> into memref<301056xf16, "cpu">
%expand_shape = memref.expand_shape %collapse_shape [[0, 1, 2, 3]] {device = "cpu"} : memref<301056xf16, "cpu"> into memref<2x224x1x672xf16, "cpu">
%alloc = memref.alloc() : memref<2x1x1x1001xf16, "cuda">
%alloc_0 = memref.alloc() : memref<2x224x1x672xf16, "cuda">
memref.copy %expand_shape, %alloc_0 : memref<2x224x1x672xf16, "cpu"> to memref<2x224x1x672xf16, "cuda">
byre.compute @cudaComputeOp(%alloc_0, %alloc) {device = "cuda", kernel_name = "main_cuda", memory_effects = [1 : i32, 2 : i32]} : memref<2x224x1x672xf16, "cuda">, memref<2x1x1x1001xf16, "cuda">
%alloc_1 = memref.alloc() : memref<2x1x1x1001xf16, "cpu">
memref.copy %alloc, %alloc_1 : memref<2x1x1x1001xf16, "cuda"> to memref<2x1x1x1001xf16, "cpu">
%collapse_shape_2 = memref.collapse_shape %alloc_1 [[0], [1, 2, 3]] {device = "cpu"} : memref<2x1x1x1001xf16, "cpu"> into memref<2x1001xf16, "cpu">
memref.copy %collapse_shape_2, %arg1 : memref<2x1001xf16, "cpu"> to memref<2x1001xf16, "cuda">
return
}
}

// CHECK-LABEL: func.func @h2dCopy
// CHECK: memref.copy
// CHECK: memref.copy
// CHECK: memref.copy


12 changes: 12 additions & 0 deletions frontends/torch-frontend/torch-frontend/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel
fx_utils.py
ts_utils.py

byteir_backend/__init__.py
byteir_backend/compilation_cache.py
byteir_backend/compiled_function.py
byteir_backend/compiler.py
byteir_backend/config.py
byteir_backend/inner_compile.py
byteir_backend/utils.py
byteir_backend/byteir_fusible_pattern.py
byteir_backend/fx_match_utils.py
byteir_backend/fx_utils.py
byteir_backend/partitioners.py

tools/compiler.py
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
*Requires torch version 2.4 or higher*

- Usage

```python
import torch

import torch_frontend
from torch_frontend import byteir_backend as byteir_backend
from torch_frontend.byteir_backend.utils import *

class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0, x1, x2):
r0 = torch.ops.aten.mul(x0, x1)
r1 = torch.ops.aten.div(r0, x2)
x0 = torch.ops.aten.mul(r1, r1) - x0
r2 = torch.ops.aten.slice(x0, 1, 1, 3, 1)
return r1, r2

model = NaiveModel()
opt_mod = torch.compile(model, backend="byteir")

x0 = torch.rand(32, 64).to('cuda')
x1 = torch.rand(32, 64).to('cuda')
x2 = torch.rand(32, 64).to('cuda')

x0 = x0.as_strided(size=(32,16), stride=(64,2), storage_offset=16)
x1 = x1.as_strided(size=(32,16), stride=(64,1), storage_offset=8)
x2 = x2.as_strided(size=(32,16), stride=(32,1), storage_offset=32)

golden = model(x0, x1, x2)
outs = opt_mod(x0, x1, x2)
torch.cuda.synchronize()

torch.testing.assert_close(golden, outs)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch._dynamo import register_backend


@register_backend
def byteir(*args, **kwargs):
from .compiler import byteir_compiler

return byteir_compiler(*args, **kwargs)

def set_cache_dir(path: str):
from .compilation_cache import ByteIRFxGraphCache

ByteIRFxGraphCache.base_cache_dir = path
Loading

0 comments on commit 51ccfe9

Please sign in to comment.