## Acoustic Integration - Automatic Differentiation

http://localhost:8888/?token=sloth

In [None]:
if "mi" not in vars():
    import os

    from tqdm import trange

    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    import drjit as dr
    import mitsuba as mi

    from libs import utils, acoustic_torch

    mi.set_variant('cuda_ad_acoustic')
    mi.set_log_level(mi.LogLevel.Warn)

    plt.style.use('ggplot')
    utils.drjit_turn_off_optimizations(False)

    sess_seed   = np.random.randint(0, 2**30)
    sess_seed_g = np.random.randint(0, 2**30)
    print(f"session seeds are: sess_seed={sess_seed}; sess_seed_g={sess_seed_g}")

### Scene Construction

In [None]:
absorption = [0.65, 0.65]
scattering = 0.05

config = {
    "box_dim":     [20., 20., 20.],
    "mic_pos":     [9.,   9.,  9.],
    "speaker_pos": [11., 11., 11.],
    "speaker_radius": 1.0,

    "absorption": [(i + 1, a) for i, a in enumerate(absorption)],
    "scattering": scattering,

    "wav_bins":  len(absorption),
    "time_bins": 150,
    "max_time":  1.5,

    # "integrator": "prb_acoustic",
    "integrator": "prb_reparam_acoustic",
    "max_depth": 50,
    "spp": 2**18,
}

fs = config["time_bins"] / config["max_time"]
time = np.linspace(0., config["max_time"], config["time_bins"], endpoint=False)

# config["max_depth"] = utils.estimate_max_depth(config["box_dim"], config["max_time"], 1.5)
print(f"max_depth = {config['max_depth']}")

scene_dict = utils.shoebox_scene(**config)
# scene_dict = utils.shoebox_scene_visual(**config, resf=4)

In [None]:
scene_dict["integrator"]["skip_direct"] = False
scene_dict["integrator"]["reparam_max_depth"] = 8
# scene_dict["sensor"]["film"]["rfilter"] = {
#     "type": "gaussian",
#     "stddev": 0.25 * 343. * config["max_time"] / config["time_bins"],
# }

scene_dict = {
    "opt_bsdf" : {
        "type": "acousticbsdf",
        "scattering": { "type": "spectrum", "value": scattering },
        "absorption": { "type": "spectrum", "value": [(i+1, a) for i, a in enumerate([0.65, 0.65])], },
    },
    **scene_dict
}

# faces = ['back', 'right', 'top', 'bottom']
faces = np.array(['back', 'front', 'left', 'right', 'top', 'bottom'])
for f in faces[[0, 1]]:
    scene_dict["shoebox"][f]["bsdf"] = { "type": "ref", "id": "opt_bsdf" }

scene = mi.load_dict(scene_dict)

### Reference Histogram

In [None]:
hist_ref = mi.render(scene, seed=sess_seed)
utils.plot_hist(hist_ref[:, :, 0], **config)
# utils.plot_img(hist_ref)

### Optimization Setup

In [None]:
key_a = "opt_bsdf.absorption.values"

params = mi.traverse(scene)

vertex_pos_ref = {}
for f in faces:
    key_g = f"shoebox.{f}.vertex_positions"
    vertex_pos_ref[key_g] = dr.unravel(mi.Point3f, params[key_g])

# display(params)
display(params[key_a])
# display(vertex_pos_ref)

In [None]:
opt = mi.ad.Adam(lr=0.005)
opt[key_a] = mi.Float([0.65, 0.65])
opt['s']   = mi.Vector3f(1.0, 1.0, 1.0)
# opt_s['s'] = mi.Vector3f((np.random.rand(3) - 0.5) * 0.1 + 1.0)

def apply_transform(params_to_update):
    opt[key_a] = dr.clamp(opt[key_a], 0.2, 0.9)
    opt['s'] = dr.clamp(opt['s'], 0.5, 2.)

    transf = mi.Transform4f.scale(opt['s'].z)
    key = "shoebox.back.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])
    key = "shoebox.front.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])

    transf = mi.Transform4f.scale(opt['s'].x)
    key = "shoebox.left.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])
    key = "shoebox.right.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])

    transf = mi.Transform4f.scale(opt['s'].y)
    key = "shoebox.bottom.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])
    key = "shoebox.top.vertex_positions"
    params_to_update[key] = dr.ravel(transf @ vertex_pos_ref[key])

    params_to_update[key_a] = opt[key_a]
    params_to_update.update()

