Skip to content

Commit

Permalink
Add suspend and resume grad function in C++
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers committed Sep 15, 2022
1 parent fdb652f commit 112c294
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions include/drjit/array_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <drjit/array_traits.h>
#include <drjit/array_utils.h>
#include <drjit/array_constants.h>
#include <drjit-core/containers.h>

#if defined(min) || defined(max)
# error min/max are defined as preprocessor symbols! Define NOMINMAX on MSVC.
Expand Down Expand Up @@ -1554,6 +1555,61 @@ template <typename... Ts> void disable_grad(Ts&... ts) {
(set_grad_enabled(ts, false), ...);
}

namespace detail {
template <typename T>
void collect_ad_indices(dr_index_vector &indices, const T &value) {
if constexpr (array_depth_v<T> > 1) {
for (size_t i = 0; i < value.derived().size(); ++i)
collect_ad_indices(indices, value.derived().entry(i));
} else if constexpr (is_diff_v<T>) {
uint32_t index = value.index_ad();
if (index)
indices.push_back(index);
} else if constexpr (is_drjit_struct_v<T>) {
struct_support_t<T>::apply_1(
value, [&](const auto &x) { collect_ad_indices(indices, x); });
}
}
}

template <typename T> struct resume_grad {
static constexpr bool Enabled =
is_diff_v<T> && std::is_floating_point_v<scalar_t<T>>;
template <typename... Args>
resume_grad(const Args &... args) {
if constexpr (Enabled) {
dr_index_vector indices;
(detail::collect_ad_indices(indices, args), ...);
detail::ad_scope_enter<typename T::Type>(
detail::ADScope::Resume, indices.size(), indices.data());
}
}

~resume_grad() {
if constexpr (Enabled)
detail::ad_scope_leave<typename T::Type>(true);
}
};

template <typename T> struct suspend_grad {
static constexpr bool Enabled =
is_diff_v<T> && std::is_floating_point_v<scalar_t<T>>;
template <typename... Args>
suspend_grad(const Args &... args) {
if constexpr (Enabled) {
dr_index_vector indices;
(detail::collect_ad_indices(indices, args), ...);
detail::ad_scope_enter<typename T::Type>(
detail::ADScope::Suspend, indices.size(), indices.data());
}
}

~suspend_grad() {
if constexpr (Enabled)
detail::ad_scope_leave<typename T::Type>(true);
}
};

template <bool UnderlyingType, typename T>
decltype(auto) detach(T &&value) {
using Result = std::conditional_t<UnderlyingType, detached_t<T>, std::decay_t<T>>;
Expand Down

0 comments on commit 112c294

Please sign in to comment.