In [None]:
from tqdm import trange

import numpy as np
import matplotlib.pyplot as plt

import drjit as dr
import mitsuba as mi

from utils import plot_img

plt.style.use('ggplot')
mi.set_variant('cuda_ad_rgb')
mi.set_log_level(mi.LogLevel.Warn)

sess_seed   = np.random.randint(0, 2**30)
sess_seed_g = np.random.randint(0, 2**30)
print(f"session seeds are: sess_seed={sess_seed}; sess_seed_g={sess_seed_g}")

In [None]:
from mitsuba import ScalarTransform4f as T

scene_dict = {
    "type": "scene",
    "integrator": {
        "type": "direct_reparam",
        "hide_emitters": True,
    },

    "sensor": {
        "type": "perspective",
        "fov": 60,
        "to_world": T.look_at(
            origin=[5., 5., 10.],
            target=[5., 5., 0.],
            up=[0, 1, 0]
        ),
        "film": {
            "type": "hdrfilm",
            "width":  256,
            "height": 256,
            "sample_border": True,
        },
    },

    "wall": {
        "type": "ply",
        "filename": "/home/daniel/Studium/masterarbeit/data/scenes/meshes/rectangle.ply",
        "emitter": {
            "type": "area",
            "radiance": { "type": "uniform", "value": 1. },
        },
    }
}

In [None]:
vertex_base = mi.Point3f(
    [-1., -1.,  1.,  1.],
    [-1.,  1., -1.,  1.],
    [ 0.,  0.,  0.,  0.]
)
vertex_base = vertex_base + mi.Point3f(5., 5., 0.)

scene = mi.load_dict(scene_dict)
params = mi.traverse(scene)
key = 'wall.vertex_positions'

def apply_transformation(params, opt):
    opt['s'] = mi.Float(1.) #dr.clamp(opt['s'],  0.1, 2.0)
    opt['p'] = dr.clamp(opt['p'], -3.0, 3.0)
    opt['p'].z = 0.0

    params[key] = dr.ravel(vertex_base * opt['s'] + opt['p'])
    params.update()

opt = mi.ad.Adam(lr=0.01)
opt['s'] = mi.Float(1.0)
opt['p'] = mi.Point3f(1.0, 1.0, 0.0)
apply_transformation(params, opt)

In [None]:
ref_img = mi.render(scene, seed=sess_seed, spp=256)
plot_img(ref_img, figsize=(4, 4))

In [None]:
opt['p'] = mi.Point3f(1.0, 0.0, 0.)
apply_transformation(params, opt)
img = mi.render(scene, seed=0, spp=256)

assert not dr.grad_enabled(img)

dr.enable_grad(img)
loss = dr.mean(dr.sqr(img - ref_img))
dr.backward(loss)
grad = mi.TensorXf(dr.grad(img))

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))

img = ax1.imshow(grad[:, :, 0], interpolation="none", cmap="gray")
fig.colorbar(img, ax=ax1, fraction=0.1, shrink=0.8).set_label('loss')
img = ax2.imshow(grad[:, :, 1], interpolation="none", cmap="gray")
fig.colorbar(img, ax=ax2, fraction=0.1, shrink=0.8).set_label('loss')
img = ax3.imshow(grad[:, :, 2], interpolation="none", cmap="gray")
fig.colorbar(img, ax=ax3, fraction=0.1, shrink=0.8).set_label('loss')

ax1.axis("off")
ax2.axis("off")
ax3.axis("off")

fig.show()

In [None]:
gradn = grad.numpy()
# gradn[:, 140:150] = 1

h = np.zeros_like(gradn)
h[gradn < 0.] = 1.

gradn[gradn > 0] = 0.

# gradn = np.zeros_like(gradn)
# gradn[:116, :] = -0.1

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(np.abs(gradn) / np.max(np.abs(gradn)))
ax2.imshow(h)
ax1.axis("off")
ax2.axis("off")
# fig.show()
plt.close()

In [None]:
losses, vals = [], []

In [None]:
iters = 1 #50

for i in trange(iters) if iters > 1 else range(iters):
    apply_transformation(params, opt)
    img = mi.render(scene, params, seed=sess_seed+i, seed_grad=sess_seed_g+i, spp=16)

    loss = dr.mean(dr.sqr(img - ref_img))
    dr.backward(loss, flags=dr.ADFlag.ClearNone)

    # dr.set_grad(img, mi.TensorXf(gradn))
    # dr.backward_to(opt['s'], flags=dr.ADFlag.ClearNone)

    if iters < 2:
        v = vertex_base.numpy()[:, :2]
        g = dr.unravel(mi.Point3f, dr.grad(params[key])).numpy()[:, :2]
        p = dr.grad(opt['p']).numpy()[0, :2]
        print(g)

        w = 0.005
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.set_xlim(3, 7)
        ax.set_ylim(3, 7)
        plt.plot([4, 6, 6, 4, 4], [4, 4, 6, 6, 4], color="C1", alpha=0.8)
        plt.quiver(v[:, 0], v[:, 1], g[:, 0], g[:, 1], width=w)
        plt.quiver(5, 5, p[0], p[1], width=w)
        fig.show()
        # plt.close()
    else:
        opt.step()
        losses.append(loss)
        vals.append(opt['a'])

In [None]:
if iters > 1:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    ax1.plot(np.array(losses)[:, 0])
    ax1.set_ylabel("loss")
    ax2.plot(np.array(vals)[:, 0])
    ax2.set_ylabel("angles")
    fig.show()