In [None]:
@dr.wrap_ad(source='drjit', target='torch')
def norm(hist):
    hist = hist / torch.sum(hist[:, :, 1], dim=0)[None, :, None]
    return hist

def loss(hist, hist_ref=None):
    assert hist_ref is None
    # t   = acoustic_torch.TS(mi.TensorXf(time), hist[:, :, 0])
    edc = acoustic_torch.EDC(hist[:, :, 0], db=True, norm=True)
    t   = acoustic_torch.T(mi.TensorXf(time), edc)
    return dr.sqr(t[0] - t[1])
    # return dr.sum(t)

In [None]:
# s = 0.05
# params["acoustic_bsdf.scattering.value"] = s
# params["opt_bsdf.scattering.value"]      = s
opt[key_a] = mi.Float([0.2, 0.9])
apply_transform(params)
hist = mi.render(scene, seed=sess_seed)
dr.enable_grad(hist)

l = loss(hist)
dr.backward(l, flags=dr.ADFlag.ClearNone)

grad = dr.grad(hist)
utils.plot_hist(grad[:, :, 0], **config)

# X = hist[:, 1, 0].numpy() * grad[:, 1, 0].numpy()
# plt.plot(time, X / np.max(np.abs(X)))
# G = grad[:, 1, 0].numpy()
# plt.plot(time, G / np.max(np.abs(G)))
# plt.xlim(-0.05, 0.4)
# plt.ylim(-1.10, 1.1)
# plt.show()

In [None]:
n = 7
d = np.zeros((n, 2))
a = np.linspace(0.6, 0.9, n)

for i in trange(n):
    break
    opt[key_a] = mi.Float([0.2, a[i]])
    apply_transform(params)
    hist = mi.render(scene, params, seed=sess_seed, seed_grad=sess_seed_g)
    l    = -1. * loss(hist)
    dr.backward(l)
    d[i, 0] = l.numpy()
    d[i, 1] = dr.grad(opt[key_a])[1]

plt.plot(a, d[:, 1])
plt.show()

### Main Loop

In [None]:
vals, grads, losses = [], [], []

%matplotlib ipympl
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 3))

ax1.set_title("values")
ax1.set_xlim(-1, 51)
ax1.set_ylim(-0.1, 2.1)

ax2.set_title("gradients")
ax2.set_xlim(-1, 51)
ax2.set_ylim(-0.1, 1.1);

ax3.set_title("loss")
ax3.set_xlim(-1, 51)
ax3.set_ylim(-0.1, 1.1);

In [None]:
iters  = 50
if iters > 1:
    n  = len(vals) + iters

for i in trange(iters):
    apply_transform(params)
    hist = mi.render(scene, params, seed=sess_seed+i, seed_grad=sess_seed_g+i)
    l    = -1. * loss(hist)
    dr.backward(l, flags=dr.ADFlag.ClearNone)
    if dr.any(dr.isnan(dr.grad(opt['s'])))[0]:
        dr.set_grad(opt['s'], 0.)
        dr.set_grad(opt[key_a], 0.)
        print(i)
        continue

    if iters < 2:
        display(dr.grad(opt['s']))
    else:
        vals.append(np.append(opt['s'].numpy()[0], opt[key_a].numpy()))
        grads.append(np.append(dr.grad(opt['s']).numpy()[0], dr.grad(opt[key_a]).numpy()))
        losses.append(l.numpy())

        opt.step()

        ax1.clear()
        ax1.set_title("values")
        ax1.set_xlim(-n * 0.02, n * 1.02)
        ax1.set_ylim(-0.1, 2.1)
        ax1.plot(vals)
        ax1.hlines([0.2, 0.9], 0, n, colors='k', linestyles='dotted', linewidth=1.0)
        ax1.hlines([0.5, 2.0], 0, n, colors='k', linestyles='dashdot', linewidth=1.0)

        ax2.clear()
        ax2.set_title("gradients")
        ax2.set_xlim(-n * 0.02, n * 1.02)
        ax2.plot(np.array(grads)[:, :-2], label=["x", "y", "z"])
        # ax2.plot(np.array(grads)[:, -2:], label=["a1", "a2"])
        # ax2.legend()

        ax3.clear()
        ax3.set_title("loss")
        ax3.set_xlim(-n * 0.02, n * 1.02)
        plt.plot(losses)

        fig.canvas.draw()

