## Import & Setup

In [None]:
# from tqdm import trange
import numpy as np
import matplotlib.pyplot as plt

import drjit as dr
import mitsuba as mi

from libs import utils

mi.set_log_level(mi.LogLevel.Info)
plt.style.use('ggplot')

In [None]:
mi_var = ['scalar_rgb', 'cuda_ad_rgb', 'cuda_ad_acoustic'][2]
mi.set_variant(mi_var)
print(f"Mitsuba variant set to '{mi_var}'")

## Batch Sensor

In [None]:
scene = mi.load_dict({
    "type": "scene",
    "integrator": {
        "type": "prb_acoustic",
        "max_time": 1,
        "max_depth": 150,
    },
    "sensor": {
        "type": "batch",
        "micA": {
            "type": "microphone",
            "to_world": mi.ScalarTransform4f.translate([1, 0, 0]),
        },
        "micB": {
            "type": "microphone",
            "to_world": mi.ScalarTransform4f.translate([0, 0, 1]),
        },
        "tape": {
            "type": "tape",
            "wav_bins":  4,
            "time_bins": 10,
            "rfilter": { "type": "box" },
            "count": True
        },
        "sampler": { "type": "ldsampler", "sample_count": 4 },
    }
})

In [None]:
sensor  = scene.sensors()[0]
film    = sensor.film()
sampler = sensor.sampler()

spp = sampler.sample_count()
sampler.set_samples_per_wavefront(spp)

film_size = film.crop_size()
wavefront_size = film_size.x * spp

sampler.seed(0, wavefront_size)

idx = dr.arange(mi.UInt32, film_size.x * spp)

log_spp = dr.log2i(spp)
if 1 << log_spp == spp:
    idx >>= dr.opaque(mi.UInt32, log_spp)
else:
    idx //= dr.opaque(mi.UInt32, spp)

# Compute the position on the image plane
pos = mi.Vector2i(idx, 0 * idx)

scale = dr.rcp(mi.ScalarVector2f(film.crop_size()))
pos_adjusted = mi.Vector2f(pos) * scale

sensor.sample_ray_differential(
    time=0.0,
    sample1=mi.Float(idx) + 1.,
    sample2=pos_adjusted,
    sample3=sampler.next_2d(),
    active=True
)[0]

In [None]:
prb     = scene.integrator()
sensor  = scene.sensors()[0]
sampler, spp = prb.prepare(sensor, seed=0, spp=4)
ray, w, p, det = prb.sample_rays(scene=scene, sensor=sensor, sampler=sampler, reparam=None)

## Basic scene (coordinates)

In [None]:
cube_scene = {
    "type": "scene",
    "integrator": {
        "type": "path",
        "max_depth": 8,
        "hide_emitters": True,
    },
    "sensor": {
        "type": "perspective",
        "near_clip": 0.1,
        "far_clip": 100.,
        "to_world": mi.ScalarTransform4f.look_at(
            origin=[0, 0, 0],
            target=[0, 0, 1],
            up=[0, 1, 0]
        ),
        "film": {
            "type": "hdrfilm",
            "rfilter": {
                "type": "gaussian"
            },
            "width": 1024,
            "height": 768,
        },
        "sampler": {
            "type": "independent",
            "sample_count": 128,
        },
    },
    "emitter": {
        "type": "constant",
        "radiance": {
            "type": "spectrum",
            "value": 0.99,
        }
    },
    "cube": {
        "type": "cube",
        "to_world": mi.ScalarTransform4f.translate([0., 0., 10.]),
        "bsdf": {
            "type": "diffuse",
            "reflectance": {
                "type": "srgb",
                "color": [.1, .1, .9],
            },
        },
    },
}

In [None]:
#scene = mi.load_file("mitsuba_debug.xml")
scene = mi.load_dict(cube_scene)
#params = mi.traverse(scene)
img = mi.render(scene)
mi.Bitmap(img)

## ImageBlock

### General

In [None]:
imb = mi.ImageBlock(size=[3, 2], offset=[0, 0], channel_count=1, coalesce=False)

p = mi.Point2u([
    [0, 1, 2],
    [0, 0, 0]
])
x = mi.Float([1., 2., 3.])

imb.put(pos=p, values=[x])
np.array(imb.tensor())[:, :, 0]

In [None]:
time_bins, wav_bins = 6, 3
imb = mi.ImageBlock([time_bins, wav_bins], [0, 0], 2, coalesce=False)
imb

In [None]:
x = mi.Float([1., 2., 3.])
imb.put(p, [x, mi.Float(1.)])
np.array(imb.tensor())[:, :, 0]

