In [None]:
import drjit as dr
import mitsuba as mi
import matplotlib.pyplot as plt

mi.set_variant('cuda_ad_rgb')

In [None]:
base = '/home/daniel/Studium/masterarbeit/src/data/scenes'
scene_path = f"{base}/benchy/scene.xml"

scene = mi.load_file(scene_path, resx=400, resy=300)
params = mi.traverse(scene)

print(params)

In [None]:
key = 'benchy.bsdf.reflectance.value'
ref_col = mi.Color3f(params[key])
print(ref_col)

In [None]:
params[key] = ref_col
params.update()
ref_img = mi.render(scene, spp=1024)
mi.Bitmap(ref_img)
mi.Bitmap.convert(mi.Bitmap(ref_img), srgb_gamma=True)

In [None]:
opt = mi.ad.Adam(lr=0.05)
opt[key] = mi.Color3f([0.8, 0.2, 0.2])
params.update(opt)

img = mi.render(scene)
mi.Bitmap.convert(mi.Bitmap(img), srgb_gamma=True)

In [None]:
dr.enable_grad(img)
loss = dr.mean(dr.sqr(img - ref_img))
dr.backward(loss)
grad = mi.TensorXf(dr.grad(img))
grad = dr.maximum(0., grad - dr.min(grad))
grad = dr.minimum(1., grad / dr.max(grad))
mi.Bitmap.convert(mi.Bitmap(grad), srgb_gamma=True)

In [None]:
def mse(image):
    return dr.mean(dr.sqr(image - ref_img))

epochs = 50
errors = []

for i in range(epochs):
    img = mi.render(scene, params, spp=64)
    loss = mse(img)
    dr.backward(loss)
    # print(dr.grad(opt[key]))
    opt.step()
    opt[key] = dr.clamp(opt[key], 0.0, 1.0)
    params.update(opt)

    err_ref = dr.sum(dr.sqr(ref_col - params[key]))
    print(f"Iteration {i+1:02d}: parameter error = {err_ref[0]:6f}", end='\r')
    errors.append(err_ref)

print('\nOptimization complete.')

plt.figure(figsize=(6, 4))
plt.plot(errors)
plt.show()

img = mi.render(scene)
mi.Bitmap(img)