Skip to content

Commit

Permalink
Update Dr.Jit-Core, work around Clang 14 miscompilation (issue #282)
Browse files Browse the repository at this point in the history
On Clang 14, `ArrayBase::shuffle_<..>()` "forgets" to call the default
constructor of the `Derived out` values. On AD variants, this causes a
reference counting failure when `out.entry(..) = ` invokes the
destructor of a field with uninitialized memory. Very strange.

This commit works around the issue by setting
`out = dr::zeros<Derived>()`. It would be good to investigate this
further and potentially submit a Clang bug at some point.
  • Loading branch information
wjakob committed Dec 20, 2022
1 parent ebeed9b commit bfab9ac
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 123 deletions.
2 changes: 1 addition & 1 deletion include/drjit/array_base.h
Expand Up @@ -595,7 +595,7 @@ template <typename Value_, bool IsMask_, typename Derived_> struct ArrayBase {
sizeof...(Indices) == Derived::ActualSize, "shuffle(): Invalid size!");
DRJIT_CHKSCALAR("shuffle_");
size_t idx = 0; (void) idx;
Derived out;
Derived out = zeros<Derived>();
((out.entry(idx++) = derived().entry(Indices % Derived::Size)), ...);
return out;
}
Expand Down
142 changes: 53 additions & 89 deletions include/drjit/jit.h
Expand Up @@ -74,13 +74,13 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {

template <typename T, typename Derived2>
JitArray(const JitArray<Backend, T, Derived2> &v) {
m_index = jit_var_new_cast(v.index(), Type, 0);
m_index = jit_var_cast(v.index(), Type, 0);
}

template <typename T, typename Derived2>
JitArray(const JitArray<Backend, T, Derived2> &v,
detail::reinterpret_flag) {
m_index = jit_var_new_cast(v.index(), Type, 1);
m_index = jit_var_cast(v.index(), Type, 1);
}

template <typename T, enable_if_scalar_t<T> = 0>
Expand All @@ -92,7 +92,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
else
av = jit_registry_get_id(Backend, value);

m_index = jit_var_new_literal(Backend, Type, &av, 1, 0, IsClass);
m_index = jit_var_literal(Backend, Type, &av, 1, 0, IsClass);
}

template <typename... Ts, enable_if_t<(sizeof...(Ts) > 1 &&
Expand Down Expand Up @@ -129,71 +129,67 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
// -----------------------------------------------------------------------

Derived add_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Add, m_index, v.m_index));
return steal(jit_var_add(m_index, v.m_index));
}

Derived sub_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Sub, m_index, v.m_index));
return steal(jit_var_sub(m_index, v.m_index));
}

Derived mul_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Mul, m_index, v.m_index));
return steal(jit_var_mul(m_index, v.m_index));
}

Derived mulhi_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Mulhi, m_index, v.m_index));
return steal(jit_var_mulhi(m_index, v.m_index));
}

Derived div_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Div, m_index, v.m_index));
return steal(jit_var_div(m_index, v.m_index));
}

Derived mod_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Mod, m_index, v.m_index));
return steal(jit_var_mod(m_index, v.m_index));
}

auto gt_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Gt, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_gt(m_index, v.m_index));
}

auto ge_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Ge, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_ge(m_index, v.m_index));
}

auto lt_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Lt, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_lt(m_index, v.m_index));
}

auto le_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Le, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_le(m_index, v.m_index));
}

auto eq_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Eq, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_eq(m_index, v.m_index));
}

auto neq_(const Derived &v) const {
return mask_t<Derived>::steal(jit_var_new_op_2(JitOp::Neq, m_index, v.m_index));
return mask_t<Derived>::steal(jit_var_neq(m_index, v.m_index));
}

Derived neg_() const {
return steal(jit_var_new_op_1(JitOp::Neg, m_index));
}
Derived neg_() const { return steal(jit_var_neg(m_index)); }

Derived not_() const {
return steal(jit_var_new_op_1(JitOp::Not, m_index));
}
Derived not_() const { return steal(jit_var_not(m_index)); }

template <typename T> Derived or_(const T &v) const {
return steal(jit_var_new_op_2(JitOp::Or, m_index, v.index()));
return steal(jit_var_or(m_index, v.index()));
}

template <typename T> Derived and_(const T &v) const {
return steal(jit_var_new_op_2(JitOp::And, m_index, v.index()));
return steal(jit_var_and(m_index, v.index()));
}

template <typename T> Derived xor_(const T &v) const {
return steal(jit_var_new_op_2(JitOp::Xor, m_index, v.index()));
return steal(jit_var_xor(m_index, v.index()));
}