In [None]:
imb  = mi.ImageBlock([6, 4], [0, 0], 1, coalesce=False)
imb2 = mi.ImageBlock([6, 1], [0, 2], 1, coalesce=False)

p = mi.Point2u([
    [0, 1, 2],
    [2, 2, 2]
])
x = mi.Float([1., 2., 3.])

imb2.put(pos=p, values=[x], active=True)
np.array(imb2.tensor())[:, :, 0]

In [None]:
imb.put_block(imb2)
np.array(imb.tensor())[:, :, 0]

In [None]:
film = mi.load_dict({
    "type": "tape",
    "wav_bins": 2,
    "time_bins": 4,
    "rfilter": { "type": "box" },
    "count": True
})
film.prepare([])

imb  = film.create_block()
imb2 = film.create_block()

p = mi.Point2u([
    [1, 1, 1, 1], # wav_bins  = x
    [0, 1, 2, 1]  # time_bins = y
])
imb.put(pos=p, values=mi.Vector2f(mi.Float([1., 2., 3., -1.]), mi.Float(1.)), active=True)
display(imb.tensor()[:, :, 1].numpy())

x = imb.tensor()[:, :, 0].array
y = imb.tensor()[:, :, 1].array

p = mi.Point2u(dr.meshgrid(
    dr.arange(mi.UInt32, film.crop_size().x),
    dr.arange(mi.UInt32, film.crop_size().y),
))
imb2.put(pos=p, values=mi.Vector2f(x, y), active=True)

assert dr.all(dr.eq(imb.tensor(), imb2.tensor()))

In [None]:
p = mi.Point2u([
    [1, 1, 1], # wav_bins  = x
    [0, 1, 2]  # time_bins = y
])
imb.read(pos=p)

In [None]:
film = mi.load_dict({
    "type": "hdrfilm",
    "width": 2,
    "height": 2,
    "rfilter": { "type": "box" },
})
film.prepare([])

rgb = mi.Spectrum([np.random.rand(8), np.random.rand(8), np.random.rand(8)])
p   = mi.Point2u([0, 0, 1, 1, 0, 0, 1, 1], [0, 1, 0, 1, 0, 1, 0, 1])
det = mi.Float(np.random.rand(8))

### Weighting

In [None]:
block = film.create_block()
block2 = film.create_block()

block.put(
    pos=p,
    wavelengths=mi.Color0f(),
    value=rgb,
    # weight=1.0,
    # alpha=1.0
)

block2.put(
    pos=p,
    wavelengths=mi.Float(),
    value=rgb * det,
    weight=det,
    alpha=1.0
)

film.clear()
film.put_block(block)
img = film.develop()

film.clear()
film.put_block(block2)
img2 = film.develop()

rgb_det     = rgb * det
idx         = dr.arange(mi.UInt32, 4)
rgb_sum     = dr.gather(mi.Spectrum, rgb,     idx) + dr.gather(mi.Spectrum, rgb,     4 + idx)
rgb_det_sum = dr.gather(mi.Spectrum, rgb_det, idx) + dr.gather(mi.Spectrum, rgb_det, 4 + idx)
det_sum     = dr.gather(mi.Float,    det,     idx) + dr.gather(mi.Float,    det,     4 + idx)

display(dr.unravel(mi.Spectrum, img.array))
display(rgb_sum / 2.)

display(dr.unravel(mi.Spectrum, img2.array))
display(rgb_det_sum / det_sum)

### Reconstruction Filter

In [None]:
utils.drjit_turn_off_optimizations(False)

In [None]:
film = mi.load_dict({
    "type": "tape",
    "wav_bins":  21,
    "time_bins": 21,
    "rfilter": { "type": "gaussian", "stddev": 1.0 },
    # "rfilter": { "type": "tent" },
    "count": True
})

film.prepare([])

imb = mi.ImageBlock(size=[11, 11], offset=[0, 0], channel_count=2, coalesce=False, rfilter=film.rfilter(), y_only=True)
# imb = film.create_block()
# imb

In [None]:
p = mi.Point2f(5.2, 5.2)
x = mi.Float([100])

imb.put(pos=p, values=[x, 0])
img = np.array(imb.tensor())[:, :, 0]

plt.imshow(img, interpolation='none', cmap="inferno")
plt.axis("off")
plt.show()

In [None]:
X = np.zeros((21, 11, 1))
X[5, 5] = 100
imb = mi.ImageBlock(mi.TensorXf(X), rfilter=film.rfilter(), y_only=True)
img = np.zeros_like(X)
for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        p = mi.Point2f(j, i)
        img[i, j] = imb.read(pos=p)[0].numpy()

fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.imshow(imb.tensor().numpy()[:, :, 0], interpolation='none', cmap="inferno")
ax1.axis("off")

