Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions onnxruntime/test/providers/cpu/activation/activation_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ float ReluGrad(float dy, float x) {
float SigmoidGrad(float dy, float y) {
return dy * y * (1 - y);
}

float TanhGrad(float dy, float y) {
return dy * (1 - y * y);
}
} // namespace
#endif

TEST_F(ActivationOpTest, Sigmoid) {
Expand Down Expand Up @@ -303,6 +307,22 @@ TEST(SigmoidGradInferenceTest, Basic) {
},
{}, 1, kMSDomain);
}

TEST(TanhGradInferenceTest, Basic) {
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);

TestElementwiseGradientOp(
"TanhGrad",
{{"dY", dY}, {"Y", y_vals}},
[](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], y = params[1];

return TanhGrad(dy, y);
},
{}, 1, kMSDomain);
}
#endif

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,12 @@
"SplitTraining com.microsoft CPUExecutionProvider",
12689204749897364688
],
[
"TanhGrad com.microsoft CPUExecutionProvider",
7147744030478490408
],
[
"ZeroGradient com.microsoft CPUExecutionProvider",
3284255990062374928
]
]
]
42 changes: 18 additions & 24 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static bool SimplifyReshape(const std::vector<Dimension>& target_shape, // the
return false;
}
}
//trim empty strings in the tail of list
// trim empty strings in the tail of list
while (!dim_params.empty() && dim_params.back().empty()) {
dim_params.pop_back();
}
Expand Down Expand Up @@ -90,15 +90,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) {
}

IMPLEMENT_GRADIENT_BUILDER(GetTanhGradient) {
ArgDef Y = O(0);
std::vector<NodeDef> result;
NodeDef one_constant_node = OneConstantNode(OElemType(0));
ArgDef one_arg = one_constant_node.output_args[0];
result.push_back(one_constant_node);
result.push_back(NodeDef("Mul", {Y, Y}, {IA("Squared_Y")}));
result.push_back(NodeDef("Sub", {one_arg, IA("Squared_Y")}, {IA("Sub_Squared_Y")}));
result.push_back(NodeDef("Mul", {GO(0), IA("Sub_Squared_Y")}, {GI(0)}));
return result;
return std::vector<NodeDef>{
NodeDef(OpDef{"TanhGrad", kMSDomain, 1},
{GO(0), O(0)},
{GI(0)})};
}

IMPLEMENT_GRADIENT_BUILDER(GetSqrtGradient) {
Expand Down Expand Up @@ -241,7 +236,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
NodeDef(OpDef{"FusedMatMul", kMSDomain, 1},
{GO(0), B},
{matmul_out},
{{"transB", MakeAttribute("transB", int64_t(1))}}));
{{"transB", MakeAttribute("transB", int64_t(1))}}));
if (A_axes.size() > 0) {
AddReduceSumNode(IA("PreReduceGrad0"), IA("ReduceGrad0"), A_axes, true, result);
result.push_back(NodeDef("Shape", {A}, {IA("A_shape")}));
Expand Down Expand Up @@ -281,7 +276,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMatMulGradient) {
}
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef a_axes, b_axes, a_shape, b_shape, ia_shape;
a_shape = IA("Shape_" + A.name);
b_shape = IA("Shape_" + B.name);
Expand Down Expand Up @@ -451,7 +446,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
}
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef c_axes = IA("ReduceAxes_" + C.name);
ArgDef c_shape = IA("Shape_" + C.name);
ArgDef dy_shape = IA("Shape_" + dY.name);
Expand Down Expand Up @@ -617,7 +612,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetTransposeGradient) {
std::vector<AttributeProto> new_attributes;
if (attributes.empty()) {
const TensorShapeProto& input_shape = I(0).type_proto->tensor_type().shape();
if (input_shape.dim_size() > 0) { //input_shape is available
if (input_shape.dim_size() > 0) { // input_shape is available
int n = input_shape.dim_size() - 1;
bw_perm.resize(n + 1);
std::generate(bw_perm.begin(), bw_perm.end(), [&n] { return n--; });
Expand Down Expand Up @@ -694,7 +689,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetConvGradient) {
}

IMPLEMENT_GRADIENT_BUILDER(GetSigmoidGradient) {
auto const_one = OneConstantNode(OElemType(0));
return std::vector<NodeDef>{
NodeDef(OpDef{"SigmoidGrad", kMSDomain, 1},
{GO(0), O(0)},
Expand Down Expand Up @@ -860,7 +854,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetAddSubGradient) {
}
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef a_axes = IA("ReduceAxes_" + a.name);
ArgDef b_axes = IA("ReduceAxes_" + b.name);
ArgDef A_shape = IA("Shape_" + a.name);
Expand Down Expand Up @@ -944,7 +938,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetMulGradient) {
}
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef a_axes = IA("ReduceAxes_" + a.name);
ArgDef b_axes = IA("ReduceAxes_" + b.name);
ArgDef A_shape = IA("Shape_" + a.name);
Expand Down Expand Up @@ -1001,7 +995,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetDivGradient) {
output.push_back(NodeDef("Identity", {tmp_grad}, {GI(0)}));
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef a_axes = IA("ReduceAxes_" + a.name);
ArgDef A_shape = IA("Shape_" + a.name);
ArgDef B_shape = IA("Shape_" + b.name);
Expand Down Expand Up @@ -1133,17 +1127,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
ArgDef grad = GO(0);
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
if (SrcNodeOpsetVersion() < 13) { //axes is attribute
if (SrcNodeOpsetVersion() < 13) { // axes is attribute
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));

grad = IA("Unqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
}
} else if (numInputs == 2) { //optional input 'axes' is available as input I(1)
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
grad = IA("Unqueezed_Grad");
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), I(1)}, {grad}));
} //axes is not available, the GO(0) is a scalar which can be expanded to required shape
} // axes is not available, the GO(0) is a scalar which can be expanded to required shape
}

