In [None]:
from tqdm import trange

import numpy as np
import matplotlib.pyplot as plt

import drjit as dr
import mitsuba as mi

plt.style.use('ggplot')
mi.set_variant('cuda_ad_rgb')

In [None]:
def plot_img(image, figsize=(4, 4)):
    bim = mi.Bitmap.convert(mi.Bitmap(image), component_format=mi.Struct.Type.UInt8, srgb_gamma=True)
    plt.figure(figsize=figsize)
    plt.axis('off')
    plt.imshow(bim)
    plt.show()

In [None]:
from mitsuba import ScalarTransform4f as T

benchy_scene = {
    "type": "scene",
    "integrator": { "type": ["prb_reparam", "direct_reparam", "prb"][1] },

    "sensor": {
        "type": "perspective",
        "to_world": T.look_at(
            origin=[0., 0., 2.],
            target=[0., 0., 0.],
            up=[0, 1, 0]
        ),
        "fov": 60,
        "film": {
            "type": "hdrfilm",
            "width": 64,
            "height": 64,
            "rfilter": { "type": "gaussian" },
            "sample_border": True,
        },
        # "sampler": {
        #     "type": "independent",
        #     "sample_count": 32,
        # },
    },

    "wall": {
        "type": "rectangle",
        "to_world": T.translate([0., 0., -2.]).scale(2.),
        "bsdf": {
            "type": "diffuse",
            "reflectance": {
                "type": "rgb",
                "value": 0.5,
            },
        },
    },

    "bunny": {
        "type": "ply",
        "filename": "/home/daniel/Studium/masterarbeit/src/data/scenes/bunny.ply",
        "to_world": T.scale(6.5),
        "bsdf": {
            "type": "diffuse",
            "reflectance": {
                "type": "rgb",
                "value": [0.3, 0.3, 0.75],
            },
        },
    },

    # "benchy": {
    #     "type": "obj",
    #     "filename": "/home/daniel/Studium/masterarbeit/src/data/scenes/benchy/benchy.obj",
    #     "to_world": T.translate([0., -0.4, 0.]).rotate(axis=[1., 0., 0.], angle=15).scale(0.018),
    #     "bsdf": {
    #         "type": "diffuse",
    #         "reflectance": {
    #             "type": "rgb",
    #             "value": [0.3, 0.3, 0.8]
    #         },
    #     },
    # },

    "light": {
        "type": "sphere",
        "center": [2.5, 2.5, 7.0],
        "radius": 0.25,
        "ermitter": {
            "type": "area",
            "radiance": {
                "type": "rgb",
                "value": [1e3, 1e3, 1e3],
            },
        },
    },
}

In [None]:
scene = mi.load_dict(benchy_scene)

ref_img = mi.render(scene, seed=0, spp=1024)
plot_img(ref_img)

In [None]:
params = mi.traverse(scene)
display(params)

# key = 'benchy.vertex_positions'
key = 'bunny.vertex_positions'
ref_vertices = dr.unravel(mi.Point3f, params[key])

In [None]:
opt = mi.ad.Adam(lr=0.025)
opt['angle'] = mi.Float(0.25)
opt['trans'] = mi.Point2f(0.1, -0.25)

def apply_transformation(params, opt):
    opt['angle'] = dr.clamp(opt['angle'], -0.5, 0.5)
    opt['trans'] = dr.clamp(opt['trans'], -0.5, 0.5)
    transf = mi.Transform4f.translate([opt['trans'].x, opt['trans'].y, 0.0]).rotate([0., 1., 0.], opt['angle'] * 100.)
    params[key] = dr.ravel(transf @ ref_vertices)
    params.update()

In [None]:
apply_transformation(params, opt)
img = mi.render(scene, seed=0, spp=1024)

assert not dr.grad_enabled(img)

dr.enable_grad(img)
diff_img = img - ref_img
loss     = dr.mean(dr.sqr(diff_img))
dr.backward(loss)

grad     = mi.TensorXf(dr.grad(img))
grad     = dr.maximum(0., grad - dr.min(grad))
grad     = dr.minimum(1., grad / dr.max(grad))
img_merged = np.hstack([img.numpy(), grad.numpy()])
plot_img(img_merged, figsize=(8, 4))

In [None]:
iters = 50
errors, losses = [], []

for i in trange(iters):
    apply_transformation(params, opt)
    img = mi.render(scene, params, seed=i, spp=16)

    loss = dr.mean(dr.sqr(img - ref_img))
    dr.backward(loss, flags=dr.ADFlag.Default)
    # print(dr.grad(opt['trans']))

    opt.step()

    err = dr.sum(dr.sqr(dr.ravel(ref_vertices) - params[key]))

    losses.append(loss)
    errors.append(err)

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(errors)

plt.subplot(1, 2, 2)
plt.plot(losses)

plt.show()

img = mi.render(scene, seed=0, spp=1024)
plot_img(img)