Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enzyme error when differentiating stateful lambda, but no problems with explicitly defined functor #1343

Closed
samuelpmishLLNL opened this issue Jul 20, 2023 · 3 comments

Comments

@samuelpmishLLNL
Copy link
Collaborator

Hi, I'm trying to figure how to differentiate functors/lambdas with Enzyme.

Functors seem to work after introducing a level of indirection to turn the operator() into a free function:

struct Functor {
  auto operator() (const std::array<double,2> & x) const {
    return x[0] * a + x[1] * x[1] * b * b;
  }
  double a, b;
};

template < typename T, typename ... arg_types >
auto functor_wrapper(const T & f, arg_types && ... args) {
  return f(args ...);
}

...

Functor f{1.0, 2.0};
std::array<double, 2> x{3.0, 4.0};
std::array<double, 2> dx{5.0, 6.0};

// works!
double dy = __enzyme_fwddiff<double>((void*)functor_wrapper<decltype(f), std::array<double,2> >, enzyme_const, &f, enzyme_dup, &x, &dx);

but when following the same pattern with a stateful lambda we get an error message:

auto f = [a = 1.0, b = 2.0](std::array<double,2> x) {
  return x[0] * a + x[1] * x[1] * b * b;
};
std::array<double, 2> x{3.0, 4.0};
std::array<double, 2> dx{5.0, 6.0};

// error: function '__enzyme_fwddiff<
//    double, 
//    int, 
//    (lambda at <source>:29:12) *, 
//    int, 
//    std::array<double, 2> *, 
//    std::array<double, 2> *
//  >' is used but not defined in this translation unit, 
//     and cannot be defined in any other translation unit because its type does not have linkage
double dy = __enzyme_fwddiff<double>((void*)functor_wrapper<decltype(f), std::array<double,2> >, enzyme_const, &f, enzyme_dup, &x, &dx);

Is there a limitation in enzyme that prevents differentiation of stateful lambdas? The error message seems misleading, it doesn't seem like a linkage issue, since all of this code exists in a single translation unit.

From a C++ point of view, I think of these two examples as practically equivalent, since the compiler implements lambda functions by turning them into functor objects behind the scenes (using cppinsights.io):

Screenshot 2023-07-20 at 3 06 07 PM


Here's a complete example of the issue: https://fwd.gymni.ch/PlRQpl

@wsmoses
Copy link
Member

wsmoses commented Jul 20, 2023

change it to (void*)&f like below and its fine:

  double dy = __enzyme_fwddiff<double>((void*)functor_wrapper<decltype(f), std::array<double,2> >, enzyme_const, (void*)&f, enzyme_dup, &x, &dx);

@wsmoses
Copy link
Member

wsmoses commented Jul 20, 2023

This is a purely C++ error, outside of Enzyme. Basically its saying that no possible code could definite the definition of enzyme_fwdddiff, since it uses a local variable, and the template is marked extern. However, the compiler actually can definite (as we do here). Marking the type as a non-local type (such as void* in my fix, or your functor in above) means that it's fine.

@wsmoses wsmoses closed this as completed Jul 20, 2023
@samuelpmishLLNL
Copy link
Collaborator Author

Thanks for the quick response-- is there any way to emit a descriptive error message here to help users diagnose this sort of problem and implement your fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants