diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index a33b338..4c791e6 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include #include @@ -2273,7 +2275,8 @@ class IrParser { } { - auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor"); + auto ptr_op = getOperatorForLiteral( + "aten::gelu(Tensor self, int approximate=0) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2283,7 +2286,20 @@ class IrParser { c10::nullopt, value_map[node->inputs()[0]->unique()]); auto self = list_val.front(); list_val.pop_front(); - auto out = gelu(self); + + auto approximate = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const bool kApproximate = approximate.value(); + + Val* out = nullptr; + if (kApproximate == at::Gelu::Tanh) { + out = fast_gelu(self); + } else { + out = unaryOp(UnaryOpType::Gelu, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -2293,7 +2309,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor"); + "aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2308,7 +2324,19 @@ class IrParser { auto self = list_val.front(); list_val.pop_front(); - auto grad_in = gelu_backward(grad_out, self); + auto approximate = constant_as(node->input(2)); + TORCH_INTERNAL_ASSERT( + approximate.has_value(), + "The approximate parameter is required."); + const bool kApproximate = approximate.value(); + + Val* grad_in = nullptr; + if (kApproximate == at::Gelu::Tanh) { + grad_in = fast_gelu_backward(grad_out, self); + } else { + grad_in = gelu_backward(grad_out, self); + } + value_map.emplace( node->output()->unique(), ValueHolder(grad_in, format)); }, @@ -3015,6 +3043,38 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { } } + static auto gelu_schema = + getOperatorForLiteral( + "aten::gelu(Tensor self, int approximate=0) -> Tensor") + ->schema(); + if (node->matches(gelu_schema)) { + switch (offset) { + // argument 1: approximate; + case 1: + profileInt(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto gelu_backward_schema = + getOperatorForLiteral( + "aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor") + ->schema(); + if (node->matches(gelu_backward_schema)) { + switch (offset) { + // argument 2: approximate; + case 2: + profileInt(pr, node, offset); + break; + default: + return false; + } + return true; + } + static auto softmax_backward_data_schema = getOperatorForLiteral( "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")