## Gradients of input.

#### 1. MLPs.

$$ y = mlp(x)$$

#### 2. Voxels.

$$ y = voxel(x)$$


In [85]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from tava.utils.training import learning_rate_decay


class MLP(nn.Module):
    def __init__(
        self, input_dim=3, net_depth=8, net_width=128, skip_layer=4, output_dim=1
    ):
        super().__init__()
        self.input_dim = input_dim
        self.net_depth = net_depth
        self.skip_layer = skip_layer
        self.input_layers = nn.ModuleList()
        in_features = input_dim
        for i in range(net_depth):
            self.input_layers.append(nn.Linear(in_features, net_width))
            if i % skip_layer == 0 and i > 0:
                in_features = net_width + input_dim
            else:
                in_features = net_width
        hidden_features = in_features
        self.output_layer = nn.Linear(hidden_features, output_dim)
        self.net_activation = torch.nn.ReLU()

    def forward(self, x):
        inputs = x
        for i in range(self.net_depth):
            x = self.input_layers[i](x)
            x = self.net_activation(x)
            if i % self.skip_layer == 0 and i > 0:
                x = torch.cat([x, inputs], dim=-1)
        return self.output_layer(x)


class Voxel(nn.Module):
    def __init__(self, res, bbox, input_dim=3, output_dim=1, init=None):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        data_shape = [output_dim] + [res] * input_dim
        if init is not None:
            assert init.shape == torch.Size(data_shape)
        else:
            init = torch.randn(data_shape)
        self.data = nn.Parameter(init.unsqueeze(0))
        self.register_buffer("bbox", torch.tensor(bbox))
    
    def forward(self, x):
        assert x.shape[-1] == self.input_dim
        x = torch.sigmoid(x) * 2. - 1.  # [-1, 1]
        
        # bbox_min, bbox_max = self.bbox.split(split_size=[3, 3], dim=-1)
        # x = (x - bbox_min) / (bbox_max - bbox_min) * 2. - 1.  # [-1, 1]
        x = x.flip(dims=(-1,))  # the convention is k, j, i

        out = F.grid_sample(
            self.data,
            x.view([1] * self.input_dim + [-1, self.input_dim]),
            padding_mode='zeros',
            align_corners=True,
            mode="bicubic",  # "bilinear" | "bicubic"
        ).transpose(1, -1)
        out = out.view(list(x.shape[:-1]) + [self.output_dim])
        return out 


class CNNVoxel(nn.Module):
    def __init__(self, res, bbox, latent_dim=1024, output_dim=1):
        super().__init__()
        self.res = res
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        
        self.register_buffer("bbox", torch.tensor(bbox))
        self.register_buffer("z", torch.randn(1, latent_dim, 1, 1, 1))
        self.layers = nn.ModuleList()
        feature_dim = latent_dim
        for _ in range(int(math.log2(res))):
            self.layers.append(
                nn.ConvTranspose3d(feature_dim, feature_dim // 2, 4, 2, 1)
            )
            feature_dim = feature_dim // 2
        self.output_layer = nn.Conv3d(feature_dim, output_dim, 3, 1, 1)

    def get_voxel(self):
        out = self.z
        for layer in self.layers:
            out = layer(out)
            out = F.leaky_relu(out, 0.2)
        out = self.output_layer(out)
        return out

    def forward(self, x):
        assert x.shape[-1] == 3, "only support 3d for now."
        x = torch.sigmoid(x) * 2. - 1.  # [-1, 1]

        data = self.get_voxel()

        # bbox_min, bbox_max = self.bbox.split(split_size=[3, 3], dim=-1)
        # x = (x - bbox_min) / (bbox_max - bbox_min) * 2. - 1.  # [-1, 1]
        x = x.flip(dims=(-1,))  # the convention is k, j, i

        out = F.grid_sample(
            data,
            x.view(1, 1, 1, -1, 3),
            padding_mode='zeros',
            align_corners=True,
        ).transpose(1, -1)
        out = out.view(list(x.shape[:-1]) + [self.output_dim])
        return out 


class ListVoxel(nn.Module):
    def __init__(self, res, bbox, output_dim=1):
        super().__init__()
        self.output_dim = output_dim
        self.register_buffer("bbox", torch.tensor(bbox))
        self.data = nn.ParameterList()
        num_levels = int(math.log2(res)) + 1
        for _ in range(num_levels):
            self.data.append(
                nn.Parameter(torch.randn([1, output_dim, res, res, res]))
            )
            res = res // 2
    
    def forward(self, x):
        assert x.shape[-1] == 3, "only support 3d for now."
        x = torch.sigmoid(x) * 2. - 1.  # [-1, 1]
        
        # bbox_min, bbox_max = self.bbox.split(split_size=[3, 3], dim=-1)
        # x = (x - bbox_min) / (bbox_max - bbox_min) * 2. - 1.  # [-1, 1]
        x = x.flip(dims=(-1,))  # the convention is k, j, i

        output = torch.zeros(list(x.shape[:-1]) + [self.output_dim]).to(x)
        for data in self.data:
            out = F.grid_sample(
                data,
                x.view(1, 1, 1, -1, 3),
                padding_mode='zeros',
                align_corners=True,
            ).transpose(1, -1)
            out = out.view(list(x.shape[:-1]) + [self.output_dim])
            output = output + out
        return output 


class Engine():
    def __init__(
        self, 
        model, 
        target_func, 
        lr_init,
        lr_final, 
        max_steps, 
        batch_size,
        device,
    ):
        self.model = model.to(device)
        self.target_func = target_func
        self.lr_init = lr_init
        self.lr_final = lr_final
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.device = device

    def run(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr_init)
        
        pbar = tqdm(range(self.max_steps))
        for step in pbar:
            lr = learning_rate_decay(
                step, self.lr_init, self.lr_final, self.max_steps, 
            )
            for param in optimizer.param_groups:
                param["lr"] = lr

            x = torch.rand((self.batch_size, self.model.input_dim)).to(self.device)
            y = self.model(x)
            with torch.no_grad():
                target = self.target_func(x)
            loss = F.mse_loss(y, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description(f"step {step:07d} lr {lr:.4f}: loss {loss.data:.7f}")

    def optim_input(self, x, x_gt, lr_init, lr_final, max_steps):
        x = x.to(self.device)
        x_gt = x_gt.to(self.device)
        x_opt = x.clone().detach().requires_grad_(True)
        with torch.no_grad():
            target = self.model(x_gt)
                
        optimizer = torch.optim.SGD([x_opt], lr=lr_init)
        pbar = tqdm(range(max_steps))
        for step in pbar:
            lr = learning_rate_decay(step, lr_init, lr_final, max_steps)
            for param in optimizer.param_groups:
                param["lr"] = lr
            y = self.model(x_opt)
            loss = F.mse_loss(y, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description(
                f"step {step:07d} lr {lr:.4f}: loss {loss.data:.7f} "
                f"grad {x_opt.grad.abs().mean().data:.7f} "
                f"err {F.mse_loss(x_opt, x_gt).data: .7f} "
            )
        return x_opt

In [86]:
input = torch.rand((128, 2))
input_gt = torch.rand((128, 2))

In [87]:
engine_mlp = Engine(
    model=MLP(input_dim=2, output_dim=21),
    target_func=lambda x: x,
    lr_init=1e-1,
    lr_final=1e-2,
    max_steps=0,
    batch_size=40960,
    device="cuda:1",
)
engine_mlp.run()
input_opt = engine_mlp.optim_input(input, input_gt, lr_init=1e5, lr_final=1e5, max_steps=1000)

0it [00:00, ?it/s]
step 0000999 lr 100000.0000: loss 0.0000000 grad 0.0000000 err  0.0000000 : 100%|██████████| 1000/1000 [00:04<00:00, 216.78it/s]


In [88]:
res = 256
bbox = [0, 0, 0, 1, 1, 1]
# init = torch.stack(torch.meshgrid(
#     torch.linspace(bbox[0], bbox[3], steps=res),
#     torch.linspace(bbox[1], bbox[4], steps=res),
#     torch.linspace(bbox[2], bbox[5], steps=res),
#     indexing="ij",
# ))
init = None
engine_voxel = Engine(
    model=Voxel(res=res, bbox=bbox, input_dim=2, output_dim=21, init=init),
    target_func=lambda x: engine_mlp.model(x),
    lr_init=1e4,
    lr_final=1e4,
    max_steps=2000,
    batch_size=40960,
    device="cuda:1",
)
engine_voxel.run()
input_opt = engine_voxel.optim_input(input, input_gt, lr_init=1e-1, lr_final=1e-4, max_steps=2000)

step 0001999 lr 10000.0000: loss 0.0000004: 100%|██████████| 2000/2000 [00:07<00:00, 252.00it/s]
step 0001999 lr 0.0001: loss 0.0000038 grad 0.0000002 err  0.1613045 : 100%|██████████| 2000/2000 [00:03<00:00, 562.72it/s]


In [84]:
input_opt[0:3] * 255

tensor([[  2.3684, 131.8379],
        [122.5303, 161.5197],
        [ 39.2880, 248.4836]], device='cuda:1', grad_fn=<MulBackward0>)

In [16]:
res = 256
bbox = [0, 0, 0, 1, 1, 1]
init = None
engine_vomlp = Engine(
    model=nn.Sequential(
        Voxel(res=res, bbox=bbox, output_dim=21, init=init),
        MLP(input_dim=21, net_depth=1, output_dim=21),
    ),
    target_func=lambda x: engine_mlp.model(x),
    lr_init=1e1,
    lr_final=1e0,
    max_steps=0,
    batch_size=40960,
    device="cuda:1",
)
engine_vomlp.run()
input_opt = engine_vomlp.optim_input(input, input_gt, lr_init=1e-3, lr_final=1e-5, max_steps=5000)

0it [00:00, ?it/s]
step 0004999 lr 0.0000: loss 0.0129986 grad 0.0015699 err  0.1554769 valid 1.000: 100%|██████████| 5000/5000 [01:00<00:00, 82.71it/s]


In [None]:
# res = 32
# bbox = [0, 0, 0, 1, 1, 1]
# engine_cnnvoxel = Engine(
#     model=CNNVoxel(res=res, bbox=bbox, output_dim=21),
#     target_func=lambda x: engine_mlp.model(x),
#     lr_init=1e0,
#     lr_final=1e-3,
#     max_steps=10000,
#     batch_size=40960,
#     device="cuda:1",
# )
# engine_cnnvoxel.run()
# input_opt = engine_cnnvoxel.optim_input(input, input_gt, lr_init=1e-2, lr_final=1e-3, max_steps=5000)

In [21]:
res = 8
bbox = [0, 0, 0, 1, 1, 1]
engine_voxel = Engine(
    model=ListVoxel(res=res, bbox=bbox, output_dim=21),
    target_func=lambda x: engine_mlp.model(x),
    lr_init=1e1,
    lr_final=1e-1,
    max_steps=1,
    batch_size=40960,
    device="cuda:1",
)
engine_voxel.run()
input_opt = engine_voxel.optim_input(input, input_gt, lr_init=1e0, lr_final=1e0, max_steps=2000)

step 0000000 lr 10.0000: loss 1.2410700: 100%|██████████| 1/1 [00:00<00:00, 39.43it/s]
step 0001999 lr 1.0000: loss 0.1599888 grad 0.0000060 err  0.0998197 valid 1.000: 100%|██████████| 2000/2000 [00:05<00:00, 354.79it/s]


In [None]:
input = torch.rand((1024, 3)).to(device)
input_gt = torch.rand((1024, 3)).to(device)

In [None]:
x = input.clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([x], lr=1e-3)
for i in range(10000):
    x = x.to(device)
    y = mlp(x)
    loss = F.mse_loss(y, y_mlp_gt)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print ("i", i, "loss", loss.data, "grad", x.grad.abs().mean(), "err", F.mse_loss(x, input_gt).data)

In [None]:
device = "cuda:1"

voxel = Voxel(res=256, bbox=[0, 0, 0, 1, 1, 1]).to(device)
optimizer = torch.optim.SGD(voxel.parameters(), lr=1e5)
for i in range(5000):
    x = torch.rand((40960, 3)).to(device)
    y = voxel(x)
    loss = F.mse_loss(y, x)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print ("i", i, "loss", loss.data)

In [None]:
input = torch.rand((1024, 3)).to(device)
input_gt = torch.rand((1024, 3)).to(device)
with torch.no_grad():
    y_voxel_gt = voxel(input_gt)
    y_mlp_gt = mlp(input_gt)
    print ((y_voxel_gt - y_mlp_gt).abs().mean())

In [None]:
x = input.clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([x], lr=1e-3)
for i in range(10000):
    x = x.to(device)
    y = mlp(x)
    loss = F.mse_loss(y, y_mlp_gt)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print ("i", i, "loss", loss.data, "grad", x.grad.abs().mean(), "err", F.mse_loss(x, input_gt).data)

In [None]:
from typing import List

import torch
import torch.nn as nn
import tinycudann as tcnn


class TCNNHashPositionalEncoder(nn.Module):
    """ Hash Positinal Encoder from Instant-NGP.
    
    https://github.com/NVlabs/instant-ngp
    """
    def __init__(
        self,
        bounding_box: List[float], 
        in_dim: int = 3,
        n_levels: int = 16,
        n_features_per_level: int = 2,
        log2_hashmap_size: int = 19,
        base_resolution: int = 16,
        per_level_scale: float = 2.0,
    ):
        super().__init__()
        # [min_x, min_y, min_z, max_x, max_y, max_z]
        self.bounding_box = torch.tensor(bounding_box)
        self.in_dim = in_dim
        self.n_levels = n_levels
        self.n_features_per_level = n_features_per_level
        # The input to the tcnn.Encoding should be normalized
        # to (0, 1) using `self.bounding_box`
        self.encoder = tcnn.Encoding(
            n_input_dims=in_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": n_features_per_level,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
        )

    @property
    def out_dim(self):
        return self.n_levels * self.n_features_per_level

    def forward(self, x: torch.Tensor):
        """
        :params x: [..., 3],
        :return x_enc: [..., self.out_dim]
        """
        bb_min, bb_max = torch.split(
            self.bounding_box.to(x), [3, 3], dim=0
        )
        x = (x - bb_min) / (bb_max - bb_min)
        x = self.encoder(
            x.reshape(-1, x.shape[-1]).half()
        ).to(x).reshape(list(x.shape[:-1]) + [self.out_dim])
        return x
        

In [None]:
import torch.nn.functional as F

device = "cuda:9"

x_gt = torch.rand((16, 3))
encoder_gt = TCNNHashPositionalEncoder([0, 0, 0, 1, 1, 1])
out_gt = encoder_gt(x_gt)

x = torch.rand((16, 3), requires_grad=True)
# encoder = TCNNHashPositionalEncoder([0, 0, 0, 1, 1, 1])
# out = encoder(x)

optimizer = torch.optim.Adam([
    {"params": [x], "lr": 1e-2},
    {"params": encoder_gt.parameters(), "lr": 0.0}
])
for _ in range(10000):    
    out = encoder_gt(x)
    loss = F.mse_loss(out, out_gt.detach())
    err = F.mse_loss(x.detach(), x_gt)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print ("loss", loss.data, "err", err.data, "grad", x.grad.abs().mean())
