Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import vision_models
from . import mlp
from . import conv
from . import batchnorm
from . import norm_like
from . import quantized_models
from . import elementwise
from . import type_promotion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export


# ==============================================================================

class BatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -35,8 +35,8 @@ def forward(self, x):
def BatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 3))


# ==============================================================================

class BatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -60,8 +60,8 @@ def forward(self, x):
def BatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 2, 3, 3))


# ==============================================================================

class BatchNorm3DModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -89,6 +89,107 @@ def BatchNorm3DModule_basic(module, tu: TestUtils):

# ==============================================================================

class NativeBatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)


@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))

# ==============================================================================

class NativeBatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)


@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))

# ==============================================================================

class NativeBatchNorm3DModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)


@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))

# ==============================================================================

class NativeBatchNormNoneWeightModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, None, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)


@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))

# ==============================================================================

class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
Expand All @@ -113,7 +214,6 @@ def NativeLayerNormModule_basic(module, tu: TestUtils):

# ==============================================================================


class NativeLayerNormModule4D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -135,10 +235,8 @@ def forward(self, x, weight, bias):
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))


# ==============================================================================


class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -164,8 +262,8 @@ def forward(self, x):
def LayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))


# ==============================================================================

class LayerNormLastDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -189,7 +287,6 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils):

# ==============================================================================


class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,29 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
let assemblyFormat = "$input `,` $weight `,` $bias `,` $stride `,` $padding `,` $dilation `,` $groups attr-dict `:` qualified(type($input)) `,` qualified(type($weight)) `,` qualified(type($bias)) `,` qualified(type($stride)) `,` qualified(type($padding)) `,` qualified(type($dilation)) `,` qualified(type($groups)) `->` qualified(type($result))";
}

def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$training,
Torch_FloatType:$momentum,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let assemblyFormat = "$input `,` $weight `,` $bias `,` $running_mean `,` $running_var `,` $training `,` $momentum `,` $eps attr-dict `:` qualified(type($input)) `,` qualified(type($weight)) `,` qualified(type($bias)) `,` qualified(type($running_mean)) `,` qualified(type($running_var)) `,` qualified(type($training)) `,` qualified(type($momentum)) `,` qualified(type($eps)) `->` qualified(type($result0)) `,` qualified(type($result1)) `,` qualified(type($result2))";
}

def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
123 changes: 123 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,127 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
};
} // namespace

namespace {
class DecomposeAtenNativeBatchNormOp
: public OpRewritePattern<AtenNativeBatchNormOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeBatchNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();
Value runningMean = op.running_mean();
Value runningVar = op.running_var();
Value eps = op.eps();

// TODO: Add support for `training` mode.
bool training = false;
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
training)
return rewriter.notifyMatchFailure(
op, "unimplemented: training mode is not supported");

// Rank of the input tensor must be greater than or equal to 2. The shape of
// the `input` is supposed to be (N, C, D?, H?, W?).
int64_t inputRank = getTensorRank(input);
if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input must have rank greater than or equal to 2");

// In the inference mode, the `runningMean` and `runningVar` must not be
// None.
if (runningMean.getType().isa<Torch::NoneType>() ||
runningVar.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "running stats must not be None in inference mode");

// Rank of `runningMean` and `runningVar` must be exactly 1.
if (getTensorRank(runningMean) != 1 || getTensorRank(runningVar) != 1)
return rewriter.notifyMatchFailure(
op, "expected running_mean and running_var to be rank 1");

Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
// TODO: Add Runtime Asserts to check the shape of weight, bias,
// running_mean and running_var to be (numFeatures).

// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
// to make it broadcast-compatible with (N, C, D?, H?, W?).
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
SmallVector<Value> runningStatsShape(inputRank, one);
runningStatsShape[1] = numFeatures;
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), runningStatsShape);

SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = ShapedType::kDynamicSize;
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);

runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList);
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
runningStatsSizeList);

// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, input.getType(), input, runningMean, /*alpha=*/one);
Value varEps = rewriter.create<AtenAddScalarOp>(
loc, runningVar.getType(), runningVar, eps, /*alpha=*/one);
Value invStd = rewriter.create<AtenRsqrtOp>(loc, varEps.getType(), varEps);
Value normalizedInput = rewriter.create<AtenMulTensorOp>(
loc, inputSubMean.getType(), inputSubMean, invStd);

// The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it
// broadcast-compatible with (N, C, D?, H?, W?).
// 1. weight = weight.view(1, C, 1?, 1?, 1?)
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
// 3. output = normalizedInput * weight + bias
Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) {
// Rank of `weight` must be exactly 1.
if (getTensorRank(weight) != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
// Rank of `bias` must be exactly 1.
if (getTensorRank(bias) != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenAddTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
}

// The `mean` and `invstd` outputs are empty tensors in inference mode.
Value zeroList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(zero.getType()), zero);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyMeanTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
Value emptyInvStdTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);

rewriter.replaceOp(op,
{batchNormOutput, emptyMeanTensor, emptyInvStdTensor});
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -879,6 +1000,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
Expand Down
Loading