## Implementation of neural representation of SVBRDF parameters using Mitsuba 3

In this notebook we go over how to use neural networks inside Mitsuba renderer as spatially varying parameter representation (e.g albedo, roguhness, etc.).
This implementation can embed a PyTorch module 'inside' a renderer to be used for forward rendering, and to do backpropagation to the network parameters using automatic differentiation. The idea is simple: tirck Mitsuba to recognize your network as a learnable module so it records it in the computation graph; this is done by defining a fake mitsuba vector that it recognizes; this vector is zero, but it has grad_enabled=True and is summed to the network inputs and it forces Mitsuba to track gradients for the network. The use of a neural network at the core of rendering algorithm mandates enabling Mitsuba wavefront mode, i.e. we need to turn off the mega kernel later in the notebook.

Finally, we will do an simple inverse rendering to show the idea works. A more in-depth and in-action implementation can be found in our project repository: https://inverse-neural-radiosity.github.io

In [1]:
import mitsuba as mi
import drjit as dr
import torch
import torch.nn as nn
mi.set_variant('cuda_ad_rgb')
import tinycudann as tcnn
from tqdm import tqdm

### Custom Texture

We embed the neural network inside a custom texture as a child of mi.Texture class. This network is queried using .eval() for when a mi.Specturm is required (such as an RGB albedo network). For cases that one channel is required such as in a roughness network, the eval_1() is called by the renderer. 

In [2]:
class MyTexture(mi.Texture, nn.Module):
    def __init__(self, props: mi.Properties) -> None:
        mi.Texture.__init__(self, props)
        nn.Module.__init__(self)
        self.network = None        

    def traverse(self, callback):
        if self.network is not None:
            self.network.traverse(callback)
        callback.put_parameter("texture", self, mi.ParamFlags.NonDifferentiable)

    def eval(self, si, active=True, dirs=None, norms=None, albedo=None):
        return self.network.eval(si.p, dirs, norms, albedo)

    def eval_1(self, si, active=True):
        return mi.Float(self.eval(si)[0])

    def eval_1_grad(self, *args, **kwargs):
        raise NotImplementedError()

    def eval_3(self, *args, **kwargs):
        raise NotImplementedError()

    def mean(self, *args, **kwargs):
        raise NotImplementedError()

    def to_string(self):
        return (
            "MyTexture[\n"
            f"  network={self.network}\n"
            "]"
        )

mi.register_texture('mytexture', MyTexture)

### PyTorch blocks
No special treatment required here. We define a reflectance network to represent albedo, it takes location as input and spits out albedo value in [0,1]^3. We use tiny cuda networks by Muller et al., one could use a simple positional encoding.

