Skip to content

Commit

Permalink
BladeDISC related patches
Browse files Browse the repository at this point in the history
* fix float width
* fix divide_floor & export promoteTypes api (#9)
* To comply with the old pytorch versions
* Add native_dropout_backward & native_layer_norm_backward decomposition (#15)
* add native_dropout and related ops pattern (#1211)
* [MHLO] fix dot general contract
* Fix batch_norm, div.Tensor_mode and folder (#21)
* reimplement linear lowering
* reimplement 2-D rhs for mutmul
* add torchdynamo
  • Loading branch information
Tanyo Kwok committed Feb 2, 2023
1 parent 9536174 commit 3528c3a
Show file tree
Hide file tree
Showing 23 changed files with 782 additions and 152 deletions.
28 changes: 27 additions & 1 deletion include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6596,9 +6596,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
Pure,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
ReadOnly,
]> {
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
Expand Down Expand Up @@ -7129,6 +7130,31 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [
}];
}

def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$results
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
Expand All @@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
49 changes: 33 additions & 16 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;

LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
Expand Down Expand Up @@ -166,16 +167,19 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
if (!selfTy)
return op.emitError("only Tensor types supported in MHLO");

if (selfTy.getElementType().isa<mlir::FloatType>()) {
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());
if (selfTy != outTy) {
auto out = rewriter.create<MhloOpT>(op.getLoc(), selfTy, self);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, out);
return success();
} else {
rewriter.replaceOpWithNewOp<MhloOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
self);
return success();
} else {
return op.emitError(
"only floating-point datatype legalization supported");
}
}
};
Expand Down Expand Up @@ -345,15 +349,10 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);

if (!isa<AtenDivTensorModeOp>(op)) {
rewriter.replaceOp(op, result);
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, nullptr);
return success();
}

Expand All @@ -365,6 +364,17 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");

auto computeTy = outType;
if (outElemTy.isIntOrIndex()) {
computeTy =
RankedTensorType::get(outType.getShape(), rewriter.getF32Type());
}
lhs = mhlo::promoteType(rewriter, lhs, computeTy);
rhs = mhlo::promoteType(rewriter, rhs, computeTy);
auto loc = op.getLoc();
auto result =
rewriter.create<ChloOpT>(loc, computeTy, lhs, rhs, nullptr).getResult();

if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
Expand All @@ -378,7 +388,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, result);
return success();
}
};
Expand Down Expand Up @@ -836,7 +846,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
return success();
}

Expand All @@ -862,7 +876,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
return success();
}

Expand Down Expand Up @@ -1463,7 +1481,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
Expand Down
Loading

0 comments on commit 3528c3a

Please sign in to comment.