In [None]:
if iters > 1:
    V = np.stack([np.array(vals), np.array(grads)])
    fname = "../data/ism-diff-exp/ism-diff-exp-n3-02.npy"
    # np.save(fname, V)

### Loss per Optimization step

In [None]:
if False:
    # fname = "../data/ism-diff-exp/ism-diff-exp-n3-02.npy"
    # V = np.load(fname)
    losses = []

    for i in trange(V.shape[1]):
        opt[key_a] = mi.Float(V[0, i, -2:])
        opt['s']   = mi.Vector3f(V[0, i, :-2])
        apply_transform(params)

        img = mi.render(scene, seed=sess_seed)
        losses.append(loss(img))

    np.save(fname.replace('.npy', '-losses.npy'), np.array(losses)[:, 0])

    f, ax = plt.subplots(1, 1)
    ax.plot(np.array(losses)[:, 0])
    plt.show()

### Manual Parameter search

In [None]:
n = 31
values = np.linspace(0.5, 2., n)
res = np.zeros(n)

print(values)

for i in trange(n):
    # opt[key_a] = mi.Float([0.2, 0.9])
    # opt['s']   = mi.Vector3f(1.3, 1.3, values[i])
    opt['s']   = mi.Vector3f(1.0, 1.0, values[i])
    # opt['s']   = mi.Vector3f(values[i], values[i], 1.0)
    apply_transform(params)

    img = mi.render(scene, seed=sess_seed)

    # edc = acoustic_torch.EDC(img[:, :, 0], db=True, norm=True)
    # t   = acoustic_torch.T(mi.TensorXf(time), edc)
    # res[i] = dr.abs(t[1] - t[0]).numpy()

    res[i] = loss(img).numpy()

In [None]:
plt.close()
plt.plot(values, res)
plt.show()

### Scattering Dependence

In [None]:
path = "../data/ism-diff-exp"
files = os.listdir(path)

V = np.zeros((6, 2, 2, 250, 5))
L = np.zeros((6, 2, 250))
for f in files:
    if "losses" in f or "adj" in f:
        continue
    k = f[14:-4].split('-')
    n = int(k[0]) - 1
    l = int(k[-1]) - 1

    V[n, l] = np.load(os.path.join(path, f))
    L[n, l] = np.load(os.path.join(path, f).replace('.npy', '-losses.npy'))

In [None]:
for n in range(6):
    i = np.argmax(L[n, 0])
    params["acoustic_bsdf.scattering.value"] = 0.
    params["opt_bsdf.scattering.value"]      = 0.
    opt[key_a] = mi.Float(V[n, 0, 0, i, -2:])
    opt['s']   = mi.Vector3f(V[n, 0, 0, i, :-2])
    apply_transform(params)
    img = mi.render(scene, seed=sess_seed)
    l0 = loss(img).numpy()[0]

    j = np.argmax(L[n, 1])
    opt[key_a] = mi.Float(V[n, 1, 0, j, -2:])
    opt['s']   = mi.Vector3f(V[n, 1, 0, j, :-2])
    apply_transform(params)
    img = mi.render(scene, seed=sess_seed)
    l1 = loss(img).numpy()[0]

    print(n, l0, l1, L[n, 0, i], L[n, 1, j])

In [None]:
losses = []

i = np.argmax(L[2, 0])
opt[key_a] = mi.Float(V[2, 0, 0, i, -2:])
opt['s']   = mi.Vector3f(V[2, 0, 0, i, :-2])
apply_transform(params)

n = 101
scattering = np.linspace(0., 1., n, endpoint=True)
for i in trange(n):
    params["acoustic_bsdf.scattering.value"] = scattering[i]
    params["opt_bsdf.scattering.value"]      = scattering[i]
    params.update()
    img = mi.render(scene, seed=sess_seed)
    losses.append(loss(img).numpy()[0])

plt.plot(scattering, losses)
plt.show()