diff --git a/python/monarch/gradient/_gradient_generator.cpp b/python/monarch/gradient/_gradient_generator.cpp index 01ea6eed3..858d7b835 100644 --- a/python/monarch/gradient/_gradient_generator.cpp +++ b/python/monarch/gradient/_gradient_generator.cpp @@ -23,6 +23,12 @@ #include // @manual=//caffe2:torch_extension #include // @manual=//caffe2:torch_extension +#define TORCH_VERSION_NEWER_THAN(major, minor, patch) \ + ((TORCH_VERSION_MAJOR > (major)) || \ + (TORCH_VERSION_MAJOR == (major) && TORCH_VERSION_MINOR > (minor)) || \ + (TORCH_VERSION_MAJOR == (major) && TORCH_VERSION_MINOR == (minor) && \ + TORCH_VERSION_PATCH > (patch))) + using torch::autograd::Edge; using torch::autograd::InputBuffer; using torch::autograd::Node; @@ -420,12 +426,20 @@ struct GradientGenerator { DEBUG_PRINT( "// add: " << node->node->name() << ", input_nr=" << static_cast(input_nr) << "\n"); +#if TORCH_VERSION_NEWER_THAN(2, 8, 0) realInputBuffer(node).add( input_nr, check_and_reduce(node->node, input_nr, std::move(t)), std::nullopt, std::nullopt, node->node); +#else + realInputBuffer(node).add( + input_nr, + check_and_reduce(node->node, input_nr, std::move(t)), + std::nullopt, + std::nullopt); +#endif } InputBuffer& realInputBuffer(NodeState* state) {