result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
Expand Down Expand Up @@ -1443,7 +1437,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetExpandGradient) {
{GI(0)}));
}
} else {
//GetShape failed, build shape-independent gradient graph
// GetShape failed, build shape-independent gradient graph
ArgDef a_axes = IA("ReduceAxes_" + a.name);
ArgDef A_shape = IA("Shape_" + a.name);
ArgDef Y_shape = IA("Shape_" + y.name);
Expand Down Expand Up @@ -1549,10 +1543,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetTileGradient) {
NodeDef unsqueeze_axes = ConstantVectorNode(axes_values, Name("unsqueeze_axes"));
result.push_back(unsqueeze_axes);
result.push_back(NodeDef("Unsqueeze", {IA("orig_shape"), unsqueeze_axes.output_args[0]}, {IA("2d_orig_shape")})); // M, N, K
result.push_back(NodeDef("Unsqueeze", {I(1), unsqueeze_axes.output_args[0]}, {IA("2d_repeats")})); //a, b, c
result.push_back(NodeDef("Unsqueeze", {I(1), unsqueeze_axes.output_args[0]}, {IA("2d_repeats")})); // a, b, c
} else {
result.push_back(NodeDef("Unsqueeze", {IA("orig_shape")}, {IA("2d_orig_shape")}, {MakeAttribute("axes", axes_values)})); // M, N, K
result.push_back(NodeDef("Unsqueeze", {I(1)}, {IA("2d_repeats")}, {MakeAttribute("axes", axes_values)})); //a, b, c
result.push_back(NodeDef("Unsqueeze", {I(1)}, {IA("2d_repeats")}, {MakeAttribute("axes", axes_values)})); // a, b, c
}
result.push_back(NodeDef("Concat", {IA("2d_repeats"), IA("2d_orig_shape")}, {IA("concated_dims_T")},
{MakeAttribute("axis", int64_t(1))})); // [[a, M], [b, N], [c, K]]
Expand Down
35 changes: 31 additions & 4 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ void RegisterTrainingOpSchemas() {
.AddOpset("", 13)
.Const("one", int64_t(1))
.Const("k", axis)
.Const("axis_zero", std::vector<int64_t>({0})) // a 1D tensor constant
.Const("axis_zero", std::vector<int64_t>({0})) // a 1D tensor constant
.Add(R"(
shape = Shape (dY)
n_as_vector = Shape (shape)
Expand Down Expand Up @@ -835,8 +835,8 @@ void RegisterTrainingOpSchemas() {
}
});

//TODO: Move this to the right location. Its only here for quick experimentation.
//TODO: Use the mutli weight / grad version.
// TODO: Move this to the right location. Its only here for quick experimentation.
// TODO: Use the mutli weight / grad version.
ONNX_CONTRIB_OPERATOR_SCHEMA(SGDOptimizer)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down Expand Up @@ -2081,7 +2081,6 @@ Example 4:
return true;
});


