In [None]:
import drjit as dr
import mitsuba as mi

mi.set_variant('llvm_ad_rgb')

integrator = {
    'type': 'direct_reparam',
}

from mitsuba.scalar_rgb import Transform4f as T

scene = mi.load_file("myscenes/barreltest.xml", resx=160, resy=90)

In [None]:
img_ref = mi.render(scene, spp=512)

# Preview the reference image
mi.util.convert_to_bitmap(img_ref)

In [None]:
params = mi.traverse(scene)
initial_vertex_positions = dr.unravel(mi.Point3f, params['mesh-barrels_obj.vertex_positions'])

In [None]:
def apply_transformation(params, opt):
    opt['trans'] = dr.clamp(opt['trans'], -0.5, 0.5)
    opt['angle'] = dr.clamp(opt['angle'], -0.5, 0.5)

    trafo = mi.Transform4f.translate(
        [opt['trans'].x, opt['trans'].y, opt['trans'].z]
    ).rotate(
        [1, 0, 0], opt['angle'].x * 100.0
    ).rotate(
        [0, 1, 0], opt['angle'].y * 100.0
    ).rotate(
        [0, 0, 1], opt['angle'].z * 100.0
    )

    params['mesh-barrels_obj.vertex_positions'] = dr.ravel(trafo @ initial_vertex_positions)
    params.update()

In [None]:
opt = mi.ad.Adam(lr=0.025)
opt['angle'] = mi.Point3f(0, 0.1, 0.2)
opt['trans'] = mi.Point3f(0, 1, 0)

apply_transformation(params, opt)

img_init = mi.render(scene, seed=0, spp=1024)

mi.util.convert_to_bitmap(img_init)

In [None]:
iteration_count = 100
spp = 16

loss_hist = []
for it in range(iteration_count):
    # Apply the mesh transformation
    apply_transformation(params, opt)

    # Perform a differentiable rendering
    img = mi.render(scene, params, seed=it, spp=spp)

    # Evaluate the objective function
    loss = dr.sum(dr.sqr(img - img_ref)) / len(img)

    # Backpropagate through the rendering process
    dr.backward(loss)

    # Optimizer: take a gradient descent step
    opt.step()

    loss_hist.append(loss)
    print(f"Iteration {it:02d}: error={loss[0]:6f}, angle=[{opt['angle'].x[0]:.4f}, {opt['angle'].y[0]:.4f}, {opt['angle'].z[0]:.4f}], trans=[{opt['trans'].x[0]:.4f}, {opt['trans'].y[0]:.4f}, {opt['trans'].z[0]:.4f}]", end='\r')

In [None]:
from matplotlib import pyplot as plt

fig, axs = plt.subplots(2, 2, figsize=(10, 10))

axs[0][0].plot(loss_hist)
axs[0][0].set_xlabel('iteration')
axs[0][0].set_ylabel('Loss')
axs[0][0].set_title('Parameter error plot')

axs[0][1].imshow(mi.util.convert_to_bitmap(img_init))
axs[0][1].axis('off')
axs[0][1].set_title('Initial Image')

axs[1][0].imshow(mi.util.convert_to_bitmap(mi.render(scene, spp=1024)))
axs[1][0].axis('off')
axs[1][0].set_title('Optimized image')

axs[1][1].imshow(mi.util.convert_to_bitmap(img_ref))
axs[1][1].axis('off')
axs[1][1].set_title('Reference Image')