Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions e2e_testing/torchscript/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,33 @@ def forward(self, x):
def BatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 6, 4))

# ==============================================================================


class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([2, 5, 2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0]


@register_test_case(module_factory=lambda: NativeLayerNormModule())
def NativeLayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))

# ==============================================================================


class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -138,6 +163,8 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))

# ==============================================================================


class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
20 changes: 20 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,26 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
}

def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
TorchIntListType:$normalized_shape,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$layer_norm,
AnyTorchTensorType:$mean,
AnyTorchTensorType:$variance
);
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
20 changes: 13 additions & 7 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
// Step 4. Get var.
// Step 5. Get layernorm.
namespace {
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
class ConvertAtenNativeLayerNormOp
: public OpConversionPattern<AtenNativeLayerNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLayerNormOp op, OpAdaptor adaptor,
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Expand Down Expand Up @@ -889,9 +890,14 @@ class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type varResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success();
}
};
Expand Down Expand Up @@ -3511,8 +3517,8 @@ class ConvertTorchToLinalg
patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenGatherOp>();
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,33 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
return success();
}
};

class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLayerNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

auto input = op.input().getType().cast<BaseTensorType>();
if (!input.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = input.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes;
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
meanVarSizes.push_back(input.getSizes()[i]);
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.input(),
op.normalized_shape(), op.weight(), op.bias(), op.eps());
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
return success();
}
};
} // namespace

namespace {
Expand Down Expand Up @@ -522,6 +549,9 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddcmulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
Expand Down
44 changes: 44 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitBinaryScalarOp(scalarOp);
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
return visitAtenNllLossForwardOp(nllForwardOp, operands);
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
}

// Otherwise, this is an unknown operation. Just mark all results as
Expand Down Expand Up @@ -604,6 +606,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenNativeLayerNormOp(
AtenNativeLayerNormOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace

Expand Down Expand Up @@ -1572,6 +1577,45 @@ ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp(
AtenNativeLayerNormOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();

auto layerNormKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto meanKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto varKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());

layerNormKnowledge.hasSizes = input.hasSizes;
layerNormKnowledge.sizes = input.sizes;
layerNormKnowledge.dtype = input.dtype;

int64_t layerNormSize = input.sizes.size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes;
if (input.hasSizes) {
for (int i = normalizedShapeSizesTorchInt.size(); i < layerNormSize; i++)
meanVarSizes.push_back(input.sizes[i]);
}
meanKnowledge.hasSizes = input.hasSizes;
meanKnowledge.sizes = meanVarSizes;
meanKnowledge.dtype = input.dtype;
varKnowledge.hasSizes = input.hasSizes;
varKnowledge.sizes = meanVarSizes;
varKnowledge.dtype = input.dtype;

auto resultLattice =
getLatticeElement(op.getResult(0)).join(layerNormKnowledge);
resultLattice |= getLatticeElement(op.getResult(1)).join(meanKnowledge);
resultLattice |= getLatticeElement(op.getResult(2)).join(varKnowledge);

return resultLattice;
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit (
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
Expand Down