From f9a523dbdf5bbcdec7ebfdd302bbdc6c071a6faf Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 3 Feb 2022 16:53:17 +0530 Subject: [PATCH] Add aten::nll_loss_backward op The lowering of aten::nll_loss_backward op has been added from torch to linalg dialect. The changes has been made as a part of -torch-convert-to-linalg pass. Signed-off-by: Prashant Kumar prashant@nod-labs.com --- e2e_testing/torchscript/nll_loss.py | 58 ++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 20 ++++ .../Dialect/Torch/Utils/TorchUpstream.h | 9 ++ .../TorchToLinalg/TorchToLinalg.cpp | 106 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 25 ++++- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 6 files changed, 217 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/nll_loss.py b/e2e_testing/torchscript/nll_loss.py index 5955184b6f58..0722d657ef4c 100644 --- a/e2e_testing/torchscript/nll_loss.py +++ b/e2e_testing/torchscript/nll_loss.py @@ -60,3 +60,61 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) def NllLossModule_ignore_index(module, tu: TestUtils): module.forward(tu.rand(2, 3), torch.tensor([0, 1])) + +class NllLossModule_backward(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output=grad_output, + self=input, + target=target, + weight=None, + reduction=0, + ignore_index=10, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward()) +def NllLossModuleBackward_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backward_ignore_index(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output=grad_output, + self=input, + target=target, + weight=None, + reduction=0, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case( + module_factory=lambda: NllLossModule_backward_ignore_index()) +def NllLossModuleBackward_ignore_index(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.tensor(3.)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index fa4af730bbc2..aefe9563bb6d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1852,6 +1852,26 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `->` qualified(type($output)) `,` qualified(type($total_weight))"; } +def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + AnyTorchTensorType:$total_weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))"; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 5544300e9ca7..9df9042112a9 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -70,6 +70,15 @@ struct ResultTypeState { ScalarType result_type(const ResultTypeState &in_state); ScalarType promote_skip_undefined(ScalarType a, ScalarType b); +//===----------------------------------------------------------------------===// +// These constants control the reduction behavior of the loss functions. +// None, Mean and Sum corresponds to "do not reduce", "Mean of losses", and "sum +// of losses" respectively. +// Source: +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/Reduction.h +//===----------------------------------------------------------------------===// +enum Reduction { None, Mean, Sum, END }; + } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 6f4c262b99b6..bc8ae281344f 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -28,6 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; +using namespace mlir::torch::torch_upstream; // For ScalarType and type // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) @@ -1323,6 +1325,108 @@ class ConvertAtenNllLossForwardOp }; } // namespace +// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by: +// for i in range(0, len(input[0])): +// for j in range(0, len(input[1])): +// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0 +// TODO: `weight` and `reduction` operands are still to be taken care of. +namespace { +class ConvertAtenNllLossBackwardOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + Value input = adaptor.self(); + Value target = adaptor.target(); + Value weight = adaptor.weight(); + Value gradOutput = adaptor.grad_output(); + + int64_t reduction; + if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) + return rewriter.notifyMatchFailure(op, "dim must be constant"); + + // TODO: Handle reduction. + if (reduction != Reduction::None) + return rewriter.notifyMatchFailure( + op, "reduction along dimensions is not supported."); + + // TODO: Incorporate the weight argument. + if (!weight.getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented, the weight operand is not incorporated."); + + Value ignoreIndex = adaptor.ignore_index(); + Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); + + unsigned inputRank = input.getType().cast().getRank(); + unsigned targetRank = target.getType().cast().getRank(); + + // TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is + // required. + if (inputRank != 2 || targetRank != 1) { + return rewriter.notifyMatchFailure( + op, "expected input and target to be rank 2 and 1 respectively"); + } + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + Type elementType = resultType.getElementType(); + + // Given there is no reduction `grad_input` size is equal to `input` size. + auto outputSize = getTensorSizes(rewriter, loc, input); + Value initTensor0 = + createZeroInitTensor(rewriter, loc, outputSize, elementType); + Value zeroVal = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + SmallVector targetExpr{rewriter.getAffineDimExpr(0)}; + SmallVector resultExpr{rewriter.getAffineDimExpr(0), + rewriter.getAffineDimExpr(1)}; + SmallVector iteratorTypes{getParallelIteratorTypeName(), + getParallelIteratorTypeName()}; + auto indexingMaps = + AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr}); + Value finalRes = + rewriter + .create( + loc, resultType, ValueRange{target, gradOutput}, initTensor0, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value indTarget = rewriter.create( + loc, rewriter.getIndexType(), args[0]); + Value indJ = rewriter.create(loc, 1); + + // The final result is given by: + // grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0 + Value cmpEq = rewriter.create( + loc, arith::CmpIPredicate::eq, indJ, indTarget); + + // The target index shouldn't be equal to `ignoreIndex`. + Value cmpNe = rewriter.create( + loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget); + Value finalPredicate = + rewriter.create(loc, cmpEq, cmpNe); + Value negate = + rewriter.create(loc, elementType, args[1]); + Value selectFinal = rewriter.create( + loc, finalPredicate, negate, zeroVal); + b.create(loc, selectFinal); + }) + .getResult(0); + + rewriter.replaceOp(op, finalRes); + return success(); + } +}; +} // namespace + namespace { // See comments at in convertMmOp and the heading for this section for general // considerations. This function needs to be auto-generated. @@ -4525,6 +4629,8 @@ class ConvertTorchToLinalg patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index eb6f75149791..17cdbc50a1e4 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -489,6 +489,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { return visitBinaryScalarOp(op, operands); } else if (auto nllForwardOp = dyn_cast(op)) { return visitAtenNllLossForwardOp(nllForwardOp, operands); + } else if (auto nllBackwardOp = dyn_cast(op)) { + return visitAtenNllLossBackwardOp(nllBackwardOp, operands); } else if (auto nativeLayerNormOp = dyn_cast(op)) { return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands); } else if (auto constantPadNdOp = dyn_cast(op)) { @@ -647,6 +649,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { ChangeResult visitAtenNllLossForwardOp( AtenNllLossForwardOp op, ArrayRef *> operands); + ChangeResult visitAtenNllLossBackwardOp( + AtenNllLossBackwardOp op, + ArrayRef *> operands); ChangeResult visitAtenNativeLayerNormOp( AtenNativeLayerNormOp op, ArrayRef *> operands); @@ -1188,8 +1193,8 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp( if (self.hasSizes && matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) { - // reduction == 1 means reduce 1st dim. - resultRank = reduction == 1 ? resultRank - 1 : resultRank; + if (reduction != Reduction::None) + resultRank -= 1; } outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize); outputKnowledge.hasSizes = true; @@ -1199,6 +1204,22 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp( return resultLattice; } +ChangeResult TypeAnalyzer::visitAtenNllLossBackwardOp( + AtenNllLossBackwardOp op, + ArrayRef *> operands) { + auto self = operands[1]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + + knowledge.dtype = self.dtype; + if (self.hasSizes) { + unsigned resultRank = self.sizes.size(); + knowledge.sizes.resize(resultRank, kUnknownSize); + knowledge.hasSizes = true; + } + return getLatticeElement(op.getResult()).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp( AtenSqueezeDimOp op, ArrayRef *> operands) { auto operand = operands[0]->getValue(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7b4b71317978..ebf6056026be 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -548,6 +548,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") + emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")