Skip to content

Commit

Permalink
OnnxToTorch lowering group normalization op
Browse files Browse the repository at this point in the history
  • Loading branch information
aldesilv committed Apr 15, 2024
1 parent f97cd48 commit ede3805
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
25 changes: 25 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,4 +1545,29 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, input);
return success();
});
patterns.onOp("GroupNormalization", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
float eps;
int64_t num_groups;
if (binder.tensorOperands(operands, 3) ||
binder.tensorResultType(resultType) || operands.size() != 3 ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f) ||
binder.s64IntegerAttr(num_groups, "num_groups")) {
return failure();
}
Value boolFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);

Value numGroupsValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(num_groups));
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(eps));

rewriter.replaceOpWithNewOp<Torch::AtenGroupNormOp>(
binder.op, resultType, operands[0], numGroupsValue, operands[1], operands[2],
epsValue, boolFalse);
return success();
});
}
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -933,3 +933,12 @@ func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<
%0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

// -----

// CHECK-LABEL: func.func @test_group_normalization_epsilon
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}

0 comments on commit ede3805

Please sign in to comment.