Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@
"ElementwiseMaximumIntModule_basic",
"ElementwiseMaxOtherIntModule_basic",
"ElementwiseMaxOtherModule_basic",
"GluStaticModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4189,6 +4189,30 @@ def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
}];
}

def Torch_AtenGluOp : Torch_Op<"aten.glu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::glu : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6382,6 +6382,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %14 = torch.aten.add.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %14 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.remainder.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list<int>\n"
" %9 = torch.aten.add.t %5, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %10 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %11 = torch.aten.slice.t %arg0, %10, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %12 = torch.aten.add.t %9, %11 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %12 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8863,6 +8896,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
43 changes: 43 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,48 @@ class DecomposeAtenNarrowTensorOp
};
} // namespace

namespace {
class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenGluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();

auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "Expected output type having sizes and dtype");
}

Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));

Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2"));

Value splitLength = rewriter.create<AtenFloordivIntOp>(loc, dimSize, two);
Value a = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim, zero,
splitLength);
Value b = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim,
splitLength, splitLength);
// a⊗σ(b)
Value sigmoidB = rewriter.create<AtenSigmoidOp>(loc, outputTy, b);
Value result = rewriter.create<AtenMulTensorOp>(loc, outputTy, a, sigmoidB);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
Expand Down Expand Up @@ -5289,6 +5331,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenEluOp>();
target.addIllegalOp<AtenGluOp>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
def aten〇round〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
if dim < 0:
dim += len(self)
assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2"
return self[:dim] + [self[dim] // 2] + self[dim+1:]

def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1932,6 +1938,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES])
def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
emit("aten::view_as_real : (Tensor) -> (Tensor)")
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")
emit("aten::glu : (Tensor, int) -> (Tensor)")

# Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
Expand Down
18 changes: 18 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3685,3 +3685,21 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module())
def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8))

# ==============================================================================

class GluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 24, 5], torch.float32, True)
])
def forward(self, x):
return torch.ops.aten.glu(x, dim=1)

@register_test_case(module_factory=lambda: GluStaticModule())
def GluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 24, 5))