ax2.imshow(img[:, :, 0], interpolation='none', cmap="inferno")
ax2.axis("off")

plt.show()

## Shapes

In [None]:
cube = mi.load_dict({ "type": "cube" })
cube.vertex_position(dr.arange(mi.Float, cube.vertex_count()))

In [None]:
cube.vertex_normal(dr.arange(mi.Float, 24))

## Warp

In [None]:
sampler = mi.load_dict({ "type": "stratified", "sample_count": 2 ** 10 })
sampler.seed(0, 2**14)
S = sampler.next_2d()

In [None]:
X = mi.warp.square_to_uniform_cone(S, 0.7)

X = mi.Transform4f.rotate(axis=[0, 1, 0], angle=0) @ X

X = X.numpy()

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], s=5.0, c=X[:, 0])
ax.set_xlim(-1., 1.)
ax.set_ylim(-1., 1.)
ax.set_zlim(-1., 1.)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
fig.show()

## Dr.Jit Differentiation

In [None]:
from drjit.cuda.ad import Array3f, Float

### Forward vs Backward

Differentiation of the vector norm $||\textbf{x}||$ using the chain rule
\begin{align}
    \frac{\partial}{\partial\mathbf{x}} ||\textbf{x}|| &= \frac{\partial}{\partial\mathbf{x}} \sqrt{x_1^2 + x_2^2 + \ldots + x_k^2} \\
    &= \frac{1}{2\sqrt{x_1^2 + x_2^2 + \ldots + x_k^2}} \cdot \frac{\partial}{\partial\mathbf{x}} \left(x_1^2 + x_2^2 + \ldots + x_k^2\right) \\
    &= \frac{1}{2||\textbf{x}||} \cdot
        \begin{pmatrix} \frac{\partial}{\partial\mathbf{x_1}} x_1^2 \\ \frac{\partial}{\partial\mathbf{x_2}} x_2^2 \\ \vdots \\ \frac{\partial}{\partial\mathbf{x_k}} x_k^2 \end{pmatrix} \\
    &= \frac{1}{2||\textbf{x}||} \cdot \begin{pmatrix} 2 \cdot x_1 \\ 2 \cdot x_2 \\ \vdots \\ 2 \cdot x_k \end{pmatrix} = \frac{\textbf{x}}{||\textbf{x}||}.
\end{align}
Using the function names of Dr.Jit we obtain a more abstract "graph" of the calculation
\begin{align}
    \frac{\partial}{\partial\mathbf{x}} \text{sqrt}(\text{sum}(\text{sqr}(\mathbf{x}))) &= \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x})))\cdot \frac{\partial}{\partial\mathbf{x}} \text{sum}(\text{sqr}(\mathbf{x})) \\
    &= \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x})))\cdot \text{sum}'(\text{sqr}(\mathbf{x})) \cdot \frac{\partial}{\partial\mathbf{x}} \text{sqr}(\mathbf{x}) \\
    &= \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x})))\cdot \text{sum}'(\text{sqr}(\mathbf{x})) \cdot \text{sqr}'(\textbf{x}) \cdot \frac{\partial}{\partial\mathbf{x}} \mathbf{x} \\
    &= \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x})))\cdot \text{sum}'(\text{sqr}(\mathbf{x})) \cdot \text{sqr}'(\textbf{x}) \cdot 1.
\end{align}

In [None]:
def run_norm():
    X = Array3f(np.array([
        [2, 1, 1],
        [3, 4, 2.]
    ]))

    dr.enable_grad(X)

    X_quad = dr.sqr(X)
    X_sum  = dr.sum(X_quad)
    norm   = dr.sqrt(X_sum)

    return X, X_quad, X_sum, norm

run_norm()[3]

`forward_to(value)` searches for variables considered as inputs and propagates their gradients ("changes", set by `dr.set_grad(...)`) towards the given `value`. Can be interpreted as "how does `value` change, when the input (with set gradient) changes". Given a standard basis vector we obtain the partial derivative of the output with respect to the set dimension of the input gradient. Basically calculates (from left to right in exactly that order)
\begin{equation}
    \text{dr.grad(norm)} =
    \underbrace{
        \underbrace{ \text{grad}_\textbf{x} \cdot \text{sqr}'(\textbf{x}) }_\text{dr.grad(X\_quad)}
        \cdot \text{sum}'(\text{sqr}(\mathbf{x}))
    }_\text{dr.grad(X\_sum)} \cdot \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x}))).
\end{equation}

The flag `dr.ADFlag.ClearNone` just turns off optimizations, such that intermidiate values and gradients can be printed.

