Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix read-after-free in jitc_cuda_assemble #78

Merged
merged 1 commit into from
Feb 2, 2024
Merged

Conversation

merlinND
Copy link
Member

@merlinND merlinND commented Feb 2, 2024

Identified using the reproducer script below, with MI_SANITIZE_ADDRESS enabled, defining DRJIT_VALGRIND=1 in var.cpp, and running under AddressSanitizer.

The full command was:

LD_PRELOAD="$(pwd)/../resources/dlopen_interceptor.so:$(clang++-15 -print-file-name=libclang_rt.asan-x86_64.so)" ASAN_OPTIONS="protect_shadow_gap=0,verify_asan_link_order=0,detect_stack_use_after_return=1,detect_leaks=1" DRJIT_NO_RTLD_DEEPBIND=1 CUDA_VISIBLE_DEVICES=0 CUDA_LAUNCH_BLOCKING=1 python3.11 -u ../src/render/tests/test_segfault_optix.py > log.txt 2>&1

where dlopen_interceptor.so is a short library that removes RTLD_DEEPBIND flags from any dlopen call. Code by zhi_xz (source):

// Compile with: clang -shared -o dlopen_interceptor.so dlopen_interceptor.c
#include <dlfcn.h>
#include <stdio.h>

typedef void* (*orig_dlopen_func_type)(const char*, int);
void* dlopen(const char* filename, int flags) {
  void* result;
  orig_dlopen_func_type original_dlopen;
  original_dlopen = (orig_dlopen_func_type)dlsym(((void*)-1l), "dlopen");
  if (flags & RTLD_DEEPBIND) {
    printf("Intercepted dlopen(%s, %d)\n", filename, flags);
    flags &= ~RTLD_DEEPBIND;
    printf("Adjusted flags to %d\n", flags);
  }
  result = (*original_dlopen)(filename, flags);
  return result;
}

This may have been the root cause of the following crashes:

Edit: in hindsight, this same bug was already found & fixed for LLVM in the following PR: #58

Reproducer script
import drjit as dr
import mitsuba as mi

# Doesn't seem to segfault without vcall
USE_BSDF_VCALL = True
# Doesn't seem to segfault without OptiX
USE_OPTIX = True
# Doesn´t seem to segfault if ray tracing result is evaluated
USE_SI_EVAL = False