In [3]:
class TcnnEmbedding(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.embedding = tcnn.Encoding(3, config, dtype=torch.float32)
        self.n_output_dims = self.embedding.n_output_dims

    def forward(self, x):
        return self.embedding.forward(x)


class ReflectanceMlp(nn.Module):
    def __init__(
        self,
        width: int,
        hidden: int,
        embedding,
    ):
        super().__init__()

        self.embedding = TcnnEmbedding(embedding)
        in_size = 3 + self.embedding.n_output_dims

        hidden_layers = []
        for _ in range(hidden):
            hidden_layers.append(nn.Linear(width, width))
            hidden_layers.append(nn.LeakyReLU(inplace=True))

        self.network = nn.Sequential(
            nn.Linear(in_size, width),
            nn.LeakyReLU(inplace=True),
            *hidden_layers,
            nn.Linear(width, 3),
            nn.Sigmoid()
        )

    def forward(self, points):
        net_in = torch.cat([points, self.embedding(points)], dim=-1)
        ret = self.network(net_in)
        return ret

### Wrapper
We wrap the torch modules in a class that handles derivative comminucation between Mitsuba and PyTorch using dr.wrap_ad(). We also need another utility function, vec_to_tens_safe() which makes sure casting types in mitsuba does not detach gradients for that variable. 

Typically, mitsuba system does not recognize the PyTorch network as a differentiable parameter on its own, but it can ask PyTorch to track its operations; when backpropagating, Mitsuba can pass the gradients computed so far by Mitsuba to the PyTorch network (using dr.wrap_ad()). However, if this network is the only part of your system that has gradient tracking enabled, Mitsuba system does not recognize it. To trick it, we put a fake vector aka grad_activator in this class. As you can see, this vector is a mi.Vector3f(0) and it is summed to the input of network in _eval() method. We later enable gradient tracking for this vector, which makes Mitsuba to ask PyTorch to do gradient tracking as well. All this is why we need the wrapper below.

In [4]:
def vec_to_tens_safe(vec):
    # A utility function that converts a Vector3f to a TensorXf safely in mitsuba while keeping the gradients;
    # a regular type cast mi.TensorXf(vector) detaches the gradients
    return mi.TensorXf(dr.ravel(vec), shape=[dr.shape(vec)[1], dr.shape(vec)[0]])

In [5]:
class MitsubaWrapper(nn.Module):
    def __init__(self, scene_min: float, scene_max: float, name: str = None):
        super().__init__()
        self.grad_activator = mi.Vector3f(0)
        self.scene_min = scene_min
        self.scene_max = scene_max
        self.name = name or type(self).__name__

    def eval(self, pts, dirs=None, norms=None, albedo=None):
        pts = (pts - self.scene_min) / (self.scene_max - self.scene_min)
        result = self._eval(pts, dirs, norms, albedo)
        return result

    def traverse(self, callback):
        callback.put_parameter("grad_activator", self.grad_activator, mi.ParamFlags.Differentiable)
        self._traverse(callback)

    def _eval(self, pts, dirs, norms, albedo):
        raise NotImplementedError()

    def _traverse(self, callback):
        pass

class MitsubaReflectanceNetworkWrapper(MitsubaWrapper):
    def __init__(
        self,
        width: int,
        hidden: int,
        embedding: int,
        scene_min,
        scene_max,
    ):
        super().__init__(scene_min, scene_max, "bsdf_net")
        self.network = ReflectanceMlp(width, hidden, embedding)

    def _eval(self, pts, dirs, norms, albedo):
        pts = 2 * pts - 1
        p_tensor = vec_to_tens_safe(pts + self.grad_activator)
        torch_out = self.eval_torch(p_tensor)
        output = dr.unravel(mi.Vector3f, torch_out.array)
        return dr.clamp(output, 0, 1)

    @dr.wrap_ad(source="drjit", target="torch")
    def eval_torch(self, pts):
        return self.network(pts)

    def _traverse(self, callback):
        callback.put_parameter("network", self.network, mi.ParamFlags.Differentiable)


## Example inverse rendering on cornell box

We intend to recover albedo for a cornell box using one input view:

In [6]:
scene = mi.load_dict(mi.cornell_box())
gt = mi.render(scene, spp= 512)

In [7]:
# utility function that non-differentiably renders images using Mitsuba aov integrator
def render_nondiff():
    integrator = mi.load_dict({'type': 'aov',
                                'aovs': 'ab:albedo',
                                'my_image': {'type': 'path'}
                                })
    with dr.suspend_grad():
        with torch.no_grad():
            return mi.render(scene, integrator = integrator, spp = 32)


In [8]:
embedding = {
    'otype': 'HashGrid',
    'n_levels': 17,
    'n_features_per_level': 2,
    'log2_hashmap_size': 18,
    'base_resolution': 2,
    'per_level_scale': 1.5
}

network = MitsubaReflectanceNetworkWrapper(256, 0, embedding, scene.bbox().min, scene.bbox().max)
network.cuda()

MitsubaReflectanceNetworkWrapper(
  (network): ReflectanceMlp(
    (embedding): TcnnEmbedding(
      (embedding): Encoding(n_input_dims=3, n_output_dims=34, seed=1337, dtype=torch.float32, hyperparams={'base_resolution': 2, 'hash': 'CoherentPrime', 'interpolation': 'Linear', 'log2_hashmap_size': 18, 'n_features_per_level': 2, 'n_levels': 17, 'otype': 'Grid', 'per_level_scale': 1.5, 'type': 'Hash'})
    )
    (network): Sequential(
      (0): Linear(in_features=37, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01, inplace=True)
      (2): Linear(in_features=256, out_features=3, bias=True)
      (3): Sigmoid()
    )
  )
)

We create a diffuse brdf object whose reflectance object is our custom texture which has NN under the hood.

In [9]:
texture = mi.load_dict({'type':'mytexture'})
texture.network = network
bsdf = mi.load_dict({'type':'diffuse', 'reflectance': texture})


We set this diffuse BRDF as the one and only BRDF representing all objects in the scene:

In [10]:
objects = ['light', 'floor', 'ceiling', 'back', 'green-wall', 'red-wall', 'small-box', 'large-box']
new_cbox_dict = mi.cornell_box()
for key in objects:
    new_cbox_dict[key]['bsdf'] = bsdf

In [11]:
scene = mi.load_dict(new_cbox_dict)

### Initial state

The initial state of the network with a Sigmoid activation results in gray albedo (~ 0.5) everywhere:

In [12]:
img = render_nondiff()

# Albedo map

mi.Bitmap(img[:,:,-7:-4])

In [13]:
# Rendering

mi.Bitmap(img[:,:,:4])

The grad_activator gradient tracking must be enabled now to make sure our NN receives gradients:

In [14]:
params = mi.traverse(texture)
dr.enable_grad(params['grad_activator'])

We have to turn off the mega kernel now:

In [15]:
def mega_kernel(state):
    dr.set_flag(dr.JitFlag.LoopRecord, state)
    dr.set_flag(dr.JitFlag.VCallRecord, state)
    dr.set_flag(dr.JitFlag.VCallOptimize, state)

mega_kernel(False)

In [16]:
optim = torch.optim.Adam(network.parameters(), lr=0.005)

In [17]:
for i in tqdm(range(200)):
    optim.zero_grad()
    img = mi.render(scene, spp = 8, params=params)
    assert dr.grad_enabled(img)
    loss = dr.mean_nested((img-gt)**2)
    dr.backward(loss)
    optim.step()

    dr.flush_malloc_cache()
    dr.flush_malloc_cache()
    dr.flush_malloc_cache()
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()


  0%|          | 0/200 [00:00<?, ?it/s]

jit_kernel_load(): cache file "/home/saeed/.drjit/2cd330ee20775143bb22984bbda1288b.cuda.bin" is from an incompatible version of Dr.Jit. You may want to wipe your ~/.drjit directory.
jit_kernel_write(): could not link cache file "/home/saeed/.drjit/2cd330ee20775143bb22984bbda1288b.cuda.bin" into file system: File exists


100%|██████████| 200/200 [01:26<00:00,  2.32it/s]


### Final state
Voilà! We have now trained our albedo neural network:

In [18]:
img = render_nondiff()

# Albedo map
mi.Bitmap(img[:,:,-7:-4])

In [19]:
mi.Bitmap(img[:,:,:4])

We should mention that to achieve better results one should continue training for longer, use higher sample per pixel and use multiple input views.

## Citation

```bibtex
@misc{hadadan2023inverse,
      title={Inverse Global Illumination using a Neural Radiometric Prior},
      author={Saeed Hadadan and Geng Lin and Jan Novák and Fabrice Rousselle and Matthias Zwicker},
      year={2023},
      eprint={2305.02192},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