In [None]:
X, X_quad, X_sum, norm = run_norm()
dr.set_grad(X, Array3f(0, 0, 1))
dr.forward_to(norm, flags=dr.ADFlag.ClearNone)
dr.grad(X_sum), dr.grad(norm)

`dr.forward_from(X)`/`dr.forward(X)` propagates a gradient of ones from a given input through the whole AD graph. Same as `dr.forward(input)`. Yields the same result as `dr.sum(dr.backward_from(value))`, the sum of all partial derivatives (compare to `forward_to(...)`, but all dimensions of the gradient are set to one).

In [None]:
X, X_quad, X_sum, norm = run_norm()
dr.forward_from(X, flags=dr.ADFlag.ClearNone)
dr.grad(norm)

`dr.backward_from(value)`/`dr.backward(value)` calculates the gradients of the inputs in the traditional way of backpropagation (from left to right)
\begin{equation}
    \frac{\partial}{\partial\mathbf{x}} \text{sqrt}(\text{sum}(\text{sqr}(\mathbf{x}))) = 
    \underbrace {
        \underbrace{ \text{sqrt}'(\text{sum}(\text{sqr}(\mathbf{x}))) }_\text{d\_sqrt} \cdot \text{sum}'(\text{sqr}(\mathbf{x}))
    }_\text{d\_sum} \cdot \text{sqr}'(\textbf{x}).
\end{equation}

In [None]:
X, X_quad, X_sum, norm = run_norm()

# manual backward pass
with dr.suspend_grad():
    d_sqrt = dr.rcp(Float(2) * dr.sqrt(X_sum))
    d_sum  = Array3f(1) * d_sqrt
    d_norm = Float(2) * X * d_sum
    print(d_norm)

dr.backward_from(norm, flags=dr.ADFlag.ClearNone)
dr.grad(X)

`dr.backward_to(input)` backpropagates set gradients of a given variable through the AD graph.

In [None]:
X, X_quad, X_sum, norm = run_norm()
dr.set_grad(norm, Float([1., 1.]))
dr.backward_to(X, flags=dr.ADFlag.ClearNone)
dr.grad(X)

### Chain Rule by two Dr.Jit backward passes

In [None]:
x = Float(np.random.rand(3))

In [None]:
p = Float(x)
dr.enable_grad(p)
img  = dr.cos(p)
loss = dr.sqr(img)
dr.backward(loss)
dr.grad(p)

In [None]:
p = Float(x)
img = dr.cos(p)

dr.enable_grad(img)
loss = dr.sqr(img)
dr.backward(loss)
d_img = dr.grad(img)

dr.enable_grad(p)
img = dr.cos(p)
dr.backward(d_img * img)
dr.grad(p)

In [None]:
p = Float(x)
img = dr.cos(p)

dr.enable_grad(img)
loss = dr.sqr(img)
dr.backward(loss)
d_img = dr.grad(img)

dr.enable_grad(p)
img = dr.cos(p)

dr.set_grad(img, d_img)
dr.enqueue(dr.ADMode.Backward, img)
dr.traverse(Float, dr.ADMode.Backward)

dr.grad(p)

## Dr.Jit Loop

In [None]:
u = mi.Spectrum([0., 3., 4., 5., 8., 9.])
x = mi.TensorXf(np.random.rand(10, 3, 2))
imb = mi.ImageBlock(size=[3, 2], offset=[0, 0], channel_count=2, coalesce=False)

loop = mi.Loop(
    name="Test",
    state=lambda: (u, imb.tensor())
)

In [None]:
imb.tensor()

## AcousticBSDF AD

In [None]:
bsdf = mi.load_dict({
    "type": "acousticbsdf",
    # "scattering": { "type": "spectrum", "value": [(1, 0.2), (2, 0.3)] },
    # "absorption": { "type": "spectrum", "value": [(1, 0.2), (2, 0.3)] },
    "scattering": { "type": "spectrum", "value": 0.5 },
    "absorption": { "type": "spectrum", "value": 0.9 },
})

# key = "absorption.values"
key = "absorption.value"
params = mi.traverse(bsdf)
display(params)
display(params[key])

In [None]:
ctx = mi.BSDFContext()

si    = mi.SurfaceInteraction3f()
si.p  = [0, 0, 0]
si.n  = [0, 0, 1]
si.wi = [0, 0, 1]
si.sh_frame = mi.Frame3f(si.n)
si.wavelengths = mi.Spectrum(1.0)

# theta = 19. / 19.0 * (dr.pi / 2)
# wo = mi.Vector3f([dr.sin(theta), 0, dr.cos(theta)])
wo = mi.Vector3f([0, 0, 1])

dr.enable_grad(params[key])
val, pdf = bsdf.eval_pdf(ctx, si, wo)
dr.backward_from(val)
val, dr.grad(params[key])