template <typename T> Derived andnot_(const T &a) const {
Expand All @@ -205,111 +201,83 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
}

Derived sl_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Shl, m_index, v.index()));
return steal(jit_var_shl(m_index, v.index()));
}

template <int Imm> Derived sr_() const {
return sr_(Imm);
}

Derived sr_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Shr, m_index, v.index()));
}

Derived abs_() const {
return steal(jit_var_new_op_1(JitOp::Abs, m_index));
return steal(jit_var_shr(m_index, v.index()));
}

Derived sqrt_() const {
return steal(jit_var_new_op_1(JitOp::Sqrt, m_index));
}

Derived rcp_() const {
return steal(jit_var_new_op_1(JitOp::Rcp, m_index));
}

Derived rsqrt_() const {
return steal(jit_var_new_op_1(JitOp::Rsqrt, m_index));
}
Derived abs_() const { return steal(jit_var_abs(m_index)); }
Derived sqrt_() const { return steal(jit_var_sqrt(m_index)); }
Derived rcp_() const { return steal(jit_var_rcp(m_index)); }
Derived rsqrt_() const { return steal(jit_var_rsqrt(m_index)); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived exp2_() const {
return steal(jit_var_new_op_1(JitOp::Exp2, m_index));
}
Derived exp2_() const { return steal(jit_var_exp2(m_index)); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived exp_() const {
return exp2(InvLogTwo<T> * derived());
}
Derived exp_() const { return exp2(InvLogTwo<T> * derived()); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived log2_() const {
return steal(jit_var_new_op_1(JitOp::Log2, m_index));
}
Derived log2_() const { return steal(jit_var_log2(m_index)); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived log_() const {
return log2(derived()) * LogTwo<T>;
}
Derived log_() const { return log2(derived()) * LogTwo<T>; }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived sin_() const {
return steal(jit_var_new_op_1(JitOp::Sin, m_index));
}
Derived sin_() const { return steal(jit_var_sin(m_index)); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
Derived cos_() const {
return steal(jit_var_new_op_1(JitOp::Cos, m_index));
}
Derived cos_() const { return steal(jit_var_cos(m_index)); }

template <typename T = Value, enable_if_t<std::is_same_v<T, float> && IsCUDA> = 0>
std::pair<Derived, Derived> sincos_() const {
return { sin_(), cos_() };
}
std::pair<Derived, Derived> sincos_() const { return { sin_(), cos_() }; }

Derived minimum_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Min, m_index, v.index()));
return steal(jit_var_min(m_index, v.index()));
}

Derived maximum_(const Derived &v) const {
return steal(jit_var_new_op_2(JitOp::Max, m_index, v.index()));
return steal(jit_var_max(m_index, v.index()));
}

Derived round_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Round, m_index));
}
Derived round_() const { return Derived::steal(jit_var_round(m_index)); }

template <typename T> T round2int_() const {
return T(round(derived()));
}

Derived floor_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Floor, m_index));
return Derived::steal(jit_var_floor(m_index));
}

template <typename T> T floor2int_() const {
return T(floor(derived()));
}

Derived ceil_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Ceil, m_index));
return Derived::steal(jit_var_ceil(m_index));
}

template <typename T> T ceil2int_() const {
return T(ceil(derived()));
}

Derived trunc_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Trunc, m_index));
return Derived::steal(jit_var_trunc(m_index));
}

template <typename T> T trunc2int_() const {
return T(trunc(derived()));
}

Derived fmadd_(const Derived &b, const Derived &c) const {
return steal(
jit_var_new_op_3(JitOp::Fmadd, m_index, b.index(), c.index()));
return steal(jit_var_fma(m_index, b.index(), c.index()));
}

Derived fmsub_(const Derived &b, const Derived &c) const {
Expand All @@ -328,19 +296,19 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
static Derived select_(const Mask &m, const Derived &t, const Derived &f) {
static_assert(std::is_same_v<Mask, mask_t<Derived>>);
return steal(
jit_var_new_op_3(JitOp::Select, m.index(), t.index(), f.index()));
jit_var_select(m.index(), t.index(), f.index()));
}

Derived popcnt_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Popc, m_index));
return Derived::steal(jit_var_popc(m_index));
}

Derived lzcnt_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Clz, m_index));
return Derived::steal(jit_var_clz(m_index));
}

Derived tzcnt_() const {
return Derived::steal(jit_var_new_op_1(JitOp::Ctz, m_index));
return Derived::steal(jit_var_ctz(m_index));
}