def test01_reproduce_segfault():
    mi.set_variant('cuda_ad_rgb')

    # dr.set_log_level(dr.LogLevel.Warn)
    mi.set_log_level(mi.LogLevel.Trace)
    # dr.set_log_level(dr.LogLevel.Debug)
    # dr.set_flag(dr.JitFlag.ValueNumbering, False)

    scene = mi.load_dict({
        "type": "scene",
        "my_bsdf": {
            "type": "twosided",
            "nested": {
                "type": "diffuse",
                "reflectance": {"type": "rgb", "value": [0.102473, 0.102473, 0.102473]},
            },
        },

        # "emitter": {
        #     "type": "constant",
        #     # "radiance": {"type": "rgb", "value": [1.0, 1.0, 1.0]},
        # },

        "mesh-large_reflector": {
            "type": "rectangle",
            "to_world": (
                mi.ScalarTransform4f.scale([10, 10, 10])
            ),
            "bsdf": {
                "type": "ref",
                "id": "my_bsdf",
            }
        },
        "mesh-small_reflector": {
            "type": "rectangle",
            "to_world": (
                mi.ScalarTransform4f.translate([0, 0, 6])
                @ mi.ScalarTransform4f.scale([10, 10, 10])
            ),
            "bsdf": {
                "type": "ref",
                "id": "my_bsdf",
            }
        },
    })

    if not USE_BSDF_VCALL:
        single_bsdf = scene.shapes()[0].bsdf()


    n_cells = 128

    # Note: crash doesn't happen with max_depth=1 or max_depth=2
    # With max_depth=3, happens reliably at it_i=2465
    def forward(ray: mi.Ray3f, max_depth=3):
        result = dr.zeros(mi.Float, n_cells)

        ctx = mi.BSDFContext()  # mode=mi.TransportMode.Importance

        # for bounce_i in range(max_depth):
        for bounce_i in range(max_depth):
            if USE_OPTIX:
                # do_ray_tracing = mi.Mask(True)
                # Crash triggers faster when providing this mask
                do_ray_tracing = dr.full(mi.Mask, True, shape=dr.width(ray))
                si_scene = scene.ray_intersect(ray, active=do_ray_tracing)
            else:
                # Doesn't crash if using analytic intersections (no OptiX)
                si_scene = dr.zeros(mi.SurfaceInteraction3f, dr.width(ray))
                assert dr.all(dr.isinf(si_scene.t))
                for shape in scene.shapes():
                    # TODO: not quite correct
                    si = shape.ray_intersect(ray)
                    # print('Per shape:', dr.count(si.is_valid()))
                    si_scene[si.is_valid() & (si.t < si_scene.t) & (si.t >= 0)] = si

            active_scene = si_scene.is_valid()

            # Doesn't crash if evaled
            if USE_SI_EVAL:
                dr.eval(si_scene)

            if USE_BSDF_VCALL:
                bsdf_ptr = si_scene.bsdf()
            else:
                bsdf_ptr = single_bsdf

            value = bsdf_ptr.eval(ctx, si_scene, si_scene.to_local(-ray.d), active=active_scene)

            value = dr.mean(value)
            assert isinstance(value, mi.Float)
            index = dr.arange(mi.UInt32, dr.width(ray)) % dr.width(result)

            active = active_scene & dr.isfinite(value)

            # Crashes with either `scatter` or `scatter_reduce`
            # dr.scatter_reduce(dr.ReduceOp.Add, result, value, index, active=active)
            dr.scatter(result, value, index, active=active)

        return result


    params = mi.traverse(scene)
    param_name = 'mesh-large_reflector.bsdf.brdf_0.reflectance.value'
    params.keep(param_name)
    opt = mi.ad.Adam(lr=0.05, params=params)
    params.update(opt)

    n_iter = 100000
    n_samples_sqrt = 64
    for it_i in range(n_iter):

        print('\n\n----------- forward')
        dir_samples = dr.meshgrid(*[dr.linspace(mi.Float, 0, 1, n_samples_sqrt) for _ in range(2)])
        ray = mi.Ray3f(
            o=mi.Point3f(5.0, 0.0, 5.0),
            d=mi.warp.square_to_uniform_sphere(mi.Point2f(dir_samples))
        )

        result = forward(ray)

        loss = dr.mean(dr.sqr((result - 0.5)))

        print('\n\n----------- eval(result, loss)')
        dr.eval(result, loss)

        # if it_i == 0:
        if it_i == 164:
            # dr.set_log_level(dr.LogLevel.Trace)

            # dr.set_label(params[param_name], 'params.value')
            # dr.set_label(opt[param_name], 'opt.value')
            # dr.set_label(loss, 'loss')
            # dr.set_label(ray, 'ray')

            # with open('graph_ad.dot', 'w') as f:
            #     f.write(dr.graphviz_ad(as_str=True))
            pass

        # The kernel that segfaults is fully in the backward pass.
        print('\n\n----------- backward')
        with dr.scoped_set_flag(dr.JitFlag.KernelHistory, True):
            dr.backward(loss)
            for k, v in opt.items():
                dr.schedule(v, dr.grad(v))
            dr.eval()

            # Useful to try and print the PTX of the crashing kernel
            history = dr.kernel_history()
            for kernel in history:
                if ('hash' in kernel) and ('66ac29079185f015' in kernel['hash']):
                    breakpoint()

        print('\n\n----------- step')
        opt.step()
        params.update(opt)

        print(f'{it_i=}: {loss=}')


    print('Optimization done!')


if __name__ == '__main__':
    test01_reproduce_segfault()

@merlinND merlinND requested a review from wjakob February 2, 2024 11:12
@merlinND
Copy link
Member Author

merlinND commented Feb 2, 2024

Follow-up questions:

  • Is the same issue present in the LLVM backend? --> actually, the LLVM backend already had a fix.
  • Do we need to understand what caused the v pointer to be invalidated and take action based on that, or is it enough to redo the lookup? --> after discussion with @wjakob, virtual function calls should be the only case causing new variables to be allocated or de-allocated.
  • Would it be worth adding some checks to only do the lookup when there was an actual invalidation? --> as suggested by @wjakob, switched to a solution without any new lookup + an assert for debug mode.

With help from @wjakob.

The LLVM backend already included a different fix, this commit uses
the same (cheaper) fix on both backends.
@wjakob
Copy link
Member

wjakob commented Feb 2, 2024

LGTM! I will merge this -- Merlin, could I also ask you to bubble this up to a commit in Mitsuba so that we can see if the CI still passes?

@wjakob wjakob merged commit a525d6a into master Feb 2, 2024
4 checks passed
@merlinND
Copy link
Member Author

merlinND commented Feb 2, 2024

Thank you! Doing it now.

Edit: mitsuba-renderer/mitsuba3#1057

@njroussel
Copy link
Member

Is the same issue present in the LLVM backend? --> actually, the LLVM backend already had a fix.

My bad, I did not think about checking this for CUDA at the time we fixed it...
Thanks for the fix Merlin!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants