#  Mat Renderer Demonstration

### Settings

In [None]:
# Path stores material svBRDF maps
maps_path = "assets/rusty_metal"
# Light configurations
light_position = (0.0, 0.0, 10.0)
light_color = (23.47, 21.31, 20.79)
light_power = 10.0

### Load material

In [None]:
from matrenderer.io import load_svbrdf_maps, show_maps, show_rendered
import matplotlib.pyplot as plt
from matrenderer.render import Renderer, Light
from matrenderer.helpers import create_single_batch_maps, create_rendered_maps
import pathlib

maps = load_svbrdf_maps(pathlib.Path(maps_path).absolute())
show_maps(maps)

### Rendering

In [None]:
lights = [Light(light_position, light_color, light_power)]
r = Renderer(lights=lights, gamma=1.0)
batched_maps = create_single_batch_maps(maps)
color, ambient, light, diffuse, specular = r.render(batched_maps)

rendered_maps = create_rendered_maps(color[0], ambient[0], light[0]/light[0].max(), diffuse[0], specular[0])
show_maps(rendered_maps)

In [None]:
# Larger rendered result
show_rendered(color[0])

### Compute gradients

In [None]:
import torch
maps["basecolor"].requires_grad = True
batched_maps = create_single_batch_maps(maps)
color, ambient, radiance, diffuse, specular = r.render(batched_maps)
# dummy loss function
test_loss = torch.sum(color)
test_loss.backward()
print(maps["basecolor"].grad)

### Optimize random svBRDF maps towards rendered image

In [None]:
n_iter = 1000
show_interval = 200
learning_rate = 0.01

from matrenderer.helpers import create_learnable_maps, sigmoid_maps
height, width = maps["basecolor"].shape[1], maps["basecolor"].shape[2]
learning_maps = create_learnable_maps(height, width)
# basic intializations
learning_maps["normal"] = torch.ones((3, height, width), requires_grad=False)
learning_maps["normal"][0:2, :, :] = 0.5
learning_maps["normal"].requires_grad = True
learning_maps["metallic"] = torch.zeros((1, height, width), requires_grad=True)
learning_maps["ao"] = torch.ones((1, height, width), requires_grad=True)

In [None]:
from tqdm import tqdm
l2_loss = torch.nn.MSELoss()
target = color.detach()
pbar = tqdm(range(1, n_iter+1))

optimizer = torch.optim.Adam(learning_maps.values(), lr=learning_rate)
for i_iter in pbar:
    batched_data = create_single_batch_maps(sigmoid_maps(learning_maps))
    rendered, _, _, _, _ = r.render(batched_data)
    loss = l2_loss(rendered, target)
    loss.backward()
    optimizer.step()
    for param in learning_maps.values():
        if param.grad is not None:
            param.grad.data.zero_()
    pbar.set_description("loss: {:.5f}".format(loss.item()))
    
    if i_iter % show_interval == 0:
        show_maps(learning_maps)
        show_rendered(rendered[0], title="iter: #{}".format(i_iter))
        