//! @}
Expand Down Expand Up @@ -406,7 +374,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {

static Derived zero_(size_t size) {
Value value = 0;
return steal(jit_var_new_literal(Backend, Type, &value, size));
return steal(jit_var_literal(Backend, Type, &value, size));
}

static Derived full_(Value value, size_t size) {
Expand All @@ -416,8 +384,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
else
av = jit_registry_get_id(Backend, value);

return steal(
jit_var_new_literal(Backend, Type, &av, size, false, IsClass));
return steal(jit_var_literal(Backend, Type, &av, size, false, IsClass));
}

static Derived opaque_(Value value, size_t size) {
Expand All @@ -427,8 +394,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
else
av = jit_registry_get_id(Backend, value);

return steal(
jit_var_new_literal(Backend, Type, &av, size, true, IsClass));
return steal(jit_var_literal(Backend, Type, &av, size, true, IsClass));
}

static Derived arange_(ssize_t start, ssize_t stop, ssize_t step) {
Expand Down Expand Up @@ -500,8 +466,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
const Mask &mask) {
static_assert(
std::is_same_v<detached_t<Mask>, detached_t<mask_t<Derived>>>);
return steal(
jit_var_new_gather(src.index(), index.index(), mask.index()));
return steal(jit_var_gather(src.index(), index.index(), mask.index()));
}

template <bool, typename Index, typename Mask>
Expand All @@ -515,8 +480,8 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
void scatter_(Derived &dst, const Index &index, const Mask &mask) const {
static_assert(
std::is_same_v<detached_t<Mask>, detached_t<mask_t<Derived>>>);
dst = steal(jit_var_new_scatter(dst.index(), m_index, index.index(),
mask.index(), ReduceOp::None));
dst = steal(jit_var_scatter(dst.index(), m_index, index.index(),
mask.index(), ReduceOp::None));
}

template <typename Index, typename Mask>
Expand All @@ -532,8 +497,8 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
const Mask &mask) const {
static_assert(
std::is_same_v<detached_t<Mask>, detached_t<mask_t<Derived>>>);
dst = steal(jit_var_new_scatter(dst.index(), m_index, index.index(),
mask.index(), op));
dst = steal(jit_var_scatter(dst.index(), m_index, index.index(),
mask.index(), op));
}

//! @}
Expand Down Expand Up @@ -644,8 +609,7 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
}

static auto counter(size_t size) {
return uint32_array_t<Derived>::steal(
jit_var_new_counter(Backend, size));
return uint32_array_t<Derived>::steal(jit_var_counter(Backend, size));
}

void set_label_(const char *label) {
Expand Down
2 changes: 1 addition & 1 deletion include/drjit/loop.h
Expand Up @@ -336,7 +336,7 @@ struct Loop<Mask, enable_if_jit_array_t<Mask>> {
// Blend with loop state from last iteration based on mask
for (uint32_t i = 0; i < m_indices.size(); ++i) {
uint32_t i1 = *m_indices[i], i2 = m_indices_prev[i];
*m_indices[i] = jit_var_new_op_3(JitOp::Select, m_cond.index(), i1, i2);
*m_indices[i] = jit_var_select(m_cond.index(), i1, i2);
jit_var_dec_ref(i1);
jit_var_dec_ref(i2);
}
Expand Down
10 changes: 2 additions & 8 deletions include/drjit/vcall_jit_record.h
Expand Up @@ -118,14 +118,8 @@ Result vcall_jit_record_impl(const char *name, uint32_t n_inst,
jit_state.set_self(i);

Mask vcall_mask = true;
if constexpr (Backend == JitBackend::LLVM) {
// no-op to copy the mask into a local parameter
vcall_mask = Mask::steal(jit_var_new_stmt(
Backend, VarType::Bool,
"$r0 = bitcast <$w x i1> %mask to <$w x i1>", 1, 0,
nullptr));
}

if constexpr (Backend == JitBackend::LLVM)
vcall_mask = Mask::steal(jit_var_vcall_mask(Backend));
jit_state.set_mask(vcall_mask.index());

if constexpr (std::is_same_v<Result, std::nullptr_t>) {
Expand Down
4 changes: 1 addition & 3 deletions src/python/main.cpp
Expand Up @@ -258,9 +258,7 @@ PYBIND11_MODULE(drjit_ext, m_) {
.value("Host", AllocType::Host)
.value("HostAsync", AllocType::HostAsync)
.value("HostPinned", AllocType::HostPinned)
.value("Device", AllocType::Device)
.value("Managed", AllocType::Managed)
.value("ManagedReadMostly", AllocType::ManagedReadMostly);
.value("Device", AllocType::Device);

py::enum_<JitBackend>(m, "JitBackend")
.value("CUDA", JitBackend::CUDA)
Expand Down

0 comments on commit bfab9ac

Please sign in to comment.