ONNX_CONTRIB_OPERATOR_SCHEMA(SigmoidGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down Expand Up @@ -2112,7 +2111,35 @@ Example 4:
onnx_opset_13.set_version(13);

return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});

ONNX_CONTRIB_OPERATOR_SCHEMA(TanhGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("TanhGrad")
.AllowUncheckedAttributes()
.Input(0, "dY", "The gradient tensor from output.", "T")
.Input(1, "Y", "The input tensor. ", "T")
.Output(0, "dX", "Gradient of the input.", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (ONNX_NAMESPACE::TensorProto_DataType)tp->tensor_type().elem_type();
std::vector<FunctionBodyHelper::NodeDef> body{
ONNX_NAMESPACE::Const("C_One", 1.0f, elem_type),
{{"YSquare"}, "Mul", {"Y", "Y"}},
{{"dTanhX"}, "Sub", {"C_One", "YSquare"}},
{{"dX"}, "Mul", {"dY", "dTanhX"}}};

return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {});
});

ONNX_CONTRIB_OPERATOR_SCHEMA(LayerNormalizationGrad)
Expand Down
4 changes: 4 additions & 0 deletions orttraining/orttraining/test/gradient/function_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,9 @@ TEST_F(FunExpansionTest, SigmoidGrad_float) {
TestUnaryOpGrad<float, true>("SigmoidGrad");
}

TEST_F(FunExpansionTest, TanhGrad_float) {
TestUnaryOpGrad<float, true>("TanhGrad");
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -4176,3 +4176,37 @@ def run_step(model, x):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)


def test_tanh_grad():
class NeuralNetTanh(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetTanh, self).__init__()

self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.tanh = torch.nn.Tanh()

def forward(self, input1):
out = self.fc1(input1)
out = self.tanh(out)
return out

def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss
device = 'cuda'

N, D_in, H, D_out = 120, 1536, 500, 1536
pt_model = NeuralNetTanh(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))

for step in range(10):
pt_x = torch.randn(N, D_in, device=device, requires_grad=True)
ort_x = copy.deepcopy(pt_x)
ort_prediction, ort_loss = run_step(ort_model, ort_x)
pt_prediction, pt_loss = run_step(pt_model, pt_x)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
_test_helpers.assert_values_are_close(ort_loss, pt_loss)
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ float SigmoidGrad(float dy, float y) {
return dy * y * (1 - y);
}

float TanhGrad(float dy, float y) {
return dy * (1 - y * y);
}
} // namespace

TEST(GeluGradTest, Basic) {
Expand Down Expand Up @@ -180,6 +183,22 @@ TEST(SigmoidGradTest, Basic) {
{}, 1, kMSDomain);
}

TEST(TanhGradTest, Basic) {
const std::vector<float> y_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);

TestElementwiseGradientOp(
"TanhGrad",
{{"dY", dY}, {"Y", y_vals}},
[](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], y = params[1];

return TanhGrad(dy, y);
},
{}, 1, kMSDomain);
}

namespace {
template <typename TComputeGeluGradScalarFn>
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ TEST(CudaKernelTest, SigmoidGrad_basic) {
}
}

TEST(CudaKernelTest, TanhGrad_basic) {
std::vector<std::vector<int64_t>> test_dims{{4}, {16, 2}, {8, 2, 128, 128}};
for (const auto& test_dim : test_dims) {
TestActivations(test_dim, "TanhGrad", true /* grad_op */);
}
}

static void TestActivationsWithBroadcastBias(
const std::vector<int64_t>& tensor_dim,
const std::string& operator_name,
Expand Down
14 changes: 8 additions & 6 deletions orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherElementsGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);

// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
//class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad);
//class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_float, DropoutGrad);
//class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_double, DropoutGrad);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_float, DropoutGrad);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_double, DropoutGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_MLFloat16, DropoutGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_float, DropoutGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_double, DropoutGrad);
Expand Down Expand Up @@ -154,11 +155,12 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherElementsGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad)>,
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
//BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,
//BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_float, DropoutGrad)>,
//BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_double, DropoutGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_float, DropoutGrad)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_double, DropoutGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_MLFloat16, DropoutGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_float, DropoutGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_double, DropoutGrad)>,
Expand Down
Loading