diff --git a/nestedtensor/csrc/activation.cpp b/nestedtensor/csrc/activation.cpp index cb7f5688..66d85026 100644 --- a/nestedtensor/csrc/activation.cpp +++ b/nestedtensor/csrc/activation.cpp @@ -8,7 +8,7 @@ namespace F = torch::nn::functional; namespace at { -Tensor NestedTensor_gelu(const Tensor& self) { +Tensor NestedTensor_gelu(const Tensor& self, const int64_t approximate) { if (is_nested_tensor_impl(self) && get_is_contiguous(self)) { return wrap_buffer( at::gelu(get_buffer(self)), @@ -16,7 +16,7 @@ Tensor NestedTensor_gelu(const Tensor& self) { get_efficient_nested_stride(self)); } return map_nested_tensor( - [](at::Tensor tensor) { return at::gelu(tensor); }, self); + [&approximate](at::Tensor tensor) { return at::gelu(tensor, approximate); }, self); } Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {